@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
191 lines (187 loc) • 7.82 kB
JavaScript
import { defaultAllocator } from '../../../tensor/gpu/gl';
import { computeStrides, getSize } from '../../../util/shape';
import { Operation } from '../operation';
export class GatherOperation extends Operation {
constructor(tensorConstructor, dtype, allocator) {
super(tensorConstructor, dtype, allocator);
this.gatherMaxIxSize = 10;
}
getVariables() {
return `
${this.getVarModifier('axis')} int axis;
${this.getVarModifier('indexValues')} int indexValues[${this.gatherMaxIxSize}];
${this.getVarModifier('mappedIndexStrides')} int mappedIndexStrides[${this.maxRank}];
${this.getVarModifier('mappedInputStrides')} int mappedInputStrides[${this.maxRank}];
`;
}
getUniformAttrs() {
return [
{ name: 'axis' },
{ name: 'indexValues', length: this.gatherMaxIxSize },
{ name: 'mappedInputStrides', length: this.maxRank },
{ name: 'mappedIndexStrides', length: this.maxRank },
];
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getFragmentShader(info) {
return `
float process(int index[${this.maxRank}]) {
int inputPos = 0;
int indexPos = 0;
int strideAxis = 0;
for (int i = 0; i < ${this.maxRank}; i++) {
if (index[i] == -1) {
break;
}
if (i == axis) {
strideAxis = stridesX[i];
}
inputPos += mappedInputStrides[i]*index[i];
indexPos += mappedIndexStrides[i]*index[i];
}
for (int i = 0; i < ${this.gatherMaxIxSize}; i++) {
if (i == indexPos) {
inputPos += indexValues[i]*strideAxis;
break;
}
}
return getValueAtPos(inputPos, widthX, heightX, X);
}
${this.getDefaultMain()}
`;
}
getTextureNames() {
return ['X'];
}
calc(input) {
if (this.fullyStatic && this.outputShape !== undefined) {
return this.compute(this.outputShape, { X: input.X });
}
if (input.indices.size > this.gatherMaxIxSize) {
throw new Error(`Gather on GPU can deal with at most ${this.gatherMaxIxSize} indices, input had ${input.indices.size}`);
}
const r = input.X.shape.length;
const q = input.indices.shape.length;
const inputStrides = computeStrides(input.X.shape);
const indexStrides = computeStrides(input.indices.shape);
const resultRank = r + q - 1;
const resultShape = new Array(resultRank);
const mappedInputStrides = new Array(resultRank).fill(0);
const mappedIndexStrides = new Array(resultRank).fill(0);
for (let i = 0; i < input.axis; i++) {
resultShape[i] = input.X.shape[i];
mappedInputStrides[i] = inputStrides[i];
mappedIndexStrides[i] = 0;
}
for (let i = 0; i < q; i++) {
resultShape[i + input.axis] = input.indices.shape[i];
mappedIndexStrides[i + input.axis] = indexStrides[i];
mappedInputStrides[i + input.axis] = 0;
}
for (let i = input.axis + 1; i < r; i++) {
resultShape[i + q - 1] = input.X.shape[i];
mappedInputStrides[i + q - 1] = inputStrides[i];
mappedIndexStrides[i + q - 1] = 0;
}
return this.compute(resultShape, { X: input.X }, {
axis: input.axis,
indexValues: this.pad(Array.from(input.indices.values), this.gatherMaxIxSize),
mappedInputStrides: this.pad(mappedInputStrides),
mappedIndexStrides: this.pad(mappedIndexStrides),
});
}
getOutputShape(input) {
const r = input.X.shape.length;
const q = input.indices.shape.length;
const resultRank = r + q - 1;
const resultShape = new Array(resultRank);
for (let i = 0; i < input.axis; i++) {
resultShape[i] = input.X.shape[i];
}
for (let i = 0; i < q; i++) {
resultShape[i + input.axis] = input.indices.shape[i];
}
for (let i = input.axis + 1; i < r; i++) {
resultShape[i + q - 1] = input.X.shape[i];
}
return resultShape;
}
compile(info) {
if (info.shapeX !== undefined) {
this.maxRank = info.shapeX.length;
if (info.indices !== undefined && info.axis !== undefined) {
const r = info.shapeX.length;
const q = info.indices.shape.length;
const inputStrides = computeStrides(info.shapeX);
const indexStrides = computeStrides(info.indices.shape);
const resultRank = r + q - 1;
const resultShape = new Array(resultRank);
const mappedInputStrides = new Array(resultRank).fill(0);
const mappedIndexStrides = new Array(resultRank).fill(0);
for (let i = 0; i < info.axis; i++) {
resultShape[i] = info.shapeX[i];
mappedInputStrides[i] = inputStrides[i];
mappedIndexStrides[i] = 0;
}
for (let i = 0; i < q; i++) {
resultShape[i + info.axis] = info.indices.shape[i];
mappedIndexStrides[i + info.axis] = indexStrides[i];
mappedInputStrides[i + info.axis] = 0;
}
for (let i = info.axis + 1; i < r; i++) {
resultShape[i + q - 1] = info.shapeX[i];
mappedInputStrides[i + q - 1] = inputStrides[i];
mappedIndexStrides[i + q - 1] = 0;
}
info.mappedIndexStrides = mappedIndexStrides;
info.mappedInputStrides = mappedInputStrides;
info.indexValues = Array.from(info.indices.values);
delete info['indices'];
}
}
super.compile(info);
}
getCompilationInfo(input) {
const outputShape = this.getOutputShape(input);
const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype);
const r = input.X.shape.length;
const q = input.indices.shape.length;
const inputStrides = computeStrides(input.X.shape);
const indexStrides = computeStrides(input.indices.shape);
const resultRank = r + q - 1;
const resultShape = new Array(resultRank);
const mappedInputStrides = new Array(resultRank).fill(0);
const mappedIndexStrides = new Array(resultRank).fill(0);
for (let i = 0; i < input.axis; i++) {
resultShape[i] = input.X.shape[i];
mappedInputStrides[i] = inputStrides[i];
mappedIndexStrides[i] = 0;
}
for (let i = 0; i < q; i++) {
resultShape[i + input.axis] = input.indices.shape[i];
mappedIndexStrides[i + input.axis] = indexStrides[i];
mappedInputStrides[i + input.axis] = 0;
}
for (let i = input.axis + 1; i < r; i++) {
resultShape[i + q - 1] = input.X.shape[i];
mappedInputStrides[i + q - 1] = inputStrides[i];
mappedIndexStrides[i + q - 1] = 0;
}
return {
shapeX: input.X.shape,
widthX: input.X.memory.width,
heightX: input.X.memory.height,
shapeOutput: outputShape,
widthOutput: outputSize.width,
heightOutput: outputSize.height,
axis: input.axis,
indexValues: Array.from(input.indices.values),
mappedIndexStrides,
mappedInputStrides,
};
}
getInputInfoString(input) {
return `${input.X.shape}-${input.axis}-${Array.from(input.indices.values)}-${input.indices.shape}`;
}
}
//# sourceMappingURL=gather.js.map