UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

238 lines (227 loc) 11.2 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; import { createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType, } from './common'; export interface DequantizeLinerAttributes extends AttributeWithCacheKey { axis: number; blockSize: number; } const validateInputs = (inputs: readonly TensorView[], attributes: DequantizeLinerAttributes): void => { if (inputs.length < 2 || inputs.length > 3) { throw new Error('DequantizeLinear requires 2 or 3 inputs.'); } if (inputs.length === 3 && inputs[1].dims === inputs[2].dims) { throw new Error('x-scale and x-zero-point must have the same shape.'); } if (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType) { throw new Error('x and x-zero-point must have the same data type.'); } if (inputs[0].dataType === DataType.int32 && inputs.length > 2) { throw new Error('In the case of dequantizing int32 there is no zero point.'); } if (inputs[1].dims.length !== 0 && inputs[1].dims.length !== 1 && inputs[1].dims.length !== inputs[0].dims.length) { throw new Error('scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor.'); } // validate scale and zero-point input shapes if (inputs.length > 2) { // zero-point input type should be the same as input data type. if (inputs[0].dataType !== inputs[2].dataType) { throw new Error('x and x-zero-point must have the same data type.'); } // Scale and zero-point inputs must have the same shape if (inputs[1].dims.length !== inputs[2].dims.length) { throw new Error('scale and zero-point inputs must have the same rank.'); } if (!inputs[1].dims.map((d, i) => d === inputs[2].dims[i]).reduce((a, b) => a && b, true)) { throw new Error('scale and zero-point inputs must have the same shape.'); } } // Validate blockSize if (attributes.blockSize > 0) { // Block qunatization if (inputs[1].dims.length === 0 || (inputs[1].dims.length === 1 && inputs[1].dims[0] === 1)) { throw new Error('blockSize must be set only for block quantization.'); } if ( !inputs[1].dims.map((d, i) => i === attributes.axis || d === inputs[0].dims[i]).reduce((a, b) => a && b, true) ) { throw new Error('For block qunatization, scale input shape to match the input shape except for the axis'); } // Scale input rank should be same as the input rank if (inputs[1].dims.length !== inputs[0].dims.length) { throw new Error('For block qunatization the scale input rank must be the same as the x rank.'); } const dI = inputs[0].dims[attributes.axis]; const si = inputs[1].dims[attributes.axis]; if (attributes.blockSize < Math.ceil(dI / si) || attributes.blockSize > Math.ceil(dI / (si - 1) - 1)) { throw new Error('blockSize must be with in the range [ceil(dI / Si), ceil(dI / (Si - 1) - 1)].'); } } }; const createDequantizeLinearProgramInfo = ( inputs: readonly TensorView[], attributes: DequantizeLinerAttributes, ): ProgramInfo => { const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length); const inputType = inputs[0].dataType; const isSigned = inputType === DataType.int8; const outputShape = inputs[0].dims; // output shape is same as the input shape const dataType = inputs[1].dataType; // output type is same as the the scale input type const outputSize = ShapeUtil.size(outputShape); const isPacked = inputType === DataType.int8 || inputType === DataType.uint8; const inputShape = isPacked ? [Math.ceil(ShapeUtil.size(inputs[0].dims) / 4)] : inputs[0].dims; const scaleShape = inputs[1].dims; const zeroPointInput = inputs.length > 2 ? inputs[2] : undefined; const zeroPointShape = zeroPointInput ? isPacked ? [Math.ceil(ShapeUtil.size(zeroPointInput.dims) / 4)] : zeroPointInput.dims : undefined; // Scales input is a scaler for per-tensor/per-layer quantization, 1-D tensor for per-axis quantization // or tensor with same rank as input for blocked quantization. const perLayerQuantization = scaleShape.length === 0 || (scaleShape.length === 1 && scaleShape[0] === 1); const perAxisQuantization = perLayerQuantization === false && scaleShape.length === 1; // Left unnecessary commented-out assignment for documentation // const blockQuantization = perLayerQuantization === false && perAxisQuantization === false; const maxComponents = getMaxComponents(outputSize); const useComponents = perLayerQuantization && (!isPacked || maxComponents === 4); const components = useComponents ? maxComponents : 1; const inputComponent = useComponents && !isPacked ? maxComponents : 1; const input = inputVariable('input', isPacked ? DataType.uint32 : inputType, inputShape.length, inputComponent); const scale = inputVariable('scale', dataType, scaleShape.length); const zeroPoint = zeroPointInput ? inputVariable('zero_point', isPacked ? DataType.uint32 : inputType, zeroPointShape!.length) : undefined; const output = outputVariable('output', dataType, outputShape.length, components); const inputVariables = [input, scale]; if (zeroPoint) { inputVariables.push(zeroPoint); } const inputShapes = [inputShape, scaleShape]; if (zeroPointInput) { inputShapes.push(zeroPointShape!); } const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize / components }, { type: DataType.uint32, data: axis }, { type: DataType.uint32, data: attributes.blockSize }, ...createTensorShapeVariables(...inputShapes, outputShape), ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const uniforms: UniformsArrayType = [ { name: 'output_size', type: 'u32' }, { name: 'axis', type: 'u32' }, { name: 'block_size', type: 'u32' }, ]; return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let output_indices = ${output.offsetToIndices('global_idx')}; // Set input x ${(() => { if (isPacked) { return ` let input = ${input.getByOffset('global_idx / 4')}; let x_vec = ${isSigned ? 'unpack4xI8(input)' : 'unpack4xU8(input)'}; let x_value = ${components === 1 ? 'x_vec[global_idx % 4]' : 'x_vec'};`; } else { return `let x_value = ${input.getByOffset('global_idx')};`; } })()}; // Set scale input ${(() => { if (perLayerQuantization) { // scale input is a scalar () return `let scale_value= ${scale.getByOffset('0')}`; } else if (perAxisQuantization) { // scale input is a 1D tensor return ` let scale_index = ${output.indicesGet('output_indices', 'uniforms.axis')}; let scale_value= ${scale.getByOffset('scale_index')};`; } else { // Block quantization. Scale input rank is same as input/output rank. return ` var scale_indices: ${scale.type.indices} = output_indices; let index = ${scale.indicesGet('scale_indices', 'uniforms.axis')} / uniforms.block_size; ${scale.indicesSet('scale_indices', 'uniforms.axis', 'index')}; let scale_value= ${scale.getByIndices('scale_indices')};`; } })()}; // Set zero-point input ${(() => { if (zeroPoint) { if (perLayerQuantization) { // zero-point input is a scalar if (isPacked) { return ` let zero_point_input = ${zeroPoint.getByOffset('0')}; let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'}; let zero_point_value= zero_point_vec[0]`; } else { return `let zero_point_value = ${zeroPoint.getByOffset('0')}`; } } else if (perAxisQuantization) { // zero-point input is a 1D tensor if (isPacked) { return ` let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')}; let zero_point_input = ${zeroPoint.getByOffset('zero_point_index / 4')}; let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'}; let zero_point_value = zero_point_vec[zero_point_index % 4]`; } else { return ` let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')}; let zero_point_value = ${zeroPoint.getByOffset('zero_point_index')};`; } } else { // BlockedQuantization. The zero-point input shape is same as the input shape except along axis. if (isPacked) { return ` let zero_point_offset = ${scale.indicesToOffset('scale_indices')}; let zero_point_input = ${zeroPoint.getByOffset('zero_point_offset / 4')}; let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'}; let zero_point_value = zero_point_vec[zero_point_offset % 4];`; } else { return `let zero_point_value = ${zeroPoint.getByIndices('scale_indices')};`; } } } else { return `let zero_point_value = ${isPacked ? (isSigned ? 'i32' : 'u32') : input.type.value}(0);`; } })()}; // Compute and write output ${output.setByOffset('global_idx', `${output.type.value}(x_value - zero_point_value) * scale_value`)}; }`; }; return { name: 'DequantizeLinear', shaderCache: { hint: attributes.cacheKey, inputDependencies: zeroPoint ? ['rank', 'rank', 'rank'] : ['rank', 'rank'], }, getShaderSource, getRunData: () => ({ outputs: [{ dims: outputShape, dataType }], dispatchGroup: { x: Math.ceil(outputSize / components / 64), y: 1, z: 1 }, programUniforms, }), }; }; export const dequantizeLinear = (context: ComputeContext, attributes: DequantizeLinerAttributes): void => { validateInputs(context.inputs, attributes); context.compute(createDequantizeLinearProgramInfo(context.inputs, attributes)); }; export const parseDequantizeLinearAttributes = (attributes: Record<string, unknown>): DequantizeLinerAttributes => createAttributeWithCacheKey({ axis: attributes.axis as number, blockSize: attributes.blockSize as number });