@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
136 lines (130 loc) • 4.2 kB
JavaScript
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { getSize } from '../../../util/shape';
import { Operation } from '../operation';
export class PadOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int index[${this.maxRank}]) {
int inputIx[${this.maxRank}];
${this.initIndex('inputIx')}
if (mode == 0) {
float res = value;
int outOfBounds = 0;
for (int i = 0; i < ${this.maxRank}; i++) {
if (index[i] == -1) {
break;
}
inputIx[i] = index[i] - pads[i];
if (inputIx[i] < 0 || inputIx[i] >= shapeX[i]) {
outOfBounds = 1;
break;
}
}
if (outOfBounds == 0) {
res = _X(inputIx);
}
return res;
} else if (mode == 1) {
for (int i = 0; i < ${this.maxRank}; i++) {
if (index[i] == -1) {
break;
}
inputIx[i] = index[i] - pads[i];
if (inputIx[i] < 0) {
inputIx[i] = -inputIx[i];
} else if (inputIx[i] >= shapeX[i]) {
inputIx[i] = 2*shapeX[i] - inputIx[i] - 2;
}
}
return _X(inputIx);
} else {
for (int i = 0; i < ${this.maxRank}; i++) {
if (index[i] == -1) {
break;
}
inputIx[i] = index[i] - pads[i];
if (inputIx[i] < 0) {
inputIx[i] = 0;
} else if (inputIx[i] >= shapeX[i]) {
inputIx[i] = shapeX[i] - 1;
}
}
return _X(inputIx);
}
}
${this.getDefaultMain()}
`;
}
getTextureNames() {
return ['X'];
}
getVariables() {
return `
${this.getVarModifier('pads')} int pads[${this.maxRank * 2}];
${this.getVarModifier('value')} float value;
${this.getVarModifier('mode')} int mode;
`;
}
getUniformAttrs() {
return [
{ name: 'value', type: 'float' },
{ name: 'pads', length: this.maxRank * 2 },
{ name: 'mode' },
];
}
getModeFlag(mode) {
return mode === 'constant' ? 0 : mode === 'reflect' ? 1 : 2;
}
calc(input) {
if (this.fullyStatic && this.outputShape !== undefined) {
return this.compute(this.outputShape, { X: input.input });
}
const resultShape = this.getOutputShape(input);
return this.compute(resultShape, { X: input.input }, {
pads: this.copyPad(input.pads, this.maxRank * 2),
value: input.value,
mode: this.getModeFlag(input.mode),
});
}
getOutputShape(input) {
const rank = input.input.shape.length;
const resultShape = [...input.input.shape];
for (let i = 0; i < rank; i++) {
resultShape[i] += input.pads[i] + input.pads[i + rank];
}
return resultShape;
}
compile(info) {
if (info.shapeX !== undefined) {
this.maxRank = info.shapeX.length;
}
if (info.mode !== undefined && typeof info.mode === 'string') {
info.mode = this.getModeFlag(info.mode);
}
super.compile(info);
}
getCompilationInfo(input) {
const outputShape = this.getOutputShape(input);
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
return {
shapeX: input.input.shape,
widthX: input.input.memory.width,
heightX: input.input.memory.height,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
pads: input.pads,
mode: this.getModeFlag(input.mode),
value: input.value,
};
}
getInputInfoString(input) {
//TODO: Format value with enough precision?
return `${input.input.shape}-${input.pads}-${input.mode}-${input.value}`;
}
}
//# sourceMappingURL=pad.js.map