@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
89 lines (88 loc) • 3.27 kB
JavaScript
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { computeStrides, getSize } from '../../../util/shape';
import { Operation } from '../operation';
export class TransposeOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
}
getVariables() {
return `
${this.getVarModifier('mappedStrides')} int mappedStrides[${this.maxRank}];
`;
}
getUniformAttrs() {
return [{ name: 'mappedStrides', length: this.maxRank }];
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int index[${this.maxRank}]) {
return getValueAt(index, mappedStrides, widthA, heightA, A);
}
${this.getDefaultMain()}
`;
}
getTextureNames() {
return ['A'];
}
calc(input) {
if (this.fullyStatic && this.outputShape !== undefined) {
return this.compute(this.outputShape, { A: input.A });
}
const rank = input.A.shape.length;
const outputShape = this.getOutputShape(input);
const inputStrides = computeStrides(input.A.shape);
const mappedStrides = new Array(rank);
for (let i = 0; i < rank; i++) {
mappedStrides[i] = inputStrides[input.permutation[i]];
}
return this.compute(outputShape, { A: input.A }, { mappedStrides: this.pad(mappedStrides) });
}
getOutputShape(input) {
const rank = input.A.shape.length;
const outputShape = new Array(rank);
for (let i = 0; i < rank; i++) {
outputShape[i] = input.A.shape[input.permutation[i]];
}
return outputShape;
}
compile(info) {
if (info.shapeA !== undefined) {
this.maxRank = info.shapeA.length;
if (info.permutation !== undefined) {
const rank = info.shapeA.length;
const inputStrides = computeStrides(info.shapeA);
const mappedStrides = new Array(rank);
for (let i = 0; i < rank; i++) {
mappedStrides[i] = inputStrides[info.permutation[i]];
}
info.mappedStrides = mappedStrides;
delete info['permutation'];
}
}
super.compile(info);
}
getCompilationInfo(input) {
const outputShape = this.getOutputShape(input);
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
const rank = input.A.shape.length;
const inputStrides = computeStrides(input.A.shape);
const mappedStrides = new Array(rank);
for (let i = 0; i < rank; i++) {
mappedStrides[i] = inputStrides[input.permutation[i]];
}
return {
shapeA: input.A.shape,
widthA: input.A.memory.width,
heightA: input.A.memory.height,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
mappedStrides,
};
}
getInputInfoString(input) {
return `${input.A.shape}-${input.permutation}`;
}
}
//# sourceMappingURL=transpose.js.map