UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

282 lines (256 loc) 10.2 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common'; let [idxN, idxC, idxH, idxW] = [0, 1, 2, 3]; // NCHW type Mode = 'bilinear' | 'nearest' | 'bicubic'; type PaddingMode = 'zeros' | 'border' | 'reflection'; type Format = 'NHWC' | 'NCHW'; export interface GridSampeAttributes extends AttributeWithCacheKey { alignCorners: number; mode: Mode; paddingMode: PaddingMode; format: Format; } const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims.length !== 4) { throw new Error('only 4-D tensor is supported.'); } if (inputs[0].dims.length !== inputs[1].dims.length) { throw new Error('input dimensions must be equal to grid dimensions'); } if (inputs[0].dims.length - 2 !== inputs[1].dims[inputs[1].dims.length - 1]) { throw new Error(`last dimension of grid must be equal to ${inputs[0].dims.length - 2}`); } if (inputs[0].dims[0] !== inputs[1].dims[0]) { throw new Error('grid batch size must match input batch size'); } }; const gsGetCubicCoeffs = ` fn gs_get_cubic_coeffs(x: f32) -> vec4<f32> { let cubic_alpha = -0.75f; let x_abs = abs(x); var coeffs: vec4<f32>; coeffs[0] = (((cubic_alpha * (x_abs + 1) - 5 * cubic_alpha) * (x_abs + 1) + 8 * cubic_alpha) * (x_abs + 1) - 4 * cubic_alpha); coeffs[1] = (((cubic_alpha + 2) * x_abs - (cubic_alpha + 3)) * x_abs * x_abs + 1); coeffs[2] = (((cubic_alpha + 2) * (1 - x_abs) - (cubic_alpha + 3)) * (1 - x_abs) * (1 - x_abs) + 1); coeffs[3] = (((cubic_alpha * (2 - x_abs) - 5 * cubic_alpha) * (2 - x_abs) + 8 * cubic_alpha) * (2 - x_abs) - 4 * cubic_alpha); return coeffs; } `; const gsBicubicInterpolate = (dataType: string): string => ` fn gs_bicubic_interpolate(p: mat4x4<${dataType}>, x: f32, y: f32) -> ${dataType} { var v: vec4<f32>; var coeffs = gs_get_cubic_coeffs(x); for (var i = 0; i < 4; i++) { v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3]; } coeffs = gs_get_cubic_coeffs(y); let pixel = ${dataType}(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]); return pixel; } `; const gsDenormalize = (attributes: GridSampeAttributes): string => ` fn gs_denormalize(n: f32, length: i32) -> f32 { ${ attributes.alignCorners === 0 ? ` // alignCorners: false => [-1, 1] to [-0.5, length - 0.5] return ((n + 1.0) * f32(length) - 1.0) / 2.0; ` : ` // alignCorners: true => [-1, 1] to [0, length - 1] return (n + 1.0) / 2.0 * (f32(length - 1)); ` } } `; const gsReflect = (attributes: GridSampeAttributes): string => ` ${ attributes.paddingMode === 'reflection' ? ` fn gs_reflect(x: i32, x_min: f32, x_max: f32) -> u32 { var dx = 0.0; var fx = f32(x); let range = x_max - x_min; if (fx < x_min) { dx = x_min - fx; let n = u32(dx / range); let r = dx - f32(n) * range; if (n % 2 == 0) { fx = x_min + r; } else { fx = x_max - r; } } else if (fx > x_max) { dx = fx - x_max; let n = u32(dx / range); let r = dx - f32(n) * range; if (n % 2 == 0) { fx = x_max - r; } else { fx = x_min + r; } } return u32(fx); }` : '' } `; const pixelAtGrid = (input: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string => ` fn pixel_at_grid(r: i32, c: i32, H: i32, W: i32, batch: u32, channel: u32, border: vec4<f32>) -> ${dataType} { var pixel = ${dataType}(0); var indices = vec4<u32>(0); indices[${idxN}] = batch; indices[${idxC}] = channel;` + (() => { switch (attributes.paddingMode) { case 'zeros': return ` if (r >= 0 && r < H && c >=0 && c < W) { indices[${idxH}] = u32(r); indices[${idxW}] = u32(c); } else { return ${dataType}(0); } `; case 'border': return ` indices[${idxH}] = u32(clamp(r, 0, H - 1)); indices[${idxW}] = u32(clamp(c, 0, W - 1)); `; case 'reflection': return ` indices[${idxH}] = gs_reflect(r, border[1], border[3]); indices[${idxW}] = gs_reflect(c, border[0], border[2]); `; default: throw new Error(`padding mode ${attributes.paddingMode} is not supported`); } })() + ` return ${input.getByIndices('indices')}; } `; const computePixel = (output: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string => (() => { switch (attributes.mode) { case 'nearest': return ` let result = pixel_at_grid(i32(round(y)), i32(round(x)), H_in, W_in, indices[${idxN}], indices[${idxC}], border); `; case 'bilinear': return ` let x1 = i32(floor(x)); let y1 = i32(floor(y)); let x2 = x1 + 1; let y2 = y1 + 1; let p11 = pixel_at_grid(y1, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border); let p12 = pixel_at_grid(y1, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border); let p21 = pixel_at_grid(y2, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border); let p22 = pixel_at_grid(y2, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border); let dx2 = ${dataType}(f32(x2) - x); let dx1 = ${dataType}(x - f32(x1)); let dy2 = ${dataType}(f32(y2) - y); let dy1 = ${dataType}(y - f32(y1)); let result = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); `; case 'bicubic': return ` let x0 = i32(floor(x)) - 1; let y0 = i32(floor(y)) - 1; var p: mat4x4<${dataType}>; for (var h = 0; h < 4; h++) { for (var w = 0; w < 4; w++) { p[h][w] = pixel_at_grid(h + y0, w + x0, H_in, W_in, indices[${idxN}], indices[${idxC}], border); } } let dx = x - f32(x0 + 1); let dy = y - f32(y0 + 1); let result = gs_bicubic_interpolate(p, dx, dy); `; default: throw new Error(`mode ${attributes.mode} is not supported`); } })() + `${output.setByOffset('global_idx', 'result')}`; const createGridSampleProgramInfo = (inputs: readonly TensorView[], attributes: GridSampeAttributes): ProgramInfo => { const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length); // discard last dimension for using vec2 to access grid data const gridShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2]]; const grid = inputVariable('grid', inputs[1].dataType, gridShape.length, 2); let outputShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[1].dims[1], inputs[1].dims[2]]; if (attributes.format === 'NHWC') { outputShape = [inputs[0].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[0].dims[3]]; [idxN, idxC, idxH, idxW] = [0, 3, 1, 2]; } const output = outputVariable('output', inputs[0].dataType, outputShape.length); const dataType = x.type.value; const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, ...createTensorShapeVariables(inputs[0].dims, gridShape, outputShape), ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(x, grid, output)} ${gsGetCubicCoeffs} ${gsBicubicInterpolate(dataType)} ${gsDenormalize(attributes)} ${gsReflect(attributes)} ${pixelAtGrid(x, dataType, attributes)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let H_in = i32(uniforms.x_shape[${idxH}]); let W_in = i32(uniforms.x_shape[${idxW}]); ${ attributes.alignCorners === 0 ? ` let x_min = -0.5; let x_max = f32(W_in) - 0.5; let y_min = -0.5; let y_max = f32(H_in) - 0.5; ` : ` let x_min = 0.0; let x_max = f32(W_in) - 1.0; let y_min = 0.0; let y_max = f32(H_in) - 1.0; ` }; let border = vec4<f32>(x_min, y_min, x_max, y_max); let indices = ${output.offsetToIndices('global_idx')}; var grid_indices = vec3<u32>(indices[${idxN}], indices[${idxH}], indices[${idxW}]); let nxy = ${grid.getByIndices('grid_indices')}; var x = gs_denormalize(f32(nxy[0]), W_in); var y = gs_denormalize(f32(nxy[1]), H_in); ${computePixel(output, dataType, attributes)} }`; return { name: 'GridSample', shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies: ['type', 'type'] }, getRunData: (inputs) => { const outputSize = ShapeUtil.size(outputShape); return { outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms, }; }, getShaderSource, }; }; export const gridSample = (context: ComputeContext, attributes: GridSampeAttributes): void => { validateInputs(context.inputs); context.compute(createGridSampleProgramInfo(context.inputs, attributes)); }; export const parseGridSampleAttributes = (attributes: Record<string, unknown>): GridSampeAttributes => createAttributeWithCacheKey({ alignCorners: attributes.align_corners as number, mode: attributes.mode as Mode, paddingMode: attributes.padding_mode as PaddingMode, format: attributes.format as Format, });