UNPKG

chromadb-default-embed

Version:

Chroma's fork of @xenova/transformers serving as our default embedding function

1,299 lines (1,115 loc) 222 kB
/** * @file Definitions of all models available in Transformers.js. * * **Example:** Load and run an `AutoModel`. * * ```javascript * import { AutoModel, AutoTokenizer } from '@xenova/transformers'; * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); * let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased'); * * let inputs = await tokenizer('I love transformers!'); * let { logits } = await model(inputs); * // Tensor { * // data: Float32Array(183132) [-7.117443084716797, -7.107812881469727, -7.092104911804199, ...] * // dims: (3) [1, 6, 30522], * // type: "float32", * // size: 183132, * // } * ``` * * We also provide other `AutoModel`s (listed below), which you can use in the same way as the Python library. For example: * * **Example:** Load and run an `AutoModelForSeq2SeqLM`. * ```javascript * import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@xenova/transformers'; * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); * let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small'); * * let { input_ids } = await tokenizer('translate English to German: I love transformers!'); * let outputs = await model.generate(input_ids); * let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true }); * // 'Ich liebe Transformatoren!' * ``` * * @module models */ import { AutoConfig, } from './configs.js'; import { add_token_types, } from './tokenizers.js'; import { Callable, isIntegralNumber, isTypedArray, mergeArrays, } from './utils/core.js'; import { getModelFile, getModelJSON, } from './utils/hub.js'; import { LogitsProcessorList, GenerationConfig, ForceTokensLogitsProcessor, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, WhisperTimeStampLogitsProcessor, NoRepeatNGramLogitsProcessor, RepetitionPenaltyLogitsProcessor, NoBadWordsLogitsProcessor, MinLengthLogitsProcessor, MinNewTokensLengthLogitsProcessor, Sampler, } from './utils/generation.js'; import { cat, dynamicTimeWarping, mean, ones_like, stack, std_mean, Tensor, } from './utils/tensor.js'; import { executionProviders, ONNX } from './backends/onnx.js'; import { medianFilter } from './transformers.js'; const { InferenceSession, Tensor: ONNXTensor, env } = ONNX; /** @typedef {import('onnxruntime-web').InferenceSession} InferenceSession */ ////////////////////////////////////////////////// // Model types: used internally const MODEL_TYPES = { EncoderOnly: 0, EncoderDecoder: 1, Seq2Seq: 2, Vision2Seq: 3, DecoderOnly: 4, } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // Helper functions // NOTE: These will be populated fully later const MODEL_TYPE_MAPPING = new Map(); const MODEL_NAME_TO_CLASS_MAPPING = new Map(); const MODEL_CLASS_TO_NAME_MAPPING = new Map(); /** * Constructs an InferenceSession using a model file located at the specified path. * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {string} fileName The name of the model file. * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model. * @returns {Promise<InferenceSession>} A Promise that resolves to an InferenceSession object. * @private */ async function constructSession(pretrained_model_name_or_path, fileName, options) { // TODO add option for user to force specify their desired execution provider let modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`; let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); try { return await InferenceSession.create(buffer, { executionProviders, }); } catch (err) { // If the execution provided was only wasm, throw the error if (executionProviders.length === 1 && executionProviders[0] === 'wasm') { throw err; } console.warn(err); console.warn( 'Something went wrong during model construction (most likely a missing operation). ' + 'Using `wasm` as a fallback. ' ) return await InferenceSession.create(buffer, { executionProviders: ['wasm'] }); } } /** * Validate model inputs * @param {InferenceSession} session The InferenceSession object that will be run. * @param {Record<string, Tensor>} inputs The inputs to check. * @returns {Record<string, Tensor>} The checked inputs. * @throws {Error} If any inputs are missing. * @private */ function validateInputs(session, inputs) { /** * NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy` * @type {Record<string, Tensor>} */ const checkedInputs = Object.create(null); const missingInputs = []; for (const inputName of session.inputNames) { const tensor = inputs[inputName]; // Rare case where one of the model's input names corresponds to a built-in // object name (e.g., toString), which would cause a simple (!tensor) check to fail, // because it's not undefined but a function. if (!(tensor instanceof Tensor)) { missingInputs.push(inputName); continue; } // NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker // boundary, transferring ownership to the worker and invalidating the tensor. // So, in this case, we simply sacrifice a clone for it. checkedInputs[inputName] = env.wasm.proxy ? tensor.clone() : tensor; } if (missingInputs.length > 0) { throw new Error( `An error occurred during model execution: "Missing the following inputs: ${missingInputs.join(', ')}.`); } const numInputsProvided = Object.keys(inputs).length; const numInputsNeeded = session.inputNames.length; if (numInputsProvided > numInputsNeeded) { // No missing inputs, but too many inputs were provided. // Warn the user and ignore the extra inputs. let ignored = Object.keys(inputs).filter(inputName => !session.inputNames.includes(inputName)); console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(', ')}".`); } return checkedInputs; } /** * Executes an InferenceSession using the specified inputs. * NOTE: `inputs` must contain at least the input names of the model. * - If additional inputs are passed, they will be ignored. * - If inputs are missing, an error will be thrown. * * @param {InferenceSession} session The InferenceSession object to run. * @param {Object} inputs An object that maps input names to input tensors. * @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors. * @private */ async function sessionRun(session, inputs) { const checkedInputs = validateInputs(session, inputs); try { // @ts-ignore let output = await session.run(checkedInputs); output = replaceTensors(output); return output; } catch (e) { // This usually occurs when the inputs are of the wrong type. console.error(`An error occurred during model execution: "${e}".`); console.error('Inputs given to model:', checkedInputs); throw e; } } /** * Replaces ONNX Tensor objects with custom Tensor objects to support additional functions. * @param {Object} obj The object to replace tensor objects in. * @returns {Object} The object with tensor objects replaced by custom Tensor objects. * @private */ function replaceTensors(obj) { for (let prop in obj) { if (obj[prop] instanceof ONNXTensor) { obj[prop] = new Tensor(obj[prop]); } else if (typeof obj[prop] === 'object') { replaceTensors(obj[prop]); } } return obj; } /** * Converts an array or Tensor of integers to an int64 Tensor. * @param {Array|Tensor} items The input integers to be converted. * @returns {Tensor} The int64 Tensor with the converted values. * @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length. * @private */ function toI64Tensor(items) { if (items instanceof Tensor) { return items; } // items is an array if (items.length === 0) { throw Error("items must be non-empty"); } if (Array.isArray(items[0])) { // batched if (items.some(x => x.length !== items[0].length)) { throw Error("Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' and/or 'truncation=True' to have batched tensors with the same length.") } return new Tensor('int64', BigInt64Array.from(items.flat().map(x => BigInt(x))), [items.length, items[0].length] ); } else { //flat return new Tensor('int64', BigInt64Array.from(items.map(x => BigInt(x))), [1, items.length] ); } } /** * Prepares an attention mask for a sequence of tokens based on configuration options. * @param {Object} self The calling object instance. * @param {Tensor} tokens The input tokens. * @returns {Tensor} The attention mask tensor. * @private */ function prepareAttentionMask(self, tokens) { // Prepare attention mask let pad_token_id = self.config.pad_token_id ?? null; let eos_token_id = self.config.eos_token_id ?? null; if (isIntegralNumber(eos_token_id)) { eos_token_id = [eos_token_id]; } let is_pad_token_in_inputs = tokens.indexOf(pad_token_id) !== -1; let is_pad_token_not_equal_to_eos_token_id = (eos_token_id === null) || !eos_token_id.includes(pad_token_id) if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) { let data = BigInt64Array.from( // Note: != so that int matches bigint // @ts-ignore tokens.data.map(x => x != pad_token_id) ) return new Tensor('int64', data, tokens.dims) } else { return ones_like(tokens); } } /** * Add position IDs to the feeds object. * @param {Object} session The inference session. * @param {Object} feeds The input to the model. * @param {boolean} use_cache_branch Whether to use the cache branch of the model. * @returns {void} * @private */ function preparePositionIds(session, feeds, use_cache_branch) { if (!session.inputNames.includes('position_ids')) return; const data = new BigInt64Array(feeds.attention_mask.data.length); // Compute cumulative sum of the attention mask along the sequence length dimension for (let i = 0; i < feeds.attention_mask.dims[0]; ++i) { let start = i * feeds.attention_mask.dims[1]; let sum = BigInt(0); for (let j = 0; j < feeds.attention_mask.dims[1]; ++j) { const index = start + j; if (feeds.attention_mask.data[index] === 0n) { data[index] = BigInt(1); } else { // === 1n data[index] = sum; sum += feeds.attention_mask.data[index]; } } } feeds.position_ids = new Tensor('int64', data, feeds.attention_mask.dims); if (use_cache_branch) { feeds.position_ids = feeds.position_ids.slice(null, -1).unsqueeze_(-1); } } /** * Creates a boolean tensor with a single value. * @param {boolean} value The value of the tensor. * @returns {Tensor} The boolean tensor. * @private */ function boolTensor(value) { return new Tensor('bool', [value], [1]); } // JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in /** * Perform forward pass on the seq2seq model (both encoder and decoder). * @param {Object} self The seq2seq model object. * @param {Object} model_inputs The input object for the model containing encoder and decoder inputs. * @returns {Promise<Seq2SeqLMOutput>} Promise that resolves with the output of the seq2seq model. * @private */ async function seq2seqForward(self, model_inputs) { let { encoder_outputs, past_key_values } = model_inputs; if (!encoder_outputs) { // Encoder outputs are not given, so we must compute them. encoder_outputs = (await encoderForward(self, model_inputs)).last_hidden_state; } let decoderFeeds = { input_ids: model_inputs.decoder_input_ids, encoder_hidden_states: encoder_outputs, }; const use_cache_branch = !!past_key_values; if (self.decoder_merged_session.inputNames.includes('use_cache_branch')) { decoderFeeds.use_cache_branch = boolTensor(use_cache_branch); } if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) { decoderFeeds.encoder_attention_mask = model_inputs.attention_mask } preparePositionIds(self.decoder_merged_session, decoderFeeds, use_cache_branch); self.addPastKeyValues(decoderFeeds, past_key_values); const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds); let logits = decoderResults.logits; past_key_values = self.getPastKeyValues(decoderResults, past_key_values); // Get cross attention and/or decoder attentions if they are present const attns = self.getAttentions(decoderResults); return new Seq2SeqLMOutput({ logits, past_key_values, encoder_outputs, ...attns }); } /** * Start the beam search process for the seq2seq model. * @param {PreTrainedModel} self The seq2seq model object. * @param {Tensor} inputTokenIds Array of input token ids for each input sequence. * @param {Object} generation_config The generation config. * @param {number} numOutputTokens The maximum number of output tokens for the model. * @returns {Object[]} Array of beam search objects. * @private */ function seq2seqStartBeams(self, inputTokenIds, generation_config, numOutputTokens) { let beams = []; let beamId = 0; // @ts-ignore const requires_attention_mask = self.requires_attention_mask ?? true; // decoder_input_ids == output_token_ids let decoder_input_ids = generation_config.decoder_input_ids ?? generation_config.decoder_start_token_id ?? generation_config.bos_token_id ?? generation_config.eos_token_id; // Support input as tensor or list // TODO support batched decoder_input_ids if (decoder_input_ids instanceof Tensor) { decoder_input_ids = decoder_input_ids.tolist().flat(); } else if (!Array.isArray(decoder_input_ids)) { decoder_input_ids = [decoder_input_ids]; } for (let tokens of inputTokenIds) { // TODO: Improve // Currently, just add back batch dimension. // In future, allow for true parallel execution tokens.dims = [1, ...tokens.dims] // Create beam let start = { inputs: tokens, encoder_outputs: null, prev_model_outputs: null, output_token_ids: decoder_input_ids, done: false, score: 0, id: beamId++ // assign unique id to beams } if (requires_attention_mask) { start.attention_mask = prepareAttentionMask(self, tokens); } beams.push(start); } return beams; } /** * Run beam search on the seq2seq model for a single beam. * @param {PreTrainedModel} self The seq2seq model object. * @param {Object} beam The beam search object for which to run the model. * @param {Object} options options * @param {string} [options.input_name='input_ids'] The name of the input tensor for the encoder. * @returns {Promise<Object>} Promise that resolves with the output of the seq2seq model for the given beam. * @private */ async function seq2seqRunBeam(self, beam) { const input_name = self.main_input_name; let decoder_input_ids = beam.output_token_ids; if (beam.prev_model_outputs) { // After the first step, `prev_model_outputs` won't be null. // So, we cut decoder_input_ids if past is used decoder_input_ids = decoder_input_ids.slice(-1); } // 1. Prepare let model_inputs = { [input_name]: beam.inputs, decoder_input_ids: toI64Tensor(decoder_input_ids), encoder_outputs: beam.encoder_outputs, past_key_values: beam.prev_model_outputs?.past_key_values, } if (beam.attention_mask) { model_inputs.attention_mask = beam.attention_mask } // 2. Run let output = await self.forward(model_inputs); // 3. Update beam.prev_model_outputs = output; beam.encoder_outputs = output.encoder_outputs; return output; } /** * Update a beam with a new token ID. * @param {Object} beam The beam to update. * @param {number} newTokenId The new token ID to add to the beam's output. * @private */ function seq2seqUpdatebeam(beam, newTokenId) { beam.output_token_ids = [...beam.output_token_ids, newTokenId]; } /** * Forward pass of an encoder model. * @param {Object} self The encoder model. * @param {Object} model_inputs The input data to be used for the forward pass. * @returns {Promise<Object>} Promise that resolves with an object containing the model's outputs. * @private */ async function encoderForward(self, model_inputs) { const encoderFeeds = Object.create(null); for (const key of self.session.inputNames) { encoderFeeds[key] = model_inputs[key]; } if (self.session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) { // Assign default `token_type_ids` to the `encoderFeeds` if the model expects it, // but they weren't created by the tokenizer. add_token_types(encoderFeeds); } return await sessionRun(self.session, encoderFeeds); } /** * Forward pass of a decoder model. * @param {Object} self The decoder model. * @param {Object} model_inputs The input data to be used for the forward pass. * @returns {Promise<Object>} Promise that resolves with an object containing the logits and past key values. * @private */ async function decoderForward(self, model_inputs) { let { input_ids, past_key_values, attention_mask } = model_inputs; let decoderFeeds = { input_ids: input_ids, attention_mask: attention_mask ?? prepareAttentionMask(self, input_ids), } const use_cache_branch = !!past_key_values; if (self.session.inputNames.includes('use_cache_branch')) { decoderFeeds.use_cache_branch = boolTensor(use_cache_branch); } preparePositionIds(self.session, decoderFeeds, use_cache_branch); self.addPastKeyValues(decoderFeeds, past_key_values); let decoderResults = await sessionRun(self.session, decoderFeeds); let logits = decoderResults.logits; past_key_values = self.getPastKeyValues(decoderResults, past_key_values); return { logits, past_key_values }; } /** * Starts the generation of text by initializing the beams for the given input token IDs. * @param {Object} self The text generation model object. * @param {Tensor} inputTokenIds An tensor of input token IDs to generate text from. * @param {Object} generation_config The generation config. * @param {number} numOutputTokens The maximum number of tokens to generate for each beam. * @param {Tensor} [inputs_attention_mask] The attention mask tensor for the input token IDs. * @returns {Object[]} An array of beams initialized with the given inputs and parameters. * @private */ function decoderStartBeams(self, inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask) { let beams = []; let beamId = 0; for (let tokens of inputTokenIds) { let output_token_ids = tokens.tolist().map(Number); // TODO: Improve // Currently, just add back batch dimension. // In future, allow for true parallel execution tokens.dims = [1, ...tokens.dims] let attn_mask; if (inputs_attention_mask) { attn_mask = inputs_attention_mask[beamId]; attn_mask.dims = [1, ...attn_mask.dims] } else { attn_mask = prepareAttentionMask(self, tokens) } let start = { input: tokens, model_input_ids: tokens, attention_mask: attn_mask, prev_model_outputs: null, output_token_ids: output_token_ids, num_output_tokens: numOutputTokens, done: false, score: 0, id: beamId++ // assign unique id to beams } beams.push(start); } return beams; } /** * Runs a single step of the text generation process for a given beam. * * @param {Object} self The decoder object. * @param {Object} beam The beam to run. * @param {Tensor} beam.input The input tensor. * @param {Tensor} beam.model_input_ids The input ids to the model. * @param {Tensor} beam.attention_mask The attention mask. * @param {Object} beam.prev_model_outputs The past key values. * @param {number[]} beam.output_token_ids The output token ids. * @returns {Promise<Object>} The output of the generation step. * @private */ async function decoderRunBeam(self, beam) { let attnMaskData = new BigInt64Array(beam.output_token_ids.length).fill(1n) // 1. Prepare let model_inputs = { input_ids: beam.model_input_ids, attention_mask: new Tensor( 'int64', attnMaskData, [1, attnMaskData.length] ), past_key_values: beam.prev_model_outputs?.past_key_values, } // 2. Run let output = await self.forward(model_inputs); // 3. Update beam.prev_model_outputs = output; return output; } /** * Update a beam with a new token ID. * @param {Object} beam The beam to update. * @param {number} newTokenId The new token ID to add to the beam's output. * @private */ function decoderUpdatebeam(beam, newTokenId) { beam.output_token_ids = [...beam.output_token_ids, newTokenId]; beam.model_input_ids = new Tensor('int64', [BigInt(newTokenId)], [1, 1]); } ////////////////////////////////////////////////// ////////////////////////////////////////////////// /** * A base class for pre-trained models that provides the model configuration and an ONNX session. */ export class PreTrainedModel extends Callable { main_input_name = 'input_ids'; /** * Creates a new instance of the `PreTrainedModel` class. * @param {Object} config The model configuration. * @param {any} session session for the model. */ constructor(config, session) { super(); this.config = config; this.session = session; const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor); const modelType = MODEL_TYPE_MAPPING.get(modelName); this.can_generate = false; this._runBeam = null; this._getStartBeams = null; this._updateBeam = null; this._forward = null; if (modelType === MODEL_TYPES.DecoderOnly) { this.can_generate = true; this._runBeam = decoderRunBeam; this._getStartBeams = decoderStartBeams; this._updateBeam = decoderUpdatebeam; this._forward = decoderForward; } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) { this.can_generate = true; this._runBeam = seq2seqRunBeam; this._getStartBeams = seq2seqStartBeams; this._updateBeam = seq2seqUpdatebeam; this._forward = seq2seqForward; } else if (modelType === MODEL_TYPES.EncoderDecoder) { this._forward = encoderForward; } else { // should be MODEL_TYPES.EncoderOnly this._forward = encoderForward; } } /** * Disposes of all the ONNX sessions that were created during inference. * @returns {Promise<unknown[]>} An array of promises, one for each ONNX session that is being disposed. * @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry */ async dispose() { const promises = []; for (let key of Object.keys(this)) { const item = this[key]; // @ts-ignore if (item instanceof InferenceSession) { promises.push(item.handler.dispose()) } } return await Promise.all(promises); } /** * Instantiate one of the model classes of the library from a pretrained model. * * The model class to instantiate is selected based on the `model_type` property of the config object * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) * * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: * - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a * user or organization name, like `dbmdz/bert-base-german-cased`. * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`. * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model. * * @returns {Promise<PreTrainedModel>} A new instance of the `PreTrainedModel` class. */ static async from_pretrained(pretrained_model_name_or_path, { quantized = true, progress_callback = null, config = null, cache_dir = null, local_files_only = false, revision = 'main', model_file_name = null, } = {}) { let options = { quantized, progress_callback, config, cache_dir, local_files_only, revision, model_file_name, } const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this); const modelType = MODEL_TYPE_MAPPING.get(modelName); let info; if (modelType === MODEL_TYPES.DecoderOnly) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'decoder_model_merged', options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), constructSession(pretrained_model_name_or_path, 'encoder_model', options), constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); } else if (modelType === MODEL_TYPES.EncoderDecoder) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), constructSession(pretrained_model_name_or_path, 'encoder_model', options), constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), ]); } else { // should be MODEL_TYPES.EncoderOnly if (modelType !== MODEL_TYPES.EncoderOnly) { console.warn(`Model type for '${modelName}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`) } info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'model', options) ]); } // @ts-ignore return new this(...info); } /** * Runs the model with the provided inputs * @param {Object} model_inputs Object containing input tensors * @returns {Promise<Object>} Object containing output tensors */ async _call(model_inputs) { return await this.forward(model_inputs); } /** * Forward method for a pretrained model. If not overridden by a subclass, the correct forward method * will be chosen based on the model type. * @param {Object} model_inputs The input data to the model in the format specified in the ONNX model. * @returns {Promise<Object>} The output data from the model in the format specified in the ONNX model. * @throws {Error} This method must be implemented in subclasses. */ async forward(model_inputs) { return await this._forward(this, model_inputs); } /** * @param {import('./utils/generation.js').GenerationConfigType} generation_config * @param {number} input_ids_seq_length The starting sequence length for the input ids. * @returns {LogitsProcessorList} * @private */ _get_logits_processor( generation_config, input_ids_seq_length, // encoder_input_ids, TODO // prefix_allowed_tokens_fn, TODO logits_processor = null ) { const processors = new LogitsProcessorList(); // if (generation_config.diversity_penalty !== null && generation_config.diversity_penalty > 0.0) { // processors.push(new HammingDiversityLogitsProcessor( // generation_config.diversity_penalty, // generation_config.num_beams, // generation_config.num_beam_groups // )); // } // if (generation_config.encoder_repetition_penalty !== null && generation_config.encoder_repetition_penalty !== 1.0) { // processors.push(new EncoderRepetitionPenaltyLogitsProcessor( // generation_config.encoder_repetition_penalty, // encoder_input_ids // )); // } if (generation_config.repetition_penalty !== null && generation_config.repetition_penalty !== 1.0) { processors.push(new RepetitionPenaltyLogitsProcessor(generation_config.repetition_penalty)); } if (generation_config.no_repeat_ngram_size !== null && generation_config.no_repeat_ngram_size > 0) { processors.push(new NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)); } // if (generation_config.encoder_no_repeat_ngram_size !== null && generation_config.encoder_no_repeat_ngram_size > 0) { // if (this.config.is_encoder_decoder) { // processors.push(new EncoderNoRepeatNGramLogitsProcessor( // generation_config.encoder_no_repeat_ngram_size, // encoder_input_ids // )); // } else { // throw new Error("It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"); // } // } if (generation_config.bad_words_ids !== null) { processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); } if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) { processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)); } if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) { processors.push(new MinNewTokensLengthLogitsProcessor( input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id )); } // if (prefix_allowed_tokens_fn !== null) { // processors.push(new PrefixConstrainedLogitsProcessor( // prefix_allowed_tokens_fn, // generation_config.num_beams / generation_config.num_beam_groups // )); // } if (generation_config.forced_bos_token_id !== null) { processors.push(new ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)); } if (generation_config.forced_eos_token_id !== null) { processors.push(new ForcedEOSTokenLogitsProcessor( generation_config.max_length, generation_config.forced_eos_token_id )); } // if (generation_config.remove_invalid_values === true) { // processors.push(new InfNanRemoveLogitsProcessor()); // } // if (generation_config.exponential_decay_length_penalty !== null) { // processors.push(new ExponentialDecayLengthPenalty( // generation_config.exponential_decay_length_penalty, // generation_config.eos_token_id, // input_ids_seq_length // )); // } // if (generation_config.suppress_tokens !== null) { // processors.push(new SuppressTokensLogitsProcessor(generation_config.suppress_tokens)); // } if (generation_config.begin_suppress_tokens !== null) { let begin_index = (input_ids_seq_length > 1 || generation_config.forced_bos_token_id === null) ? input_ids_seq_length : input_ids_seq_length + 1; if (generation_config.forced_decoder_ids !== null) { // generation starts after the last token that is forced begin_index += generation_config.forced_decoder_ids[generation_config.forced_decoder_ids.length - 1][0]; } processors.push(new SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)); } if (generation_config.forced_decoder_ids !== null) { processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)); } if (logits_processor !== null) { processors.extend(logits_processor) } // `LogitNormalization` should always be the last logit processor, when present // if (generation_config.renormalize_logits === true) { // processors.push(new LogitNormalization()); // } return processors; } /** * This function merges multiple generation configs together to form a final generation config to be used by the model for text generation. * It first creates an empty `GenerationConfig` object, then it applies the model's own `generation_config` property to it. Finally, if a `generation_config` object was passed in the arguments, it overwrites the corresponding properties in the final config with those of the passed config object. * @param {import('./utils/generation.js').GenerationConfigType} generation_config A `GenerationConfig` object containing generation parameters. * @returns {import('./utils/generation.js').GenerationConfigType} The final generation config object to be used by the model for text generation. */ _get_generation_config(generation_config) { // Create empty generation config (contains defaults) // We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them let gen_config = new GenerationConfig(this.config); // Apply model's generation config, if it exists if ('generation_config' in this) { Object.assign(gen_config, this.generation_config); } // Finally, use any generation config specified by the user // when calling `generate` if (generation_config !== null) { Object.assign(gen_config, generation_config); } return gen_config; } /** * @typedef {import('./utils/maths.js').TypedArray} TypedArray */ /** * @typedef {{ sequences: Tensor, decoder_attentions: Tensor, cross_attentions: Tensor }} EncoderDecoderOutput * @typedef {Object} DecoderOutput * * Generates text based on the given inputs and generation configuration using the model. * @param {Tensor|Array|TypedArray} inputs An array of input token IDs. * @param {Object|GenerationConfig|null} generation_config The generation configuration to use. If null, default configuration will be used. * @param {Object|null} logits_processor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created. * @param {Object} options options * @param {Object} [options.inputs_attention_mask=null] An optional attention mask for the inputs. * @returns {Promise<number[][]|EncoderDecoderOutput|DecoderOutput>} An array of generated output sequences, where each sequence is an array of token IDs. * @throws {Error} Throws an error if the inputs array is empty. */ async generate( inputs, generation_config = null, logits_processor = null, { inputs_attention_mask = null } = {}, ) { if (!this.can_generate) { const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor); let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.` const modelType = this.config.model_type; const possibleInfo = MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(modelType) ?? MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.get(modelType) ?? MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.get(modelType) // ?? MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES.get(modelType) // TODO ?? MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.get(modelType); if (possibleInfo) { // TODO: support multiple possible classes errorMessage += ` Please use the following class instead: '${possibleInfo[0]}'`; } throw Error(errorMessage); } if (!(inputs instanceof Tensor) && !isTypedArray(inputs) && !Array.isArray(inputs)) { throw Error(`\`inputs\` must be a Tensor, TypedArray, or Array, but is "${inputs.constructor.name}".`); } let input_ids_seq_length; // Prepare `input_ids` which will be used for auto-regressive generation // TODO: Update to align with HF transformers' implementation if (this.config.is_encoder_decoder) { // Generating from the encoder outputs input_ids_seq_length = 0; } else { input_ids_seq_length = inputs instanceof Tensor ? inputs.dims.at(-1) : inputs.length; // decoder-only if (input_ids_seq_length === 0) { throw Error("Must supply a non-empty array of input token ids.") } } // Update generation config with defaults generation_config = this._get_generation_config(generation_config); logits_processor = logits_processor ?? new LogitsProcessorList() // Update logits processor logits_processor = this._get_logits_processor( generation_config, input_ids_seq_length, logits_processor ) /** @type {number[]} */ let eos_token_ids = generation_config.eos_token_id; if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) { eos_token_ids = [eos_token_ids]; } // TODO implement early_stopping // https://huggingface.co/blog/how-to-generate let numOutputTokens = 1; const maxOutputTokens = numOutputTokens + (generation_config.max_new_tokens ?? Infinity); // Only use max length if max_new_tokens is not provided const useMaxLength = Number.isInteger(generation_config.max_length) && (generation_config.max_new_tokens ?? null) === null; let sampler = Sampler.getSampler(generation_config); // @ts-ignore let beams = this.getStartBeams(inputs, generation_config, numOutputTokens, inputs_attention_mask); while (beams.some(x => !x.done) && numOutputTokens < maxOutputTokens) { let newest_beams = []; for (let beam of beams) { if (beam.done) { // Add this beam back into the pool newest_beams.push(beam); continue } if (useMaxLength && beam.output_token_ids.length >= generation_config.max_length) { // Set this beam to done and add it back into the pool beam.done = true; newest_beams.push(beam); continue } // @ts-ignore let output = await this.runBeam(beam); // add attentions/scores to beam only if user requested if (generation_config.output_attentions) { this.addAttentionsToBeam(beam, output); } if (generation_config.output_scores) { // TODO add } // Logits are of the form [batch_size, out_seq_length, vocab_size] // In most cases, this will be [batch_size, 1, vocab_size] // So, we select the last token's logits: // (equivalent to `logits = outputs.logits[:, -1, :]`) let logits = output.logits.slice(null, -1, null); // Apply logits processor logits_processor(beam.output_token_ids, logits); let sampledTokens = sampler(logits); for (let [newTokenId, logProb] of sampledTokens) { // use previous beam as a starting point let newBeam = { ...beam }; // update new beam // @ts-ignore this.updateBeam(newBeam, newTokenId); newBeam.score += logProb; if (eos_token_ids && eos_token_ids.includes(newTokenId)) { newBeam.done = true; } newest_beams.push(newBeam); } } ++numOutputTokens; // Next, we get the best beams, per ID newest_beams = this.groupBeams(newest_beams).map( group => group .sort((a, b) => b.score - a.score) // sort by score .slice(0, generation_config.num_beams) // remove outside beam width ); // Flatten beams beams = newest_beams.flat(); // Run callback if (generation_config.callback_function) { generation_config.callback_function(beams); } } // TODO: Ensure that we can return non-batched outputs const groupedBeams = this.groupBeams(beams); const getFlattened = (key) => groupedBeams.map( batch => { if (generation_config.num_return_sequences > 1) { return batch.slice(0, generation_config.num_return_sequences).map(x => x[key]); } else { return [batch[0][key]]; } } ).flat(); // Flatten across batches (depth=1) const sequences = getFlattened('output_token_ids'); // [1, seqLength] if (generation_config.return_dict_in_generate) { // NOTE: `decoder_attentions` and `cross_attentions` should be: // list (one element for each generated token) // of list (one element for each layer of the decoder) // of torch.FloatTensor of shape (batch_size, num_heads, generated_length, sequence_length) // However, since we are only generating one batch at a time, they are of the form: // list (batches) // of list (one element for each generated token) // of list (one element for each layer of the decoder) // of torch.FloatTensor of shape (1, num_heads, generated_length, sequence_length) // // TODO: In future (when true parallelism, we should be able to return the correct shape) const decoder_attentions = getFlattened('decoder_attentions'); const cross_attentions = getFlattened('cross_attentions'); return { sequences, decoder_attentions, cross_attentions, } } else { return sequences; } } /** * Helper function to add attentions to beam * @param {Object} beam * @param {Object} output * @private */ addAttentionsToBeam(beam, output) { if (this.config.is_encoder_decoder) { if (!output.cross_attentions || output.cross_attentions.length === 0) { throw Error( "`output_attentions` is true, but the model did not produce cross-attentions. " + "This is most likely because the model was not exported with `output_attentions=True`." ) } if (!beam.cross_attentions) { beam.cross_attentions = []; } beam.cross_attentions.push(output.cross_attentions); } if (!output.decoder_attentions || output.decoder_attentions.length === 0) { throw Error( "`output_attentions` is true, but the model did not produce decoder-attentions. " + "This is most likely because the model was not exported with `output_attentions=True`." ) } if (!beam.decoder_attentions) { beam.decoder_attentions = []; } beam.decoder_attentions.push(output.decoder_attentions); } /** * Groups an array of beam objects by their ids. * * @param {Array} beams The array of beam objects to group. * @returns {Array} An array of arrays, where each inner array contains beam objects with the same id. */ groupBeams(beams) { // Group beams by their ids const groups = Object.create(null); for (const obj of beams) { if (groups[obj.id] === undefined) { groups[obj.id] = [obj]; } else { groups[obj.id].push(obj); } } return Object.values(groups); } /** * Returns an object containing past key values from the given decoder results object. * * @param {Object} decoderResults The decoder results object. * @param {Object} pastKeyValues The previous past key values. * @returns {Object} An object containing past key values. */ getPastKeyValues(decoderResults, pastKeyValues) { const pkvs = Object.create(null); for (const name in decoderResults) { if (name.startsWith('present')) { let newName = name.replace('present', 'past_key_values'); if (pastKeyValues && name.includes('encoder')) { // Optimization introduced by optimum to reuse past key values. So, we just replace the constant // outputs with the previous past key values. // https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704 pkvs[newName] = pastKeyValues[newName]; } else { pkvs[newName] = decoderResults[name]; } } } return pkvs; } /** * Returns an object containing attentions from the given decoder results object. * * @param {Object} decoderResults The decoder results object. * @returns {Object} An object containing attentions. */ getAttentions(decoderResults) { const attns = Object.create(null); for (const attnName of ['cross_attentions', 'decoder_attentions']) { const result = []; for (const name in decoderResults) { if (name.startsWith(attnName)) { const index = name.split('.').pop() result[index] = decoderResults[name]; } } attns[attnName] = result; } return attns; } /** * Adds past key values to the decoder feeds object. If pastKeyValues is null, creates new tensors for past key values. * * @param {Object} decoderFeeds The decoder feeds object to add past key values to. * @param {Object} pastKeyValues An object containing past key values. */ addPastKeyValues(decoderFeeds, pastKeyValues) { if (pastKeyValues) { Object.assign(decoderFeeds, pastKeyValues) } else { // TODO support batches (i.e., batch_size > 1) const batch_size = 1; // @ts-ignore if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) { // @ts-ignore let encoder_dims = [batch_size, this.num_encoder_heads, 0, this.encoder_dim_kv]; // @ts-ignore let decoder_dims = [batch_size, this.num_decoder_heads, 0, this.decoder_dim_kv]; // @ts-ignore for (let i = 0;