UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

169 lines (166 loc) 6.07 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { GlslContext, GlslLib, GlslLibRoutine } from './glsl-definitions'; /** * GLSL Library responsible for data types and routines for manipulating * coordinates and mapping to/from tensor indices */ export class ShapeUtilsGlslLib extends GlslLib { constructor(context: GlslContext) { super(context); } getFunctions(): { [name: string]: GlslLibRoutine } { return { ...this.bcastIndex(), ...this.bcastMatmulIndex(), ...this.offsetToIndices(), ...this.indicesToOffset(), ...this.incrementIndices(), }; } getCustomTypes() { return {}; } protected bcastIndex(): { [name: string]: GlslLibRoutine } { const outputRank = this.context.outputTextureLayout.shape.length; const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const shape = this.context.inputTextureLayouts[i].unpackedShape; if (shape.length <= outputRank) { const rank = shape.length; const dimOffset = outputRank - rank; const funcName = `bcastIndices_${name}`; let block = ''; for (let i = 0; i < rank; ++i) { block += ` realIndices[${i}] = int( mod(float(bcastedIndices[${dimOffset + i}]), ${shape[i]}.0) ); `; } const body = ` void ${funcName} (int bcastedIndices[${outputRank}], out int realIndices[${rank}]) { ${block} } `; result[funcName] = new GlslLibRoutine(body); } }); return result; } protected bcastMatmulIndex(): { [name: string]: GlslLibRoutine } { const outputRank = this.context.outputTextureLayout.shape.length; const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const shape = this.context.inputTextureLayouts[i].shape; if (!(shape.length < 2 || shape.length > outputRank)) { const rank = shape.length; const dimOffset = outputRank - rank; const funcName = `bcastMatmulIndices_${name}`; let block = ''; for (let i = 0; i < rank - 2; ++i) { block += ` realIndices[${i}] = int( mod(float(bcastedIndices[${dimOffset + i}]), ${shape[i]}.0) ); `; } const body = ` void ${funcName}(int bcastedIndices[${outputRank}], out int realIndices[${rank}]) { ${block} realIndices[${rank - 1}] = bcastedIndices[${outputRank - 1}]; realIndices[${rank - 2}] = bcastedIndices[${outputRank - 2}]; } `; result[funcName] = new GlslLibRoutine(body); } }); return result; } protected indicesToOffset(): { [name: string]: GlslLibRoutine } { const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const shape = this.context.inputTextureLayouts[i].shape; const strides = this.context.inputTextureLayouts[i].strides; const rank = shape.length; let funcName = `indicesToOffset_${name}`; result[funcName] = new GlslLibRoutine(ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides)); funcName = `indicesToOffset_${name}_T`; result[funcName] = new GlslLibRoutine( ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides.slice().reverse()), ); }); return result; } static indexToOffsetSingle(name: string, rank: number, strides: readonly number[]): string { let block = ''; for (let i = rank - 1; i >= 0; --i) { block += ` offset += indices[${i}] * ${strides[i]}; `; } return ` int ${name}(int indices[${rank}]) { int offset = 0; ${block} return offset; } `; } protected offsetToIndices(): { [name: string]: GlslLibRoutine } { const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const shape = this.context.inputTextureLayouts[i].shape; const strides = this.context.inputTextureLayouts[i].strides; const rank = shape.length; let funcName = `offsetToIndices_${name}`; result[funcName] = new GlslLibRoutine(ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides)); funcName = `offsetToIndices_${name}_T`; result[funcName] = new GlslLibRoutine( ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides.slice().reverse()), ); }); return result; } static offsetToIndicesSingle(name: string, rank: number, strides: readonly number[]): string { const stridesBlock = []; for (let i = 0; i < rank - 1; ++i) { stridesBlock.push(` indices[${i}] = offset / ${strides[i]};`); stridesBlock.push(` offset -= indices[${i}] * ${strides[i]};`); } stridesBlock.push(` indices[${rank - 1}] = offset;`); return ` void ${name}(int offset, out int indices[${rank}]) { ${stridesBlock.join('')} } `; } protected incrementIndices(): { [name: string]: GlslLibRoutine } { const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const shape = this.context.inputTextureLayouts[i].shape; const rank = shape.length; const funcName = `incrementIndices_${name}`; let shapeInit = ''; for (let i = 0; i < rank; ++i) { shapeInit += ` shape[${i}] = ${shape[i]};`; } const body = ` void ${funcName}(int axis, out int indices[${rank}]) { int shape[${rank}]; ${shapeInit}; for(int i = ${rank} -1 ; i >= 0; --i) { if(i > axis) continue; indices[i] += 1; if(indices[i] < shape[i]) { break; } indices[i] = 0; } } `; result[funcName] = new GlslLibRoutine(body); }); return result; } }