@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
158 lines (150 loc) • 5.45 kB
JavaScript
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { computeStrides, getSize } from '../../../util/shape';
import { poolResultShape } from '../../util/pool';
import { Operation } from '../operation';
export class PoolOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
this.maxIterations = 1000000;
}
getVariables() {
return `
${this.getVarModifier('mappedInputStrides')} int mappedInputStrides[${this.maxRank}];
${this.getVarModifier('mappedInputStrides')} int sumDims[${this.maxRank}];
${this.getVarModifier('mappedInputStrides')} int sumSize;
`;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int index[${this.maxRank}]) {
int inputIx[${this.maxRank}];
${this.initIndex('inputIx')}
int inputPos = 0;
for (int i = 0; i < ${this.maxRank}; i++) {
if (mappedInputStrides[i] == -1 || index[i] == -1) {
break;
}
inputPos += mappedInputStrides[i]*index[i];
}
${this.posToIndex('stridesX', 'inputIx', 'inputPos')}
float res = 0.0;
for (int i = 0; i < ${this.maxIterations}; i++) {
if (i >= sumSize) {
break;
}
float curr = _X(inputIx);
if (i == 0) {
res = ${this.init('curr')};
} else {
res = ${this.update('curr', 'res')};
}
${this.incrementConditional('inputIx', 'shapeX', 'sumDims')}
}
${this.post('res')}
return res;
}
${this.getDefaultMain()}
`;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
post(res) {
return '';
}
init(res) {
return res;
}
getTextureNames() {
return ['X'];
}
getUniformAttrs() {
return [
{ name: 'mappedInputStrides', length: this.maxRank },
{ name: 'sumDims', length: this.maxRank },
{ name: 'sumSize' },
];
}
calc(input) {
if (this.fullyStatic && this.outputShape !== undefined) {
return this.compute(this.outputShape, { X: input.X });
}
const [outputShape, ixMap] = poolResultShape(input.X.shape, input.axes, input.keepDims);
const inputStrides = computeStrides(input.X.shape);
const mappedInputStrides = [];
for (const i of ixMap) {
mappedInputStrides.push(inputStrides[i]);
}
let sumSize = 1;
const sumDims = new Array(input.X.shape.length).fill(0);
for (let i = 0; i < input.axes.length; i++) {
sumDims[input.axes[i]] = 1;
sumSize *= input.X.shape[input.axes[i]];
}
return this.compute(outputShape, { X: input.X }, {
mappedInputStrides: this.pad(mappedInputStrides),
sumDims: this.pad(sumDims),
sumSize,
});
}
getOutputShape(input) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [outputShape, ixMap] = poolResultShape(input.X.shape, input.axes, input.keepDims);
return outputShape;
}
compile(info) {
if (info.shapeX !== undefined &&
info.axes !== undefined &&
info.keepDims !== undefined) {
const [outputShape, ixMap] = poolResultShape(info.shapeX, info.axes, info.keepDims);
const inputStrides = computeStrides(info.shapeX);
const mappedInputStrides = [];
for (const i of ixMap) {
mappedInputStrides.push(inputStrides[i]);
}
let sumSize = 1;
const sumDims = new Array(info.shapeX.length).fill(0);
for (let i = 0; i < info.axes.length; i++) {
sumDims[info.axes[i]] = 1;
sumSize *= info.shapeX[info.axes[i]];
}
info.sumDims = sumDims;
info.shapeOutput = outputShape;
info.mappedInputStrides = mappedInputStrides;
info.sumSize = sumSize;
delete info['keepDims'];
delete info['axes'];
this.maxRank = info.shapeX.length;
}
super.compile(info);
}
getCompilationInfo(input) {
const [outputShape, ixMap] = poolResultShape(input.X.shape, input.axes, input.keepDims);
const inputStrides = computeStrides(input.X.shape);
const mappedInputStrides = [];
for (const i of ixMap) {
mappedInputStrides.push(inputStrides[i]);
}
let sumSize = 1;
const sumDims = new Array(input.X.shape.length).fill(0);
for (let i = 0; i < input.axes.length; i++) {
sumDims[input.axes[i]] = 1;
sumSize *= input.X.shape[input.axes[i]];
}
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
return {
shapeX: input.X.shape,
widthX: input.X.memory.width,
heightX: input.X.memory.height,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
mappedInputStrides,
sumDims,
sumSize,
};
}
getInputInfoString(input) {
return `${input.X.shape}-${input.axes}-${input.keepDims}`;
}
}
//# sourceMappingURL=pool.js.map