@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
166 lines (154 loc) • 5.38 kB
JavaScript
import { getSize } from '../../../util/shape';
import { outputDimsSize } from '../../util/conv';
import { Operation } from '../operation';
import { defaultAllocator } from '../../../tensor/gpu/gl';
export class AveragePoolOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
this.maxIterations = 1000000;
}
updateInputIx() {
return `
for (int d = 0; d < ${this.maxRank - 2}; d++) {
int stride = strides[d];
int pad = pads[d];
if (stride == -1) {
break;
}
inputIx[d+2] = index[d+2]*stride - pad + kernelIx[d];
if (inputIx[d+2] < 0 || inputIx[d+2] >= shapeX[d+2]) {
skip = true;
break;
}
}
`;
}
getVariables() {
return `
${this.getVarModifier('kernelSize')} int kernelSize;
${this.getVarModifier('dataRank')} int dataRank;
${this.getVarModifier('includePad')} int includePad;
${this.getVarModifier('pads')} int pads[${this.maxRank}];
${this.getVarModifier('strides')} int strides[${this.maxRank}];
${this.getVarModifier('kernelShape')} int kernelShape[${this.maxRank}];
`;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int index[${this.maxRank}]) {
float res = 0.0;
int count = 0;
int n = index[0];
int c = index[1];
int kernelIx[${this.maxRank}];
${this.initIndex('kernelIx')}
for (int i = 0; i < ${this.maxRank}; i++) {
if (i >= dataRank) {
break;
}
kernelIx[i] = 0;
}
int inputIx[${this.maxRank}];
${this.initIndex('inputIx')}
inputIx[0] = n;
inputIx[1] = c;
for (int kIx = 0; kIx < ${this.maxIterations}; kIx++) {
if (kIx >= kernelSize) {
break;
}
bool skip = false;
${this.updateInputIx()}
if (!skip) {
res += _X(inputIx);
}
if (!skip || includePad == 1) {
count += 1;
}
${this.incrementIndex('kernelIx', 'kernelShape')}
}
res = res / float(count);
return res;
}
${this.getDefaultMain()}
`;
}
getTextureNames() {
return ['X'];
}
getUniformAttrs() {
return [
{ name: 'kernelSize' },
{ name: 'dataRank' },
{ name: 'includePad' },
{ name: 'pads', length: this.maxRank * 2 },
{ name: 'strides', length: this.maxRank },
{ name: 'kernelShape', length: this.maxRank },
];
}
calc(input) {
if (this.fullyStatic && this.outputShape !== undefined) {
return this.compute(this.outputShape, { X: input.X });
}
const N = input.X.shape[0];
const C = input.X.shape[1];
const D = input.X.shape.slice(2);
const kernelSize = getSize(input.kernelShape);
const R = outputDimsSize(D, input.kernelShape, input.pads.slice(0, input.pads.length / 2), input.pads.slice(input.pads.length / 2), new Array(D.length).fill(1), input.strides);
let outputShape = [N, C];
outputShape = outputShape.concat(R);
return this.compute(outputShape, { X: input.X }, {
kernelSize,
includePad: input.includePad ? 1 : 0,
dataRank: D.length,
pads: this.copyPad(input.pads, this.maxRank * 2),
strides: this.copyPad(input.strides),
kernelShape: this.copyPad(input.kernelShape),
});
}
getOutputShape(input) {
const N = input.X.shape[0];
const C = input.X.shape[1];
const D = input.X.shape.slice(2);
const R = outputDimsSize(D, input.kernelShape, input.pads.slice(0, input.pads.length / 2), input.pads.slice(input.pads.length / 2), new Array(D.length).fill(1), input.strides);
let outputShape = [N, C];
outputShape = outputShape.concat(R);
return outputShape;
}
compile(info) {
if (info.shapeX !== undefined) {
info.dataRank = info.shapeX.length - 2;
this.maxRank = info.shapeX.length;
}
if (info.includePad === true) {
info.includePad = 1;
}
else if (info.includePad === false) {
info.includePad = 0;
}
super.compile(info);
}
getCompilationInfo(input) {
const outputShape = this.getOutputShape(input);
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
const kernelSize = getSize(input.kernelShape);
return {
shapeX: input.X.shape,
widthX: input.X.memory.width,
heightX: input.X.memory.height,
kernelShape: input.kernelShape,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
pads: input.pads,
strides: input.strides,
kernelSize: kernelSize,
dataRank: input.X.shape.length - 2,
includePad: input.includePad ? 1 : 0,
};
}
getInputInfoString(input) {
return `${input.X.shape}-${input.kernelShape}-${input.pads}-${input.strides}-${input.includePad}`;
}
}
//# sourceMappingURL=averagePool.js.map