UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

89 lines (88 loc) 3.27 kB
import { defaultAllocator } from '../../../tensor/gpu/gl'; import { computeStrides, getSize } from '../../../util/shape'; import { Operation } from '../operation'; export class TransposeOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); } getVariables() { return ` ${this.getVarModifier('mappedStrides')} int mappedStrides[${this.maxRank}]; `; } getUniformAttrs() { return [{ name: 'mappedStrides', length: this.maxRank }]; } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { return getValueAt(index, mappedStrides, widthA, heightA, A); } ${this.getDefaultMain()} `; } getTextureNames() { return ['A']; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { A: input.A }); } const rank = input.A.shape.length; const outputShape = this.getOutputShape(input); const inputStrides = computeStrides(input.A.shape); const mappedStrides = new Array(rank); for (let i = 0; i < rank; i++) { mappedStrides[i] = inputStrides[input.permutation[i]]; } return this.compute(outputShape, { A: input.A }, { mappedStrides: this.pad(mappedStrides) }); } getOutputShape(input) { const rank = input.A.shape.length; const outputShape = new Array(rank); for (let i = 0; i < rank; i++) { outputShape[i] = input.A.shape[input.permutation[i]]; } return outputShape; } compile(info) { if (info.shapeA !== undefined) { this.maxRank = info.shapeA.length; if (info.permutation !== undefined) { const rank = info.shapeA.length; const inputStrides = computeStrides(info.shapeA); const mappedStrides = new Array(rank); for (let i = 0; i < rank; i++) { mappedStrides[i] = inputStrides[info.permutation[i]]; } info.mappedStrides = mappedStrides; delete info['permutation']; } } super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); const rank = input.A.shape.length; const inputStrides = computeStrides(input.A.shape); const mappedStrides = new Array(rank); for (let i = 0; i < rank; i++) { mappedStrides[i] = inputStrides[input.permutation[i]]; } return { shapeA: input.A.shape, widthA: input.A.memory.width, heightA: input.A.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, mappedStrides, }; } getInputInfoString(input) { return `${input.A.shape}-${input.permutation}`; } } //# sourceMappingURL=transpose.js.map