UNPKG

@huggingface/transformers

Version:

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!

1,236 lines (1,080 loc) • 282 kB
/** * @file Definitions of all models available in Transformers.js. * * **Example:** Load and run an `AutoModel`. * * ```javascript * import { AutoModel, AutoTokenizer } from '@huggingface/transformers'; * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); * let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased'); * * let inputs = await tokenizer('I love transformers!'); * let { logits } = await model(inputs); * // Tensor { * // data: Float32Array(183132) [-7.117443084716797, -7.107812881469727, -7.092104911804199, ...] * // dims: (3) [1, 6, 30522], * // type: "float32", * // size: 183132, * // } * ``` * * We also provide other `AutoModel`s (listed below), which you can use in the same way as the Python library. For example: * * **Example:** Load and run an `AutoModelForSeq2SeqLM`. * ```javascript * import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@huggingface/transformers'; * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); * let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small'); * * let { input_ids } = await tokenizer('translate English to German: I love transformers!'); * let outputs = await model.generate(input_ids); * let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true }); * // 'Ich liebe Transformatoren!' * ``` * * @module models */ import { AutoConfig, getKeyValueShapes, } from './configs.js'; import { deviceToExecutionProviders, createInferenceSession, isONNXTensor, isONNXProxy, } from './backends/onnx.js'; import { DATA_TYPES, DEFAULT_DEVICE_DTYPE_MAPPING, DEFAULT_DTYPE_SUFFIX_MAPPING, isWebGpuFp16Supported, } from './utils/dtypes.js'; import { Callable, } from './utils/generic.js'; import { isIntegralNumber, mergeArrays, pick, } from './utils/core.js'; import { getModelFile, getModelJSON, } from './utils/hub.js'; import { LogitsProcessorList, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, WhisperTimeStampLogitsProcessor, NoRepeatNGramLogitsProcessor, RepetitionPenaltyLogitsProcessor, NoBadWordsLogitsProcessor, MinLengthLogitsProcessor, MinNewTokensLengthLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, ClassifierFreeGuidanceLogitsProcessor, } from './generation/logits_process.js'; import { GenerationConfig, } from './generation/configuration_utils.js'; import { cat, full_like, mean, ones, ones_like, stack, std_mean, Tensor, zeros_like, } from './utils/tensor.js'; import { dynamic_time_warping, medianFilter } from './utils/maths.js'; import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; import { LogitsSampler } from './generation/logits_sampler.js'; import { apis } from './env.js'; import { WhisperGenerationConfig } from './models/whisper/generation_whisper.js'; import { whisper_language_to_code } from './models/whisper/common_whisper.js'; ////////////////////////////////////////////////// // Model types: used internally const MODEL_TYPES = { EncoderOnly: 0, EncoderDecoder: 1, Seq2Seq: 2, Vision2Seq: 3, DecoderOnly: 4, MaskGeneration: 5, ImageTextToText: 6, Musicgen: 7, } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // Helper functions // NOTE: These will be populated fully later const MODEL_TYPE_MAPPING = new Map(); const MODEL_NAME_TO_CLASS_MAPPING = new Map(); const MODEL_CLASS_TO_NAME_MAPPING = new Map(); /** * Constructs an InferenceSession using a model file located at the specified path. * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {string} fileName The name of the model file. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. * @returns {Promise<{buffer: Uint8Array, session_options: Object}>} A Promise that resolves to the data needed to create an InferenceSession object. * @private */ async function getSession(pretrained_model_name_or_path, fileName, options) { const custom_config = options.config?.['transformers.js_config'] ?? {}; let device = options.device ?? custom_config.device; if (device && typeof device !== 'string') { if (device.hasOwnProperty(fileName)) { device = device[fileName]; } else { console.warn(`device not specified for "${fileName}". Using the default device.`); device = null; } } // If the device is not specified, we use the default (supported) execution providers. const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */( device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm') ); const executionProviders = deviceToExecutionProviders(selectedDevice); // If options.dtype is specified, we use it to choose the suffix for the model file. // Otherwise, we use the default dtype for the device. let dtype = options.dtype ?? custom_config.dtype; if (typeof dtype !== 'string') { if (dtype && dtype.hasOwnProperty(fileName)) { dtype = dtype[fileName]; } else { dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32; console.warn(`dtype not specified for "${fileName}". Using the default dtype (${dtype}) for this device (${selectedDevice}).`); } } const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype); if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) { throw new Error(`Invalid dtype: ${selectedDtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`); } else if (selectedDtype === DATA_TYPES.fp16 && selectedDevice === 'webgpu' && !(await isWebGpuFp16Supported())) { throw new Error(`The device (${selectedDevice}) does not support fp16.`); } // Construct the model file name const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype]; const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`; const session_options = { ...options.session_options } ?? {}; // Overwrite `executionProviders` if not specified session_options.executionProviders ??= executionProviders; // Overwrite `freeDimensionOverrides` if specified in config and not set in session options const free_dimension_overrides = custom_config.free_dimension_overrides; if (free_dimension_overrides) { session_options.freeDimensionOverrides ??= free_dimension_overrides; } else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) { console.warn( 'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' + 'When `free_dimension_overrides` is not set, you may experience significant performance degradation.' ); } const bufferPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options); // handle onnx external data files /** @type {Promise<{path: string, data: Uint8Array}>[]} */ let externalDataPromises = []; if (options.use_external_data_format && ( options.use_external_data_format === true || ( typeof options.use_external_data_format === 'object' && options.use_external_data_format.hasOwnProperty(fileName) && options.use_external_data_format[fileName] === true ) )) { if (apis.IS_NODE_ENV) { throw new Error('External data format is not yet supported in Node.js'); } const path = `${fileName}${suffix}.onnx_data`; const fullPath = `${options.subfolder ?? ''}/${path}`; externalDataPromises.push(new Promise(async (resolve, reject) => { const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options); resolve({ path, data }) })); } else if (session_options.externalData !== undefined) { externalDataPromises = session_options.externalData.map(async (ext) => { // if the external data is a string, fetch the file and replace the string with its content if (typeof ext.data === "string") { const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options); return { ...ext, data: ext_buffer }; } return ext; }); } if (externalDataPromises.length > 0) { session_options.externalData = await Promise.all(externalDataPromises); } if (selectedDevice === 'webgpu') { const shapes = getKeyValueShapes(options.config, { prefix: 'present', }); if (Object.keys(shapes).length > 0 && !isONNXProxy()) { // Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX /** @type {Record<string, import('onnxruntime-common').Tensor.DataLocation>} */ const preferredOutputLocation = {}; for (const key in shapes) { preferredOutputLocation[key] = 'gpu-buffer'; } session_options.preferredOutputLocation = preferredOutputLocation; } } const buffer = await bufferPromise; return { buffer, session_options }; } /** * Helper function to create multiple InferenceSession objects. * * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {Record<string, string>} names The names of the model files to load. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. * @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of InferenceSession objects. * @private */ async function constructSessions(pretrained_model_name_or_path, names, options) { return Object.fromEntries(await Promise.all( Object.keys(names).map(async (name) => { const { buffer, session_options } = await getSession(pretrained_model_name_or_path, names[name], options); const session = await createInferenceSession(buffer, session_options); return [name, session]; }) )); } /** * Validate model inputs * @param {Object} session The InferenceSession object that will be run. * @param {Object} inputs The inputs to check. * @returns {Record<string, Tensor>} The checked inputs. * @throws {Error} If any inputs are missing. * @private */ function validateInputs(session, inputs) { /** * NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy` * @type {Record<string, Tensor>} */ const checkedInputs = Object.create(null); const missingInputs = []; for (const inputName of session.inputNames) { const tensor = inputs[inputName]; // Rare case where one of the model's input names corresponds to a built-in // object name (e.g., toString), which would cause a simple (!tensor) check to fail, // because it's not undefined but a function. if (!(tensor instanceof Tensor)) { missingInputs.push(inputName); continue; } // NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker // boundary, transferring ownership to the worker and invalidating the tensor. // So, in this case, we simply sacrifice a clone for it. checkedInputs[inputName] = isONNXProxy() ? tensor.clone() : tensor; } if (missingInputs.length > 0) { throw new Error( `An error occurred during model execution: "Missing the following inputs: ${missingInputs.join(', ')}.`); } const numInputsProvided = Object.keys(inputs).length; const numInputsNeeded = session.inputNames.length; if (numInputsProvided > numInputsNeeded) { // No missing inputs, but too many inputs were provided. // Warn the user and ignore the extra inputs. let ignored = Object.keys(inputs).filter(inputName => !session.inputNames.includes(inputName)); console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(', ')}".`); } return checkedInputs; } /** * Executes an InferenceSession using the specified inputs. * NOTE: `inputs` must contain at least the input names of the model. * - If additional inputs are passed, they will be ignored. * - If inputs are missing, an error will be thrown. * * @param {Object} session The InferenceSession object to run. * @param {Object} inputs An object that maps input names to input tensors. * @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors. * @private */ async function sessionRun(session, inputs) { const checkedInputs = validateInputs(session, inputs); try { // pass the original ort tensor const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor])); let output = await session.run(ortFeed); output = replaceTensors(output); return output; } catch (e) { // This usually occurs when the inputs are of the wrong type. console.error(`An error occurred during model execution: "${e}".`); console.error('Inputs given to model:', checkedInputs); throw e; } } /** * Replaces ONNX Tensor objects with custom Tensor objects to support additional functions. * @param {Object} obj The object to replace tensor objects in. * @returns {Object} The object with tensor objects replaced by custom Tensor objects. * @private */ function replaceTensors(obj) { for (let prop in obj) { if (isONNXTensor(obj[prop])) { obj[prop] = new Tensor(obj[prop]); } else if (typeof obj[prop] === 'object') { replaceTensors(obj[prop]); } } return obj; } /** * Converts an array or Tensor of integers to an int64 Tensor. * @param {Array|Tensor} items The input integers to be converted. * @returns {Tensor} The int64 Tensor with the converted values. * @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length. * @private */ function toI64Tensor(items) { if (items instanceof Tensor) { return items; } // items is an array if (items.length === 0) { throw Error("items must be non-empty"); } if (Array.isArray(items[0])) { // batched if (items.some(x => x.length !== items[0].length)) { throw Error("Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' and/or 'truncation=True' to have batched tensors with the same length.") } return new Tensor('int64', BigInt64Array.from(items.flat().map(x => BigInt(x))), [items.length, items[0].length] ); } else { //flat return new Tensor('int64', BigInt64Array.from(items.map(x => BigInt(x))), [1, items.length] ); } } /** * Creates a boolean tensor with a single value. * @param {boolean} value The value of the tensor. * @returns {Tensor} The boolean tensor. * @private */ function boolTensor(value) { return new Tensor('bool', [value], [1]); } // JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in /** * Perform forward pass on the seq2seq model (both encoder and decoder). * @param {Object} self The seq2seq model object. * @param {Object} model_inputs The input object for the model containing encoder and decoder inputs. * @returns {Promise<Seq2SeqLMOutput>} Promise that resolves with the output of the seq2seq model. * @private */ async function seq2seqForward(self, model_inputs) { let { encoder_outputs, input_ids, decoder_input_ids, ...other_decoder_inputs } = model_inputs; // Encode if needed if (!encoder_outputs) { const encoder_inputs = pick(model_inputs, self.sessions['model'].inputNames); // Encoder outputs are not given, so we must compute them. encoder_outputs = (await encoderForward(self, encoder_inputs)).last_hidden_state; } other_decoder_inputs.input_ids = decoder_input_ids; other_decoder_inputs.encoder_hidden_states = encoder_outputs; if (self.sessions['decoder_model_merged'].inputNames.includes('encoder_attention_mask')) { other_decoder_inputs.encoder_attention_mask = model_inputs.attention_mask } const decoderResults = await decoderForward(self, other_decoder_inputs, true); return decoderResults; } /** * Forward pass of an encoder model. * @param {Object} self The encoder model. * @param {Object} model_inputs The input data to be used for the forward pass. * @returns {Promise<Object>} The model's outputs. * @private */ async function encoderForward(self, model_inputs) { const session = self.sessions['model']; const encoderFeeds = pick(model_inputs, session.inputNames); if (session.inputNames.includes('inputs_embeds') && !encoderFeeds.inputs_embeds) { if (!model_inputs.input_ids) { throw new Error('Both `input_ids` and `inputs_embeds` are missing in the model inputs.'); } encoderFeeds.inputs_embeds = await self.encode_text({ input_ids: model_inputs.input_ids }); } if (session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) { // Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it, // but they weren't created by the tokenizer. encoderFeeds.token_type_ids = new Tensor( 'int64', new BigInt64Array(encoderFeeds.input_ids.data.length), encoderFeeds.input_ids.dims ) } return await sessionRun(session, encoderFeeds); } /** * Forward pass of a decoder model. * @param {Object} self The decoder model. * @param {Object} model_inputs The input data to be used for the forward pass. * @returns {Promise<Object>} The logits and past key values. * @private */ async function decoderForward(self, model_inputs, is_encoder_decoder = false) { const session = self.sessions[ is_encoder_decoder ? 'decoder_model_merged' : 'model' ] const { past_key_values, ...new_model_inputs } = model_inputs; if (session.inputNames.includes('use_cache_branch')) { new_model_inputs.use_cache_branch = boolTensor(!!past_key_values); } if (session.inputNames.includes('position_ids') && new_model_inputs.attention_mask && !new_model_inputs.position_ids) { new_model_inputs.position_ids = createPositionIds(new_model_inputs, past_key_values); } // Unpack the `past_key_values` object into model inputs self.addPastKeyValues(new_model_inputs, past_key_values); // Select only the inputs that are needed for the current session const fixed = pick(new_model_inputs, session.inputNames); return await sessionRun(session, fixed); } /** * Forward pass of an image-text-to-text model. * @param {Object} self The image-text-to-text model model. * @param {Object} model_inputs The input data to be used for the forward pass. * @param {Tensor} [model_inputs.input_ids=null] * @param {Tensor} [model_inputs.attention_mask=null] * @param {Tensor} [model_inputs.pixel_values=null] * @param {Tensor} [model_inputs.position_ids=null] * @param {Tensor} [model_inputs.inputs_embeds=null] * @param {Tensor} [model_inputs.past_key_values=null] * @param {Object} [model_inputs.generation_config=null] * @param {Object} [model_inputs.logits_processor=null] * @returns {Promise<Tensor>} The model's output tensor * @private */ async function imageTextToTextForward(self, { // Produced by the tokenizer/processor: input_ids = null, attention_mask = null, pixel_values = null, // Used during generation: position_ids = null, inputs_embeds = null, past_key_values = null, // Generic generation parameters generation_config = null, logits_processor = null, // TODO: needed? ...kwargs }) { if (!inputs_embeds) { // 1. Extract the input embeddings inputs_embeds = await self.encode_text({ input_ids }); // 2. Possibly, merge text and images if (pixel_values && input_ids.dims[1] !== 1) { const image_features = await self.encode_image({ pixel_values }); ({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({ image_features, inputs_embeds, input_ids, attention_mask, })); } else if (past_key_values && pixel_values && input_ids.dims[1] === 1) { // This is the case when we are generating with cache const target_length = input_ids.dims[1]; // always 1 const past_length = Object.values(past_key_values)[0].dims.at(-2); attention_mask = cat([ ones([input_ids.dims[0], past_length]), attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]), ], 1); } } const outputs = await decoderForward(self, { inputs_embeds, past_key_values, attention_mask, position_ids, generation_config, logits_processor, }, true); return outputs; } function createPositionIds(model_inputs, past_key_values = null) { // If the model supports providing position_ids, we create position_ids on the fly for batch generation, // by computing the cumulative sum of the attention mask along the sequence length dimension. // // Equivalent to: // position_ids = attention_mask.long().cumsum(-1) - 1 // position_ids.masked_fill_(attention_mask == 0, 1) // if past_key_values: // position_ids = position_ids[:, -input_ids.shape[1] :] const { input_ids, inputs_embeds, attention_mask } = model_inputs; const [bz, seq_len] = attention_mask.dims; const data = new BigInt64Array(attention_mask.data.length); for (let i = 0; i < bz; ++i) { const start = i * seq_len; let sum = BigInt(0); for (let j = 0; j < seq_len; ++j) { const index = start + j; if (attention_mask.data[index] === 0n) { data[index] = BigInt(1); } else { // === 1n data[index] = sum; sum += attention_mask.data[index]; } } } let position_ids = new Tensor('int64', data, attention_mask.dims); if (past_key_values) { const offset = -(input_ids ?? inputs_embeds).dims.at(1); position_ids = position_ids.slice(null, [offset, null]); } return position_ids; } function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) { if (model_inputs.past_key_values) { const past_length = Object.values(model_inputs.past_key_values)[0].dims.at(-2); const { input_ids, attention_mask } = model_inputs; // Keep only the unprocessed tokens: // 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where // some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as // input) if (attention_mask && attention_mask.dims[1] > input_ids.dims[1]) { // NOTE: not needed since we only pass the generated tokens to the next forward pass // const offset = -(attention_mask.dims[1] - past_length); // model_inputs.input_ids = input_ids.slice(null, [offset, null]); } // 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. // We can discard input_ids based on the past_length. else if (past_length < input_ids.dims[1]) { // NOTE: Required for phi models. // See https://github.com/huggingface/transformers/issues/30809#issuecomment-2111918479 for more information. model_inputs.input_ids = input_ids.slice(null, [past_length, null]); } // 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. else { if ( // NOTE: Only used by VLMs (!= so that null matches undefined) self.config.image_token_index != null && // Equivalent to `self.config.image_token_index in input_ids` (== so that int matches bigint) input_ids.data.some(x => x == self.config.image_token_index) ) { // TODO: Support multiple image tokens const num_image_tokens = self.config.num_image_tokens; if (!num_image_tokens) { throw new Error('`num_image_tokens` is missing in the model configuration.'); } const num_new_tokens = input_ids.dims[1] - (past_length - num_image_tokens); model_inputs.input_ids = input_ids.slice(null, [-num_new_tokens, null]); // TODO: The attention mask should be formed from the attention mask passed in model_inputs model_inputs.attention_mask = ones([1, past_length + num_new_tokens]); } } } return model_inputs; } function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) { if (model_inputs.past_key_values) { input_ids = input_ids.map(x => [x.at(-1)]); } return { ...model_inputs, decoder_input_ids: toI64Tensor(input_ids), }; } function image_text_to_text_prepare_inputs_for_generation(self, ...args) { if (self.config.is_encoder_decoder) { return encoder_decoder_prepare_inputs_for_generation(self, ...args); } else { return decoder_prepare_inputs_for_generation(self, ...args); } } ////////////////////////////////////////////////// ////////////////////////////////////////////////// /** * A base class for pre-trained models that provides the model configuration and an ONNX session. */ export class PreTrainedModel extends Callable { main_input_name = 'input_ids'; forward_params = ['input_ids', 'attention_mask']; /** * Creates a new instance of the `PreTrainedModel` class. * @param {import('./configs.js').PretrainedConfig} config The model configuration. * @param {Record<string, any>} sessions The inference sessions for the model. */ constructor(config, sessions) { super(); this.config = config; this.sessions = sessions; const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor); const modelType = MODEL_TYPE_MAPPING.get(modelName); this.can_generate = false; this._forward = null; this._prepare_inputs_for_generation = null; switch (modelType) { case MODEL_TYPES.DecoderOnly: this.can_generate = true; this._forward = decoderForward; this._prepare_inputs_for_generation = decoder_prepare_inputs_for_generation; break; case MODEL_TYPES.Seq2Seq: case MODEL_TYPES.Vision2Seq: case MODEL_TYPES.Musicgen: this.can_generate = true; this._forward = seq2seqForward; this._prepare_inputs_for_generation = encoder_decoder_prepare_inputs_for_generation; break; case MODEL_TYPES.EncoderDecoder: this._forward = seq2seqForward; break; case MODEL_TYPES.ImageTextToText: this.can_generate = true; this._forward = imageTextToTextForward; this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation; break; default: // should be MODEL_TYPES.EncoderOnly this._forward = encoderForward; break; } if (this.can_generate) { this.forward_params.push('past_key_values'); } /** @type {import('./configs.js').TransformersJSConfig} */ this.custom_config = this.config['transformers.js_config'] ?? {}; } /** * Disposes of all the ONNX sessions that were created during inference. * @returns {Promise<unknown[]>} An array of promises, one for each ONNX session that is being disposed. * @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry */ async dispose() { const promises = []; for (const session of Object.values(this.sessions)) { if (session?.handler?.dispose) { promises.push(session.handler.dispose()) } } return await Promise.all(promises); } /** * Instantiate one of the model classes of the library from a pretrained model. * * The model class to instantiate is selected based on the `model_type` property of the config object * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) * * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: * - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a * user or organization name, like `dbmdz/bert-base-german-cased`. * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. * * @returns {Promise<PreTrainedModel>} A new instance of the `PreTrainedModel` class. */ static async from_pretrained(pretrained_model_name_or_path, { progress_callback = null, config = null, cache_dir = null, local_files_only = false, revision = 'main', model_file_name = null, subfolder = 'onnx', device = null, dtype = null, use_external_data_format = null, session_options = {}, } = {}) { let options = { progress_callback, config, cache_dir, local_files_only, revision, model_file_name, subfolder, device, dtype, use_external_data_format, session_options, } const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this); const modelType = MODEL_TYPE_MAPPING.get(modelName); config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); let info; if (modelType === MODEL_TYPES.DecoderOnly) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: options.model_file_name ?? 'model', }, options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: 'encoder_model', decoder_model_merged: 'decoder_model_merged', }, options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); } else if (modelType === MODEL_TYPES.MaskGeneration) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: 'vision_encoder', prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder', }, options), ]); } else if (modelType === MODEL_TYPES.EncoderDecoder) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: 'encoder_model', decoder_model_merged: 'decoder_model_merged', }, options), ]); } else if (modelType === MODEL_TYPES.ImageTextToText) { const sessions = { embed_tokens: 'embed_tokens', vision_encoder: 'vision_encoder', decoder_model_merged: 'decoder_model_merged', } if (config.is_encoder_decoder) { sessions['model'] = 'encoder_model'; } info = await Promise.all([ constructSessions(pretrained_model_name_or_path, sessions, options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); } else if (modelType === MODEL_TYPES.Musicgen) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: 'text_encoder', decoder_model_merged: 'decoder_model_merged', encodec_decode: 'encodec_decode', }, options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); } else { // should be MODEL_TYPES.EncoderOnly if (modelType !== MODEL_TYPES.EncoderOnly) { console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`) } info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: options.model_file_name ?? 'model', }, options), ]); } // @ts-ignore return new this(config, ...info); } /** * Runs the model with the provided inputs * @param {Object} model_inputs Object containing input tensors * @returns {Promise<Object>} Object containing output tensors */ async _call(model_inputs) { return await this.forward(model_inputs); } /** * Forward method for a pretrained model. If not overridden by a subclass, the correct forward method * will be chosen based on the model type. * @param {Object} model_inputs The input data to the model in the format specified in the ONNX model. * @returns {Promise<Object>} The output data from the model in the format specified in the ONNX model. * @throws {Error} This method must be implemented in subclasses. */ async forward(model_inputs) { return await this._forward(this, model_inputs); } /** * This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] * instances used for multinomial sampling. * @param {GenerationConfig} generation_config The generation config. * @returns {LogitsProcessorList} generation_config */ _get_logits_warper(generation_config) { // instantiate warpers list const warpers = new LogitsProcessorList(); if (generation_config.temperature !== null && generation_config.temperature !== 1.0) { warpers.push(new TemperatureLogitsWarper(generation_config.temperature)); } if (generation_config.top_k !== null && generation_config.top_k !== 0) { // TODO: add min_tokens_to_keep warpers.push(new TopKLogitsWarper(generation_config.top_k)); } if (generation_config.top_p !== null && generation_config.top_p < 1.0) { // TODO: add min_tokens_to_keep warpers.push(new TopPLogitsWarper(generation_config.top_p)); } return warpers; } /** * @param {GenerationConfig} generation_config * @param {number} input_ids_seq_length The starting sequence length for the input ids. * @returns {LogitsProcessorList} * @private */ _get_logits_processor( generation_config, input_ids_seq_length, // encoder_input_ids, TODO // prefix_allowed_tokens_fn, TODO logits_processor = null ) { const processors = new LogitsProcessorList(); // if (generation_config.diversity_penalty !== null && generation_config.diversity_penalty > 0.0) { // processors.push(new HammingDiversityLogitsProcessor( // generation_config.diversity_penalty, // generation_config.num_beams, // generation_config.num_beam_groups // )); // } // if (generation_config.encoder_repetition_penalty !== null && generation_config.encoder_repetition_penalty !== 1.0) { // processors.push(new EncoderRepetitionPenaltyLogitsProcessor( // generation_config.encoder_repetition_penalty, // encoder_input_ids // )); // } if (generation_config.repetition_penalty !== null && generation_config.repetition_penalty !== 1.0) { processors.push(new RepetitionPenaltyLogitsProcessor(generation_config.repetition_penalty)); } if (generation_config.no_repeat_ngram_size !== null && generation_config.no_repeat_ngram_size > 0) { processors.push(new NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)); } // if (generation_config.encoder_no_repeat_ngram_size !== null && generation_config.encoder_no_repeat_ngram_size > 0) { // if (this.config.is_encoder_decoder) { // processors.push(new EncoderNoRepeatNGramLogitsProcessor( // generation_config.encoder_no_repeat_ngram_size, // encoder_input_ids // )); // } else { // throw new Error("It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"); // } // } if (generation_config.bad_words_ids !== null) { processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); } if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) { processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)); } if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) { processors.push(new MinNewTokensLengthLogitsProcessor( input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id )); } // if (prefix_allowed_tokens_fn !== null) { // processors.push(new PrefixConstrainedLogitsProcessor( // prefix_allowed_tokens_fn, // generation_config.num_beams / generation_config.num_beam_groups // )); // } if (generation_config.forced_bos_token_id !== null) { processors.push(new ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)); } if (generation_config.forced_eos_token_id !== null) { processors.push(new ForcedEOSTokenLogitsProcessor( generation_config.max_length, generation_config.forced_eos_token_id )); } // if (generation_config.remove_invalid_values === true) { // processors.push(new InfNanRemoveLogitsProcessor()); // } // if (generation_config.exponential_decay_length_penalty !== null) { // processors.push(new ExponentialDecayLengthPenalty( // generation_config.exponential_decay_length_penalty, // generation_config.eos_token_id, // input_ids_seq_length // )); // } // if (generation_config.suppress_tokens !== null) { // processors.push(new SuppressTokensLogitsProcessor(generation_config.suppress_tokens)); // } if (generation_config.begin_suppress_tokens !== null) { const begin_index = (input_ids_seq_length > 1 || generation_config.forced_bos_token_id === null) ? input_ids_seq_length : input_ids_seq_length + 1; processors.push(new SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)); } // DEPRECATED: https://github.com/huggingface/transformers/pull/29485 // if (generation_config.forced_decoder_ids !== null) { // processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)); // } // 8. prepare batched CFG externally if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { processors.push(new ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)); } if (logits_processor !== null) { processors.extend(logits_processor) } // `LogitNormalization` should always be the last logit processor, when present // if (generation_config.renormalize_logits === true) { // processors.push(new LogitNormalization()); // } return processors; } /** * This function merges multiple generation configs together to form a final generation config to be used by the model for text generation. * It first creates an empty `GenerationConfig` object, then it applies the model's own `generation_config` property to it. Finally, if a `generation_config` object was passed in the arguments, it overwrites the corresponding properties in the final config with those of the passed config object. * @param {GenerationConfig|null} generation_config A `GenerationConfig` object containing generation parameters. * @param {Object} kwargs Additional generation parameters to be used in place of those in the `generation_config` object. * @returns {GenerationConfig} The final generation config object to be used by the model for text generation. */ _prepare_generation_config(generation_config, kwargs, cls = GenerationConfig) { // Create empty generation config (contains defaults) // We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them const config = { ...this.config }; for (const key of ["decoder", "generator", "text_config"]) { // Special case: some models have generation attributes set in the decoder. // Use them if still unset in the generation config. if (key in config) { Object.assign(config, config[key]); } } const gen_config = new cls(config); // Apply model's generation config, if it exists if ('generation_config' in this) { Object.assign(gen_config, this.generation_config); } // Next, use any generation config specified by the user // when calling `generate` if (generation_config) { Object.assign(gen_config, generation_config); } // Finally, if any kwargs were passed, use them to overwrite if (kwargs) { Object.assign(gen_config, pick(kwargs, Object.getOwnPropertyNames(gen_config))); } return gen_config; } /** * * @param {GenerationConfig} generation_config * @param {StoppingCriteriaList} [stopping_criteria=null] */ _get_stopping_criteria(generation_config, stopping_criteria = null) { const criteria = new StoppingCriteriaList(); if (generation_config.max_length !== null) { criteria.push(new MaxLengthCriteria( generation_config.max_length, this.config.max_position_embeddings ?? null, )); } // if (generation_config.max_time !== null) { // criteria.push(new MaxTimeCriteria(generation_config.max_time)); // } if (generation_config.eos_token_id !== null) { criteria.push(new EosTokenCriteria(generation_config.eos_token_id)); } if (stopping_criteria) { criteria.extend(stopping_criteria); } return criteria; } /** * Confirms that the model class is compatible with generation. * If not, raises an exception that points to the right class to use. */ _validate_model_class() { if (!this.can_generate) { const generate_compatible_mappings = [ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, // MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, // TODO MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, ]; const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor); const generate_compatible_classes = new Set(); const modelType = this.config.model_type; for (const model_mapping of generate_compatible_mappings) { const supported_models = model_mapping.get(modelType); if (supported_models) { generate_compatible_classes.add(supported_models[0]); } } let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.` if (generate_compatible_classes.size > 0) { errorMessage += ` Please use the following class instead: ${[...generate_compatible_classes].join(', ')}`; } throw Error(errorMessage); } } prepare_inputs_for_generation(...args) { return this._prepare_inputs_for_generation(this, ...args); } /** * * @param {Object} inputs * @param {bigint[][]} inputs.generated_input_ids * @param {Object} inputs.outputs * @param {Object} inputs.model_inputs * @param {boolean} inputs.is_encoder_decoder * @returns {Object} The updated model inputs for the next generation iteration. */ _update_model_kwargs_for_generation({ generated_input_ids, outputs, model_inputs, is_encoder_decoder }) { // update past_key_values model_inputs['past_key_values'] = this.getPastKeyValues(outputs, model_inputs.past_key_values); // update inputs for next run model_inputs['input_ids'] = new Tensor('int64', generated_input_ids.flat(), [generated_input_ids.length, 1]); if (!is_encoder_decoder) { // update attention mask model_inputs.attention_mask = cat( [ model_inputs.attention_mask, ones([model_inputs.attention_mask.dims[0], 1]), ], 1 ); } else if ('decoder_attention_mask' in model_inputs) { // TODO: update decoder attention mask if the model requires it } // force recreate position_ids in next iteration model_inputs['position_ids'] = null; return model_inputs; } /** * This function extracts the model-specific `inputs` for generation. * @param {Object} params * @param {Tensor} [params.inputs=null] * @param {number} [params.bos_token_id=null] * @param {Record<string, Tensor|number[]>} [params.model_kwargs] * @returns {{inputs_tensor: Tensor, model_inputs: Record<string, Tensor>, model_input_name: string}} The model-specific inputs for generation. */ _prepare_model_inputs({ inputs, bos_token_id, model_kwargs }) { const model_inputs = pick(model_kwargs, this.forward_params); const input_name = this.main_input_name; if (input_name in model_inputs) { if (inputs) { throw new Error( "`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " + "Make sure to either pass {inputs} or {input_name}=..." ); } } else { model_inputs[input_name] = inputs; } const inputs_tensor = model_inputs[input_name]; return { inputs_tensor, model_inputs, model_input_name: input_name }; } async _prepare_encoder_decoder_kwargs_for_generation({ inputs_tensor, model_inputs, model_input_name, generation_config }) { if ( this.sessions['model'].inputNames.includes('inputs_embeds') && !model_inputs.inputs_embeds && '_prepare_inputs_embeds' in this ) { // Encoder expects `inputs_embeds` instead of `input_ids` const { input_ids, pixel_values, attention_mask, ...kwargs } = model_inputs; // @ts-ignore const prepared_inputs = await this._prepare_inputs_embeds(model_inputs); model_inputs = { ...kwargs, ...pick(prepared_inp