UNPKG

onnxruntime-web

Version:

A Javascript library for running ONNX models on browsers

406 lines (378 loc) 16.1 kB
// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { TensorView } from '../../tensor-view'; import { createAttributeWithCacheKey } from '../attribute-with-cache-key'; import { ComputeContext, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; import { DataType } from '../../../wasm-common'; import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention'; import { maybeTransposeToBNSHAndAddBias } from './multihead-attention'; import { createSplitProgramInfo, SplitAttributes } from './split'; import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; import { RotaryEmbeddingAttributes, createRotaryEmbeddingProgramInfo } from './rotary-embedding'; import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; export interface GroupQueryAttentionAttributes { numHeads: number; kvNumHeads: number; scale: number; softcap: number; doRotary: number; rotaryInterleaved: number; smoothSoftmax: boolean; localWindowSize: number; } export const validateInputs = ( inputs: readonly TensorView[], attributes: GroupQueryAttentionAttributes, ): AttentionParameters => { if (attributes.doRotary && inputs.length <= 7) { throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified'); } const query = inputs[0]; const key = inputs[1]; const value = inputs[2]; const pastKey = inputs[3]; const pastValue = inputs[4]; if (attributes.doRotary !== 0 && inputs.length <= 7) { throw new Error('cos_cast and sin_cache are expected if do_rotary attribute is non-zero'); } if (attributes.localWindowSize !== -1) { throw new Error('Local attention is not supported'); } if (attributes.softcap !== 0) { throw new Error('Softcap is not supported'); } if (attributes.rotaryInterleaved !== 0) { throw new Error('Rotary interleaved is not supported'); } if (attributes.smoothSoftmax) { throw new Error('Smooth softmax is not supported'); } // Abbreviation and Meanings: // B: batch_size // S: sequence_length (input sequence length of query) // P: past_sequence_length (past sequence length of key or value) // L: kv_sequence_length (input sequence length of key or value) // M: max_sequence_length // T: total_sequence_length = past_sequence_length + kv_sequence_length // N: num_heads // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size // H_v: v_head_size // D_i: input hidden size // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size // D_v: v_hidden_size = num_heads * v_head_size // past_key : (B, N, S*, H) // past_value : (B, N, S*, H) // When no packing for q/k/v: // query (Q) : (B, S, D) // key (K) : (B, L, D) or (B, N, S*, H) // value (V) : (B, L, D_v) or (B, N, S*, H) // When packed kv is used: // query (Q) : (B, S, D) // key (K) : (B, L, N, 2, H) // value (V) : None // When packed qkv is used: // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) // key (K) : None // value (V) : None if (query.dims.length !== 3 && query.dims.length !== 5) { throw new Error('Input query is expected to have 3 or 5 dimensions'); } const dmmhaPacking = false; const batchSize = query.dims[0]; const sequenceLength = query.dims[1]; let hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4]; let kvSequenceLength = sequenceLength; let pastSequenceLength = 0; const packedQKV = !key || key.dims.length === 0; const headSize = !packedQKV ? Math.floor(hiddenSize / attributes.numHeads) : Math.floor(hiddenSize / (attributes.numHeads + 2 * attributes.kvNumHeads)); if (packedQKV) { hiddenSize = headSize * attributes.numHeads; } const hasPastKey = pastKey && pastKey.dims.length !== 0; const hasPastValue = pastValue && pastValue.dims.length !== 0; // Currenly the onnxruntime GQA specification only support key/value BNSH format. const isPastkvBSNH = hasPastKey && pastKey.dims.length === 4 && pastKey.dims[0] === batchSize && pastKey.dims[1] !== attributes.kvNumHeads && pastKey.dims[2] === attributes.kvNumHeads && pastKey.dims[3] === headSize; if (isPastkvBSNH) { throw new Error('BSNH pastKey/pastValue is not supported'); } if (hasPastKey && hasPastValue) { if (pastKey.dims.length !== 4) { throw new Error('Input "past_key" is expected to have 4 dimensions'); } if (pastValue.dims.length !== 4) { throw new Error('Input "past_value" is expected to have 4 dimensions'); } pastSequenceLength = pastKey.dims[2]; } else if (hasPastKey || hasPastValue) { throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); } let qkvFormat: AttentionQkvFormat = AttentionQkvFormat.qkvBNSH; if (key && key.dims.length > 0) { if (query.dims.length !== 3) { throw new Error('Input "query" is expected to have 3 dimensions when key is given'); } if (key.dims.length < 3 || key.dims.length > 5) { throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions'); } if (query.dims[0] !== key.dims[0]) { throw new Error('Input "query" and "key" shall have same dim 0 (batch size)'); } if (key.dims.length === 3) { if (query.dims[2] % key.dims[2] !== 0) { throw new Error('Dimension 2 of "query" should be a multiple of "key"'); } kvSequenceLength = key.dims[1]; } else if (key.dims.length === 5) { if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv'); } if (value) { throw new Error('Expect "value" be none when "key" has packed kv format.'); } kvSequenceLength = key.dims[1]; } else { // key_dims.size() == 4 (cross-attention with past_key) if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } kvSequenceLength = key.dims[2]; } } else { // packed QKV if (query.dims.length !== 3 && query.dims.length !== 5) { throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); } if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); } qkvFormat = AttentionQkvFormat.qkvBSN3H; } const maskType: AttentionMaskType = AttentionMaskType.none; let passPastInKv = false; let vHiddenSize = attributes.kvNumHeads ? headSize * attributes.kvNumHeads : hiddenSize; if (value && value.dims.length > 0) { if (value.dims.length !== 3 && value.dims.length !== 4) { throw new Error('Input "value" is expected to have 3 or 4 dimensions'); } if (query.dims[0] !== value.dims[0]) { throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)'); } if (value.dims.length === 3) { if (kvSequenceLength !== value.dims[1]) { throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); } vHiddenSize = value.dims[2]; } else { if (kvSequenceLength !== value.dims[2]) { throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); } vHiddenSize = value.dims[1] * value.dims[3]; passPastInKv = true; } } const seqlLens = inputs.length > 4 ? inputs[5] : undefined; if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) { throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size'); } const totalSequenceLength = -1; const maxSequenceLength = -1; const broadcastResPosBias = false; return { batchSize, sequenceLength, pastSequenceLength, kvSequenceLength, totalSequenceLength, maxSequenceLength, inputHiddenSize: 0, hiddenSize, vHiddenSize, headSize, vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads), numHeads: attributes.numHeads, kvNumHeads: attributes.kvNumHeads, nReps: attributes.numHeads / attributes.kvNumHeads, pastPresentShareBuffer: false, maskType, scale: attributes.scale, broadcastResPosBias, passPastInKv, qkvFormat, }; }; const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] }); const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params: AttentionParameters) => { let reshapedInput = input; const numHeads = params.kvNumHeads!; if (input.dims.length === 3 && params.kvSequenceLength !== 0) { reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]); reshapedInput = context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1], })[0]; } return reshapedInput; }; const generatePositionIdsProgramInfo = ( batchSize: number, sequenceLength: number, seqLens: TensorView, totalSeqLen: TensorView, ) => { const outputDataType = DataType.int64; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; const outputShape = [batchSize * sequenceLength]; const outputSize = batchSize * sequenceLength; const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, { type: DataType.uint32, data: sequenceLength }, { type: DataType.uint32, data: batchSize }, ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const seqLensInputHelper = inputVariable('seq_lens', seqLens.dataType, seqLens.dims); const totalSeqLenInputHelper = inputVariable('total_seq_lens', totalSeqLen.dataType, totalSeqLen.dims); const positionIdsHelper = outputVariable('pos_ids', outputDataType, outputShape); const uniforms: UniformsArrayType = [ { name: 'output_size', type: 'u32' }, { name: 'sequence_length', type: 'u32' }, { name: 'batch_size', type: 'u32' }, ]; return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(seqLensInputHelper, totalSeqLenInputHelper, positionIdsHelper)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let total_sequence_length = u32(${totalSeqLenInputHelper.getByOffset('0')}); let is_subsequent_prompt = uniforms.sequence_length > 1 && uniforms.sequence_length != total_sequence_length; let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length; let batch_idx = global_idx / uniforms.sequence_length; let sequence_idx = i32(global_idx % uniforms.sequence_length); var pos_id: i32 = 0; let seqlen = ${seqLensInputHelper.getByOffset('batch_idx')}; let total_seqlen = seqlen + 1; if (is_first_prompt) { if (sequence_idx < total_seqlen) { pos_id = sequence_idx; } else { pos_id = 1; } ${positionIdsHelper.setByOffset('global_idx', 'pos_id')} } else if (is_subsequent_prompt) { let past_seqlen = total_seqlen - i32(uniforms.sequence_length); if (past_seqlen + sequence_idx < total_seqlen) { pos_id = past_seqlen + sequence_idx; } else { pos_id = 1; } ${positionIdsHelper.setByOffset('global_idx', 'pos_id')} } else if (global_idx < uniforms.batch_size) { ${positionIdsHelper.setByOffset('global_idx', 'seqlen')} }; } `; }; return { name: 'GeneratePositionIds', shaderCache: { hint: `${batchSize};${sequenceLength}`, inputDependencies }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: outputDataType }], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms, }), getShaderSource, }; }; export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { const params = validateInputs(context.inputs, attributes); if (context.inputs[0].dims.length === 5) { throw new Error('Packed QKV is not implemented'); } if (context.inputs[1]?.dims.length === 5) { throw new Error('Packed KV is not implemented'); } const q = context.inputs[0]; const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined; const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined; const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined; const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined; const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads; // TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead. const splitAttributes: SplitAttributes = createAttributeWithCacheKey({ axis: 2, numOutputs: 3, splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize], }); const [query, key, value] = !k && !v ? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] }) : [q, k!, v!]; let qRotary: TensorView | undefined; let kRotary: TensorView | undefined; if (attributes.doRotary) { const posIds = context.compute( generatePositionIdsProgramInfo(params.batchSize, params.sequenceLength, seqLens!, totalSequenceLengthInput!), { inputs: [seqLens!, totalSequenceLengthInput!], outputs: [-1] }, )[0]; const cosCache = context.inputs[7]; const sinCache = context.inputs[8]; const qRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({ interleaved: attributes.rotaryInterleaved !== 0, numHeads: params.numHeads, rotaryEmbeddingDim: 0, scale: attributes.scale, }); const inputs = [query, posIds, cosCache, sinCache]; const outputs = [-1]; qRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, qRotaryEmbeddingAttributes), { inputs, outputs, })[0]; inputs.splice(0, 1, key); const kRotaryEmbeddingAttributes: RotaryEmbeddingAttributes = createAttributeWithCacheKey({ interleaved: attributes.rotaryInterleaved !== 0, numHeads: params.kvNumHeads!, rotaryEmbeddingDim: 0, scale: attributes.scale, }); kRotary = context.compute(createRotaryEmbeddingProgramInfo(inputs, kRotaryEmbeddingAttributes), { inputs, outputs, })[0]; } const Q = maybeTransposeToBNSHAndAddBias( context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, attributes.doRotary ? qRotary! : query, undefined, 0, ); const K = maybeTransposeToBNSH(context, attributes.doRotary ? kRotary! : key, params); const V = maybeTransposeToBNSH(context, value, params); applyAttention( context, Q, K, V, undefined, undefined, pastKey, pastValue, undefined, params, seqLens, totalSequenceLengthInput, ); };