UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

135 lines (132 loc) 4.77 kB
import { defaultAllocator } from '../../../tensor/gpu/gl'; import { getSize } from '../../../util/shape'; import { Operation } from '../operation'; export class SliceOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { int inIx[${this.maxRank}]; ${this.initIndex('inIx')} for (int i = 0; i < ${this.maxRank}; i++) { if (index[i] == -1) { break; } inIx[i] = index[i]*steps[i] + offsets[i]; } return _X(inIx); } ${this.getDefaultMain()} `; } getTextureNames() { return ['X']; } getVariables() { return ` ${this.getVarModifier('offsets')} int offsets[${this.maxRank}]; ${this.getVarModifier('steps')} int steps[${this.maxRank}]; `; } getUniformAttrs() { return [ { name: 'offsets', length: this.maxRank }, { name: 'steps', length: this.maxRank }, ]; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { X: input.X }); } const rank = input.X.shape.length; const resultShape = [...input.X.shape]; const offsets = new Array(rank).fill(0); const steps = new Array(rank).fill(1); let axIx = 0; for (let i = 0; i < rank && axIx < input.axes.length; i++) { if (i === input.axes[axIx]) { resultShape[i] = Math.ceil((input.ends[axIx] - input.starts[axIx]) / input.steps[axIx]); offsets[i] = input.starts[axIx]; steps[i] = input.steps[axIx]; axIx++; } } return this.compute(resultShape, { X: input.X }, { offsets: this.pad(offsets), steps: this.pad(steps), }); } getOutputShape(input) { const rank = input.X.shape.length; const resultShape = [...input.X.shape]; let axIx = 0; for (let i = 0; i < rank && axIx < input.axes.length; i++) { if (i === input.axes[axIx]) { resultShape[i] = Math.ceil((input.ends[axIx] - input.starts[axIx]) / input.steps[axIx]); axIx++; } } return resultShape; } compile(info) { if (info.shapeX !== undefined) { this.maxRank = info.shapeX.length; if (info.axes !== undefined && info.starts !== undefined && info.ends !== undefined && info.steps !== undefined) { const rank = info.shapeX.length; const offsets = new Array(rank).fill(0); const steps = new Array(rank).fill(1); let axIx = 0; for (let i = 0; i < rank && axIx < info.axes.length; i++) { if (i === info.axes[axIx]) { offsets[i] = info.starts[axIx]; steps[i] = info.steps[axIx]; axIx++; } } info.offsets = offsets; info.steps = steps; delete info['starts']; delete info['ends']; delete info['axes']; } } super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); const rank = input.X.shape.length; const resultShape = [...input.X.shape]; const offsets = new Array(rank).fill(0); const steps = new Array(rank).fill(1); let axIx = 0; for (let i = 0; i < rank && axIx < input.axes.length; i++) { if (i === input.axes[axIx]) { resultShape[i] = Math.ceil((input.ends[axIx] - input.starts[axIx]) / input.steps[axIx]); offsets[i] = input.starts[axIx]; steps[i] = input.steps[axIx]; axIx++; } } return { shapeX: input.X.shape, widthX: input.X.memory.width, heightX: input.X.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, offsets, steps, }; } getInputInfoString(input) { return `${input.X.shape}-${input.axes}-${input.starts}-${input.ends}-${input.steps}`; } } //# sourceMappingURL=slice.js.map