@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
201 lines (187 loc) • 6.29 kB
JavaScript
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { getSize } from '../../../util/shape';
import { outputDimsSize } from '../../util/convTranspose';
import { Operation } from '../operation';
export class ConvTransposeOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
this.maxIterations = 1000000;
}
updateInputIx() {
return `
for (int d = 0; d < ${this.maxRank - 2}; d++) {
int stride = strides[d];
int pad = pads[d];
int dilation = dilations[d];
if (stride == -1) {
break;
}
int trans_kernel_ix = shapeW[d + 2] - kernelIx[d + 2] - 1;
inputIx[d+2] = index[d + 2] - pad + trans_kernel_ix * dilation;
int divS = inputIx[d+2] / stride;
int resS = inputIx[d+2] - divS*stride;
if (resS != 0) {
skip = true;
break;
}
inputIx[d+2] = divS;
if (inputIx[d+2] < 0 || inputIx[d+2] >= shapeX[d+2]) {
skip = true;
break;
}
}
`;
}
getMainBody() {
return `
int n = index[0];
int m = index[1];
int kernelIx[${this.maxRank}];
${this.initIndex('kernelIx')}
for (int i = 0; i < ${this.maxRank}; i++) {
if (i >= dataRank) {
break;
}
kernelIx[i+2] = 0;
}
kernelIx[0] = m;
int inputIx[${this.maxRank}];
${this.initIndex('inputIx')}
inputIx[0] = n;
for (int cg = 0; cg < ${this.maxIterations}; cg++) {
if (cg >= CG) {
break;
}
int c = m * CG + cg;
int d = c/C;
c = c - d*C;
inputIx[1] = c;
kernelIx[1] = cg;
for (int kIx = 0; kIx < ${this.maxIterations}; kIx++) {
if (kIx >= kernelSize) {
break;
}
bool skip = false;
${this.updateInputIx()}
if (!skip) {
res += _X(inputIx) * _W(kernelIx);
}
${this.incrementIndex('kernelIx', 'shapeW')}
}
}
`;
}
getVariables() {
return `
${this.getVarModifier('CG')} int CG;
${this.getVarModifier('kernelSize')} int kernelSize;
${this.getVarModifier('dataRank')} int dataRank;
${this.getVarModifier('C')} int C;
${this.getVarModifier('dilations')} int dilations[${this.maxRank}];
${this.getVarModifier('pads')} int pads[${this.maxRank}];
${this.getVarModifier('strides')} int strides[${this.maxRank}];
`;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int index[${this.maxRank}]) {
float res = 0.0;
${this.getMainBody()}
return res;
}
${this.getDefaultMain()}
`;
}
getTextureNames() {
return ['X', 'W'];
}
getUniformAttrs() {
return [
{ name: 'CG' },
{ name: 'kernelSize' },
{ name: 'C' },
{ name: 'dataRank' },
{ name: 'pads', length: this.maxRank * 2 },
{ name: 'strides', length: this.maxRank },
{ name: 'dilations', length: this.maxRank },
];
}
calc(input) {
if (this.fullyStatic && this.outputShape !== undefined) {
return this.compute(this.outputShape, { X: input.X, W: input.W });
}
const N = input.X.shape[0];
const C = input.X.shape[1];
const D = input.X.shape.slice(2);
const W = input.W.shape.slice(2);
const M = input.W.shape[0];
const CG = input.W.shape[1];
const kernelSize = getSize(W);
const R = outputDimsSize(D, W, input.pads.slice(0, input.pads.length / 2), input.pads.slice(input.pads.length / 2), input.dilations, input.strides);
let outputShape = [N, M];
outputShape = outputShape.concat(R);
return this.compute(outputShape, { X: input.X, W: input.W }, {
CG,
kernelSize,
C,
dataRank: D.length,
pads: this.copyPad(input.pads, this.maxRank * 2),
strides: this.copyPad(input.strides),
dilations: this.copyPad(input.dilations),
});
}
getOutputShape(input) {
const N = input.X.shape[0];
const D = input.X.shape.slice(2);
const W = input.W.shape.slice(2);
const M = input.W.shape[0];
const R = outputDimsSize(D, W, input.pads.slice(0, input.pads.length / 2), input.pads.slice(input.pads.length / 2), input.dilations, input.strides);
let outputShape = [N, M];
outputShape = outputShape.concat(R);
return outputShape;
}
compile(info) {
if (info.shapeW !== undefined) {
info.CG = info.shapeW[1];
info.kernelSize = getSize(info.shapeW.slice(2));
info.dataRank = info.shapeW.length - 2;
this.maxRank = info.shapeW.length;
}
if (info.shapeX !== undefined) {
info.C = info.shapeX[1];
info.dataRank = info.shapeX.length - 2;
this.maxRank = info.shapeX.length;
}
super.compile(info);
}
getCompilationInfo(input) {
const outputShape = this.getOutputShape(input);
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
const kernelSize = getSize(input.W.shape.slice(2));
const C = input.X.shape[1];
const D = input.X.shape.slice(2);
return {
shapeX: input.X.shape,
widthX: input.X.memory.width,
heightX: input.X.memory.height,
shapeW: input.W.shape,
widthW: input.W.memory.width,
heightW: input.W.memory.height,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
pads: input.pads,
dilations: input.dilations,
strides: input.strides,
CG: input.W.shape[1],
kernelSize: kernelSize,
dataRank: D.length,
C: C,
};
}
getInputInfoString(input) {
return `${input.X.shape}-${input.W.shape}-${input.dilations}-${input.pads}-${input.dilations}-${input.strides}`;
}
}
//# sourceMappingURL=convTranspose.js.map