onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
176 lines (175 loc) • 5.9 kB
JavaScript
;
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
Object.defineProperty(exports, '__esModule', { value: true });
exports.ShapeUtilsGlslLib = void 0;
const glsl_definitions_1 = require('./glsl-definitions');
/**
* GLSL Library responsible for data types and routines for manipulating
* coordinates and mapping to/from tensor indices
*/
class ShapeUtilsGlslLib extends glsl_definitions_1.GlslLib {
constructor(context) {
super(context);
}
getFunctions() {
return {
...this.bcastIndex(),
...this.bcastMatmulIndex(),
...this.offsetToIndices(),
...this.indicesToOffset(),
...this.incrementIndices(),
};
}
getCustomTypes() {
return {};
}
bcastIndex() {
const outputRank = this.context.outputTextureLayout.shape.length;
const result = {};
this.context.programInfo.inputNames.forEach((name, i) => {
const shape = this.context.inputTextureLayouts[i].unpackedShape;
if (shape.length <= outputRank) {
const rank = shape.length;
const dimOffset = outputRank - rank;
const funcName = `bcastIndices_${name}`;
let block = '';
for (let i = 0; i < rank; ++i) {
block += `
realIndices[${i}] = int( mod(float(bcastedIndices[${dimOffset + i}]), ${shape[i]}.0) );
`;
}
const body = `
void ${funcName} (int bcastedIndices[${outputRank}], out int realIndices[${rank}]) {
${block}
}
`;
result[funcName] = new glsl_definitions_1.GlslLibRoutine(body);
}
});
return result;
}
bcastMatmulIndex() {
const outputRank = this.context.outputTextureLayout.shape.length;
const result = {};
this.context.programInfo.inputNames.forEach((name, i) => {
const shape = this.context.inputTextureLayouts[i].shape;
if (!(shape.length < 2 || shape.length > outputRank)) {
const rank = shape.length;
const dimOffset = outputRank - rank;
const funcName = `bcastMatmulIndices_${name}`;
let block = '';
for (let i = 0; i < rank - 2; ++i) {
block += `
realIndices[${i}] = int( mod(float(bcastedIndices[${dimOffset + i}]), ${shape[i]}.0) );
`;
}
const body = `
void ${funcName}(int bcastedIndices[${outputRank}], out int realIndices[${rank}]) {
${block}
realIndices[${rank - 1}] = bcastedIndices[${outputRank - 1}];
realIndices[${rank - 2}] = bcastedIndices[${outputRank - 2}];
}
`;
result[funcName] = new glsl_definitions_1.GlslLibRoutine(body);
}
});
return result;
}
indicesToOffset() {
const result = {};
this.context.programInfo.inputNames.forEach((name, i) => {
const shape = this.context.inputTextureLayouts[i].shape;
const strides = this.context.inputTextureLayouts[i].strides;
const rank = shape.length;
let funcName = `indicesToOffset_${name}`;
result[funcName] = new glsl_definitions_1.GlslLibRoutine(
ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides),
);
funcName = `indicesToOffset_${name}_T`;
result[funcName] = new glsl_definitions_1.GlslLibRoutine(
ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides.slice().reverse()),
);
});
return result;
}
static indexToOffsetSingle(name, rank, strides) {
let block = '';
for (let i = rank - 1; i >= 0; --i) {
block += `
offset += indices[${i}] * ${strides[i]};
`;
}
return `
int ${name}(int indices[${rank}]) {
int offset = 0;
${block}
return offset;
}
`;
}
offsetToIndices() {
const result = {};
this.context.programInfo.inputNames.forEach((name, i) => {
const shape = this.context.inputTextureLayouts[i].shape;
const strides = this.context.inputTextureLayouts[i].strides;
const rank = shape.length;
let funcName = `offsetToIndices_${name}`;
result[funcName] = new glsl_definitions_1.GlslLibRoutine(
ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides),
);
funcName = `offsetToIndices_${name}_T`;
result[funcName] = new glsl_definitions_1.GlslLibRoutine(
ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides.slice().reverse()),
);
});
return result;
}
static offsetToIndicesSingle(name, rank, strides) {
const stridesBlock = [];
for (let i = 0; i < rank - 1; ++i) {
stridesBlock.push(`
indices[${i}] = offset / ${strides[i]};`);
stridesBlock.push(`
offset -= indices[${i}] * ${strides[i]};`);
}
stridesBlock.push(`
indices[${rank - 1}] = offset;`);
return `
void ${name}(int offset, out int indices[${rank}]) {
${stridesBlock.join('')}
}
`;
}
incrementIndices() {
const result = {};
this.context.programInfo.inputNames.forEach((name, i) => {
const shape = this.context.inputTextureLayouts[i].shape;
const rank = shape.length;
const funcName = `incrementIndices_${name}`;
let shapeInit = '';
for (let i = 0; i < rank; ++i) {
shapeInit += `
shape[${i}] = ${shape[i]};`;
}
const body = `
void ${funcName}(int axis, out int indices[${rank}]) {
int shape[${rank}];
${shapeInit};
for(int i = ${rank} -1 ; i >= 0; --i) {
if(i > axis) continue;
indices[i] += 1;
if(indices[i] < shape[i]) {
break;
}
indices[i] = 0;
}
}
`;
result[funcName] = new glsl_definitions_1.GlslLibRoutine(body);
});
return result;
}
}
exports.ShapeUtilsGlslLib = ShapeUtilsGlslLib;
//# sourceMappingURL=glsl-shape-utils-lib.js.map