onnxruntime-web
Version:
A Javascript library for running ONNX models on browsers
324 lines (301 loc) • 9.84 kB
text/typescript
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { Tensor } from '../../../tensor';
import { BroadcastUtil, ShapeUtil } from '../../../util';
import { FunctionType, GlslValueFunction } from '../glsl-definitions';
import { getGlsl } from '../glsl-source';
import { WebGLInferenceHandler } from '../inference-handler';
import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types';
export function glslAdd(): GlslValueFunction {
const name = 'add_';
const body = `
float ${name}(float a, float b) {
return a + b;
}
vec4 ${name}(vec4 v1, vec4 v2) {
return v1 + v2;
}
`;
return { body, name, type: FunctionType.ValueBased };
}
export function glslDiv(): GlslValueFunction {
const name = 'div_';
const body = `
float ${name}(float a, float b) {
return a / b;
}
vec4 ${name}(vec4 v1, vec4 v2) {
return v1 / v2;
}
`;
return { body, name, type: FunctionType.ValueBased };
}
export function glslMul(): GlslValueFunction {
const name = 'mul_';
const body = `
float ${name}(float a, float b) {
return a * b;
}
vec4 ${name}(vec4 v1, vec4 v2) {
return v1 * v2;
}
`;
return { body, name, type: FunctionType.ValueBased };
}
export function glslSub(): GlslValueFunction {
const name = 'sub_';
const body = `
float ${name}(float a, float b) {
return a - b;
}
vec4 ${name}(vec4 v1, vec4 v2) {
return v1 - v2;
}
`;
return { body, name, type: FunctionType.ValueBased };
}
export function glslEqual(): GlslValueFunction {
const name = 'equal_';
const body = `
float ${name}(float a, float b) {
return float(a == b);
}
vec4 ${name}(vec4 v1, vec4 v2) {
return vec4(equal(v1, v2));
}
`;
return { body, name, type: FunctionType.ValueBased };
}
export function glslGreater(): GlslValueFunction {
const name = 'greater_';
const body = `
float ${name}(float a, float b) {
return float(a > b);
}
vec4 ${name}(vec4 v1, vec4 v2) {
return vec4( v1.r > v2.r ,
v1.g > v2.g,
v1.b > v2.b,
v1.a > v2.a );
}
`;
return { body, name, type: FunctionType.ValueBased };
}
export function glslLess(): GlslValueFunction {
const name = 'less_';
const body = `
float ${name}(float a, float b) {
return float(a < b);
}
vec4 ${name}(vec4 v1, vec4 v2) {
return vec4( v1.r < v2.r ,
v1.g < v2.g,
v1.b < v2.b,
v1.a < v2.a );
}
`;
return { body, name, type: FunctionType.ValueBased };
}
export function glslAnd(): GlslValueFunction {
const name = 'and_';
const body = `
float ${name}(float a, float b) {
return float( bool(a) && bool(b) );
}
vec4 ${name}(vec4 v1, vec4 v2) {
bvec4 b1 = bvec4(v1);
bvec4 b2 = bvec4(v2);
return vec4( b1.r && b2.r ,
b1.g && b2.g,
b1.b && b2.b,
b1.a && b2.a );
}
`;
return { body, name, type: FunctionType.ValueBased };
}
export function glslOr(): GlslValueFunction {
const name = 'or_';
const body = `
float ${name}(float a, float b) {
return float( bool(a) || bool(b) );
}
vec4 ${name}(vec4 v1, vec4 v2) {
bvec4 b1 = bvec4(v1);
bvec4 b2 = bvec4(v2);
return vec4( b1.r || b2.r ,
b1.g || b2.g,
b1.b || b2.b,
b1.a || b2.a );
}
`;
return { body, name, type: FunctionType.ValueBased };
}
export function glslXor(): GlslValueFunction {
const name = 'xor_';
const body = `
float ${name}(float a, float b) {
return float( bool(a) ^^ bool(b) );
}
vec4 ${name}(vec4 v1, vec4 v2) {
bvec4 b1 = bvec4(v1);
bvec4 b2 = bvec4(v2);
return vec4( b1.r ^^ b2.r ,
b1.g ^^ b2.g,
b1.b ^^ b2.b,
b1.a ^^ b2.a );
}
`;
return { body, name, type: FunctionType.ValueBased };
}
export function glslPow(): GlslValueFunction {
return glslBuiltinBinary('pow');
}
export function glslPRelu(): GlslValueFunction {
const name = 'prelu_';
const body = `
float ${name}(float a, float b) {
return a < 0.0 ? a * b: a;
}
vec4 ${name}(vec4 v1, vec4 v2) {
return vec4(
v1.r < 0.0 ? v1.r * v2.r: v1.r,
v1.g < 0.0 ? v1.g * v2.g: v1.g,
v1.b < 0.0 ? v1.b * v2.b: v1.b,
v1.a < 0.0 ? v1.a * v2.a: v1.a
);
}
`;
return { body, name, type: FunctionType.ValueBased };
}
function glslBuiltinBinary(fname: string): GlslValueFunction {
const name = `${fname}_`;
const body = `
float ${name}(float a, float b) {
return ${fname}(a, b);
}
vec4 ${name}(vec4 v1, vec4 v2) {
return ${fname}(v1, v2);
}
`;
return { body, name, type: FunctionType.ValueBased };
}
const createBinaryProgramInfoLoader = (
handler: WebGLInferenceHandler,
inputs: Tensor[],
glslFunc: GlslValueFunction,
outputTensorType: Tensor.DataType = inputs[0].type,
cacheKey?: string,
): ProgramInfoLoader => {
const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked;
return {
name: glslFunc.name,
inputNames: ['A', 'B'],
inputTypes: [textureType, textureType],
cacheHint: cacheKey,
get: () => createBinaryProgramInfo(handler, inputs, glslFunc, outputTensorType),
};
};
const createBinaryProgramInfo = (
handler: WebGLInferenceHandler,
inputs: Tensor[],
glslFunc: GlslValueFunction,
outputTensorType: Tensor.DataType = inputs[0].type,
): ProgramInfo => {
const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked;
const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims);
let outputShape = inputs[0].dims;
const usePackedTexture = handler.session.pack;
if (isBroadcast) {
const calculatedShape = BroadcastUtil.calcShape(inputs[0].dims, inputs[1].dims, false);
if (!calculatedShape) {
throw new Error("Can't perform binary op on the given tensors");
}
outputShape = calculatedShape;
const outputRank = outputShape.length;
const aRank = inputs[0].dims.length !== 0 ? inputs[0].dims.length : 1;
const bRank = inputs[1].dims.length !== 0 ? inputs[1].dims.length : 1;
const aBcast = inputs[0].dims.length !== 0 ? 'bcastIndices_A(indices, aindices);' : 'aindices[0] = 0;';
const bBcast = inputs[1].dims.length !== 0 ? 'bcastIndices_B(indices, bindices);' : 'bindices[0] = 0;';
const glsl = getGlsl(handler.session.backend.glContext.version);
const shaderSource = usePackedTexture
? `
${glslFunc.body}
void main() {
vec4 a = getAAtOutCoords();
vec4 b = getBAtOutCoords();
vec4 result = ${glslFunc.name}(a, b);
${glsl.output} = result;
}`
: `
${glslFunc.body}
float process(int indices[${outputRank}]) {
int aindices[${aRank}];
int bindices[${bRank}];
${aBcast}
${bBcast}
return ${glslFunc.name}(_A(aindices), _B(bindices));
}`;
return {
name: glslFunc.name,
inputNames: ['A', 'B'],
inputTypes: [textureType, textureType],
output: { dims: outputShape, type: outputTensorType, textureType },
shaderSource,
hasMain: usePackedTexture,
};
}
const glsl = getGlsl(handler.session.backend.glContext.version);
const shaderSource = `
${glslFunc.body}
void main() {
vec4 v1 = ${glsl.texture2D}(A, TexCoords);
vec4 v2 = ${glsl.texture2D}(B, TexCoords);
vec4 result = ${glslFunc.name}(v1, v2);
${glsl.output} = result;
}
`;
return {
name: glslFunc.name,
inputNames: ['A', 'B'],
inputTypes: [textureType, textureType],
output: { dims: inputs[0].dims, type: outputTensorType, textureType },
shaderSource,
hasMain: true,
};
};
export const add = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAdd()), inputs),
];
export const and = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAnd(), 'bool'), inputs),
];
export const div = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslDiv()), inputs),
];
export const equal = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslEqual(), 'bool'), inputs),
];
export const greater = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslGreater(), 'bool'), inputs),
];
export const less = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslLess(), 'bool'), inputs),
];
export const mul = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslMul()), inputs),
];
export const or = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslOr(), 'bool'), inputs),
];
export const pow = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPow()), inputs),
];
export const pRelu = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPRelu()), inputs),
];
export const sub = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslSub()), inputs),
];
export const xor = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [
handler.run(createBinaryProgramInfoLoader(handler, inputs, glslXor(), 'bool'), inputs),
];