onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
305 lines (285 loc) • 13.7 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
export interface EinsumAttributes extends AttributeWithCacheKey {
readonly equation: string;
}
// The equation attribute value is a string which consists of left hand side (LHS) and optionally right hand side (RHS)
// separated by '->'. Ex. "ij,jk -> ik" expresses matrix multiplication
// "ij->ji" expresses matrix transpose
// "ii->i" diagonal elements of a square matrix
// LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable.
// Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to
// 'Z' or '...' to represent arbitrary dimensions.
const symbolPattern =
'[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match
const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match
const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end.
const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match
const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end.
interface SymbolInfo {
count: number; // Symbol corresponding to a dimmension of an input
inputIndices: number[]; // Number of input variables the symbol corresponds to
dimValue: number; // Number of dimensions the symbol corresponds to
}
class EinsumTerm {
constructor(inputIndex = -1) {
this.symbolToIndices = new Map<string, number[]>();
this.inputIndex = inputIndex;
}
// Add a symbol to the term
addSymbol(symbol: string, index: number) {
let value = this.symbolToIndices.get(symbol);
if (value === undefined) {
value = [index];
} else {
value.push(index);
}
this.symbolToIndices.set(symbol, value);
}
symbolToIndices: Map<string, number[]>; // Map from symbol to dimensions of the input corresponding to the term
inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs
}
class EinsumEquation {
constructor(inputs: readonly TensorView[], public readonly equation: string) {
this.hasEllipsis = false;
this.symbolToInfo = new Map<string, SymbolInfo>();
this.lhs = new Array<EinsumTerm>();
this.outputDims = [];
// As rhs needs to be updated allow using let instead of const for both lhs and rhs.
// eslint-disable-next-line prefer-const
let [lhs, rhs] = equation.includes('->') ? equation.split('->', 2) : [equation, ''];
if (!lhs.match(RegExp(lhsPatternOnly))) {
throw new Error('Invalid LHS term');
}
const inputTerms = lhs.split(',');
inputTerms.forEach((inputTerm, index) => {
const dims = inputs[index].dims.slice();
if (!inputTerm.match(RegExp(termPatternOnly))) {
throw new Error('Invalid LHS term');
}
const einsumTerm = this.processTerm(inputTerm, true, dims, index);
this.lhs.push(einsumTerm);
});
// Initialize the RHS if not specified
if (rhs === '') {
// Construct RHS from LHS terms/symbols
rhs += [...this.symbolToInfo.entries()]
.filter(([sym, info]) => (info.count === 1 || sym === '...'))
.map(([sym]) => sym)
.join('');
} else {
if (!rhs.match(RegExp(termPattern))) {
throw new Error('Invalid RHS');
}
}
// Compute output dims
const rhsSymbols = rhs.match(RegExp(symbolPattern, 'g'));
rhsSymbols?.forEach((symbol) => {
if (symbol === '...') {
this.outputDims = this.outputDims.concat(this.ellipsisDims);
} else {
const info = this.symbolToInfo.get(symbol);
if (info === undefined) {
throw new Error('Invalid RHS symbol');
}
this.outputDims.push(info.dimValue);
}
});
this.rhs = this.processTerm(rhs, false, this.outputDims);
} // End of EinsumEqation constructor
// Add a symbol to the equation
addSymbol(symbol: string, dimValue: number, inputIndex: number) {
let info = this.symbolToInfo.get(symbol);
if (info !== undefined) {
if (info.dimValue !== dimValue && info.count !== 1) {
throw new Error('Dimension mismatch');
} else {
info.count++;
info.inputIndices.push(inputIndex);
}
} else {
info = {count: 1, dimValue, inputIndices: [inputIndex]};
}
this.symbolToInfo.set(symbol, info);
}
// Process one input/output term
processTerm(term: string, isInput: boolean, dims: readonly number[], index = -1): EinsumTerm {
const rank = dims.length;
let ellipsis = false;
let ellipsisDims = [];
let nextDim = 0;
// For output empty string is allowed because the output may be reduced to a scalar value
if (!term.match(RegExp(termPatternOnly)) && (!isInput && term !== '')) {
throw new Error('Invalid LHS term');
}
const indexSymbols = term.match(RegExp(symbolPattern, 'g'));
const einsumTerm = new EinsumTerm(index);
// symbol can be either a lettre, 'a' to 'z' or 'A' to 'Z', or '...'
indexSymbols?.forEach((symbol: string, i: number) => {
if (symbol === '...') {
if (ellipsis) {
throw new Error('Only one ellipsis is allowed per input term');
}
ellipsis = true;
const ellipsisDimLength = rank - indexSymbols.length + 1;
if (ellipsisDimLength < 0) {
throw new Error('Ellipsis out of bounds');
}
ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength);
if (this.hasEllipsis) {
if (this.ellipsisDims.length !== ellipsisDims.length ||
this.ellipsisDims.toString() !== ellipsisDims.toString()) {
throw new Error('Ellipsis dimensions mismatch');
}
} else if (isInput) {
this.hasEllipsis = true;
this.ellipsisDims = ellipsisDims;
} else {
throw new Error('Ellipsis must be specified in the LHS');
}
// Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling
for (let j = 0; j < ellipsisDims.length; j++) {
const symbol = String.fromCharCode('0'.charCodeAt(0) + j);
einsumTerm.addSymbol(symbol, i + j);
this.addSymbol(symbol, dims[nextDim++], index);
}
} else {
einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0));
this.addSymbol(symbol, dims[nextDim++], index);
}
});
return einsumTerm;
}
symbolToInfo: Map<string, SymbolInfo>; // All symbols in the equation
hasEllipsis: boolean; // The equation has ellipsis or not
ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to.
lhs: EinsumTerm[]; // Terms on the left-hand side of the equation
rhs: EinsumTerm; // Term on the right-hand side of the equation
outputDims: number[]; // Output dimensions of the equation
} // End of class EinsumEquation
const appendMax = (name: string): string => name + '_max';
const createEinsumProgramInfo =
(inputShapes: Array<readonly number[]>, dataType: number, einsumEquation: EinsumEquation,
outputShape: readonly number[]): ProgramInfo => {
const ranks = inputShapes.map((dims) => dims.length);
const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank));
const outputSize = ShapeUtil.size(outputShape);
const output = outputVariable('output', dataType, outputShape.length);
const uniformsSymbols =
[...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol));
const getShaderSource = (shaderHelper: ShaderHelper) => {
const idxCopy: string[] = [];
const initProd = 'var prod = 1.0;';
const initSum = 'var sum = 0.0;';
const updateSum = 'sum += prod;';
const reduceOpsSetIndices: string[] = [];
const reduceOpsLoopHeaders: string[] = [];
const reduceOpsLoopFooters: string[] = [];
const reduceOpCompute: string[] = [];
const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size;
einsumEquation.symbolToInfo.forEach((info, symbol) => {
if (einsumEquation.rhs.symbolToIndices.has(symbol)) {
const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0];
if (outputIndex !== undefined) {
einsumEquation.lhs.forEach((term, i) => {
if (info.inputIndices.includes(i)) {
const indices = term.symbolToIndices.get(symbol);
if (indices === undefined) {
throw new Error('Invalid symbol error');
}
indices.forEach((index) => {
idxCopy.push(`${
inputVars[i].indicesSet(
`input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`);
});
}
});
}
} else {
einsumEquation.lhs.forEach((term, i) => {
if (info.inputIndices.includes(i)) {
const indices = term.symbolToIndices.get(symbol);
if (indices === undefined) {
throw new Error('Invalid symbol error');
}
indices.forEach((index) => {
reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
});
reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
}
});
reduceOpsLoopHeaders.push(
`for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`);
reduceOpsLoopFooters.push('}');
}
});
const reduceOps = isReduceOpsWithoutLoop ?
[
...idxCopy,
`let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`
] :
[
...idxCopy,
initSum,
...reduceOpsLoopHeaders,
...reduceOpsSetIndices,
initProd,
...reduceOpCompute,
updateSum,
...reduceOpsLoopFooters,
];
return `
${
shaderHelper
.registerUniforms(uniformsSymbols.map((symbol) => ({name: `${appendMax(symbol)}`, type: 'u32'})))
.registerUniform('outputSize', 'u32')
.declareVariables(...inputVars, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
var outputIndices = ${output.offsetToIndices('global_idx')};
${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
${reduceOps.join('\n')};
${output.setByOffset('global_idx', 'sum')};
}`;
};
return {
name: 'Einsum',
shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')},
getRunData: () => {
// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
// filter is added to make sure that dimValue is never 0.
const programUniformsInit: ProgramUniform[] =
uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
.map(
(symbol) =>
({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
programUniformsInit.push({type: DataType.uint32, data: outputSize});
const programUniforms: ProgramUniform[] =
inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)])
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
programUniforms.push(...createTensorShapeVariables(outputShape));
return ({
outputs: [{dims: outputShape, dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
});
},
getShaderSource,
};
};
export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
const outputShape = einsumEquation.outputDims;
const inputShapes = context.inputs.map((input, _) => input.dims);
context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
};
export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
const equation = (attributes.equation as string).replace(/\s+/g, '');
return createAttributeWithCacheKey({equation});
};