@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
135 lines (132 loc) • 4.77 kB
JavaScript
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { getSize } from '../../../util/shape';
import { Operation } from '../operation';
export class SliceOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int index[${this.maxRank}]) {
int inIx[${this.maxRank}];
${this.initIndex('inIx')}
for (int i = 0; i < ${this.maxRank}; i++) {
if (index[i] == -1) {
break;
}
inIx[i] = index[i]*steps[i] + offsets[i];
}
return _X(inIx);
}
${this.getDefaultMain()}
`;
}
getTextureNames() {
return ['X'];
}
getVariables() {
return `
${this.getVarModifier('offsets')} int offsets[${this.maxRank}];
${this.getVarModifier('steps')} int steps[${this.maxRank}];
`;
}
getUniformAttrs() {
return [
{ name: 'offsets', length: this.maxRank },
{ name: 'steps', length: this.maxRank },
];
}
calc(input) {
if (this.fullyStatic && this.outputShape !== undefined) {
return this.compute(this.outputShape, { X: input.X });
}
const rank = input.X.shape.length;
const resultShape = [...input.X.shape];
const offsets = new Array(rank).fill(0);
const steps = new Array(rank).fill(1);
let axIx = 0;
for (let i = 0; i < rank && axIx < input.axes.length; i++) {
if (i === input.axes[axIx]) {
resultShape[i] = Math.ceil((input.ends[axIx] - input.starts[axIx]) / input.steps[axIx]);
offsets[i] = input.starts[axIx];
steps[i] = input.steps[axIx];
axIx++;
}
}
return this.compute(resultShape, { X: input.X }, {
offsets: this.pad(offsets),
steps: this.pad(steps),
});
}
getOutputShape(input) {
const rank = input.X.shape.length;
const resultShape = [...input.X.shape];
let axIx = 0;
for (let i = 0; i < rank && axIx < input.axes.length; i++) {
if (i === input.axes[axIx]) {
resultShape[i] = Math.ceil((input.ends[axIx] - input.starts[axIx]) / input.steps[axIx]);
axIx++;
}
}
return resultShape;
}
compile(info) {
if (info.shapeX !== undefined) {
this.maxRank = info.shapeX.length;
if (info.axes !== undefined &&
info.starts !== undefined &&
info.ends !== undefined &&
info.steps !== undefined) {
const rank = info.shapeX.length;
const offsets = new Array(rank).fill(0);
const steps = new Array(rank).fill(1);
let axIx = 0;
for (let i = 0; i < rank && axIx < info.axes.length; i++) {
if (i === info.axes[axIx]) {
offsets[i] = info.starts[axIx];
steps[i] = info.steps[axIx];
axIx++;
}
}
info.offsets = offsets;
info.steps = steps;
delete info['starts'];
delete info['ends'];
delete info['axes'];
}
}
super.compile(info);
}
getCompilationInfo(input) {
const outputShape = this.getOutputShape(input);
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
const rank = input.X.shape.length;
const resultShape = [...input.X.shape];
const offsets = new Array(rank).fill(0);
const steps = new Array(rank).fill(1);
let axIx = 0;
for (let i = 0; i < rank && axIx < input.axes.length; i++) {
if (i === input.axes[axIx]) {
resultShape[i] = Math.ceil((input.ends[axIx] - input.starts[axIx]) / input.steps[axIx]);
offsets[i] = input.starts[axIx];
steps[i] = input.steps[axIx];
axIx++;
}
}
return {
shapeX: input.X.shape,
widthX: input.X.memory.width,
heightX: input.X.memory.height,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
offsets,
steps,
};
}
getInputInfoString(input) {
return `${input.X.shape}-${input.axes}-${input.starts}-${input.ends}-${input.steps}`;
}
}
//# sourceMappingURL=slice.js.map