UNPKG

@huggingface/transformers

Version:

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!

1,273 lines (1,137 loc) • 332 kB
/** * @file Definitions of all models available in Transformers.js. * * **Example:** Load and run an `AutoModel`. * * ```javascript * import { AutoModel, AutoTokenizer } from '@huggingface/transformers'; * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); * let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased'); * * let inputs = await tokenizer('I love transformers!'); * let { logits } = await model(inputs); * // Tensor { * // data: Float32Array(183132) [-7.117443084716797, -7.107812881469727, -7.092104911804199, ...] * // dims: (3) [1, 6, 30522], * // type: "float32", * // size: 183132, * // } * ``` * * We also provide other `AutoModel`s (listed below), which you can use in the same way as the Python library. For example: * * **Example:** Load and run an `AutoModelForSeq2SeqLM`. * ```javascript * import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@huggingface/transformers'; * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); * let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small'); * * let { input_ids } = await tokenizer('translate English to German: I love transformers!'); * let outputs = await model.generate(input_ids); * let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true }); * // 'Ich liebe Transformatoren!' * ``` * * @module models */ import { AutoConfig, getKeyValueShapes, } from './configs.js'; import { deviceToExecutionProviders, createInferenceSession, isONNXTensor, isONNXProxy, } from './backends/onnx.js'; import { DATA_TYPES, DEFAULT_DEVICE_DTYPE_MAPPING, DEFAULT_DTYPE_SUFFIX_MAPPING, isWebGpuFp16Supported, } from './utils/dtypes.js'; import { Callable, } from './utils/generic.js'; import { mergeArrays, pick, } from './utils/core.js'; import { getModelFile, getModelJSON, MAX_EXTERNAL_DATA_CHUNKS, } from './utils/hub.js'; import { GITHUB_ISSUE_URL, } from './utils/constants.js'; import { LogitsProcessorList, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, WhisperTimeStampLogitsProcessor, NoRepeatNGramLogitsProcessor, RepetitionPenaltyLogitsProcessor, NoBadWordsLogitsProcessor, MinLengthLogitsProcessor, MinNewTokensLengthLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, ClassifierFreeGuidanceLogitsProcessor, } from './generation/logits_process.js'; import { GenerationConfig, } from './generation/configuration_utils.js'; import { cat, mean, zeros, zeros_like, ones, ones_like, full, full_like, stack, std_mean, Tensor, DataTypeMap, } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; import { dynamic_time_warping, max, medianFilter } from './utils/maths.js'; import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; import { LogitsSampler } from './generation/logits_sampler.js'; import { apis, env } from './env.js'; import { WhisperGenerationConfig } from './models/whisper/generation_whisper.js'; import { whisper_language_to_code } from './models/whisper/common_whisper.js'; ////////////////////////////////////////////////// // Model types: used internally const MODEL_TYPES = { EncoderOnly: 0, EncoderDecoder: 1, Seq2Seq: 2, Vision2Seq: 3, DecoderOnly: 4, MaskGeneration: 5, ImageTextToText: 6, Musicgen: 7, MultiModality: 8, Phi3V: 9, AudioTextToText: 10, AutoEncoder: 11, } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // Helper functions // NOTE: These will be populated fully later const MODEL_TYPE_MAPPING = new Map(); const MODEL_NAME_TO_CLASS_MAPPING = new Map(); const MODEL_CLASS_TO_NAME_MAPPING = new Map(); /** * Constructs an InferenceSession using a model file located at the specified path. * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {string} fileName The name of the model file. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. * @returns {Promise<{buffer_or_path: Uint8Array|string, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object. * @private */ async function getSession(pretrained_model_name_or_path, fileName, options) { let custom_config = options.config?.['transformers.js_config'] ?? {}; let device = options.device ?? custom_config.device; if (device && typeof device !== 'string') { if (device.hasOwnProperty(fileName)) { device = device[fileName]; } else { console.warn(`device not specified for "${fileName}". Using the default device.`); device = null; } } // If the device is not specified, we use the default (supported) execution providers. const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */( device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm') ); const executionProviders = deviceToExecutionProviders(selectedDevice); // Update custom config with the selected device's config, if it exists const device_config = custom_config.device_config ?? {}; if (device_config.hasOwnProperty(selectedDevice)) { custom_config = { ...custom_config, ...device_config[selectedDevice], }; } // If options.dtype is specified, we use it to choose the suffix for the model file. // Otherwise, we use the default dtype for the device. let dtype = options.dtype ?? custom_config.dtype; if (typeof dtype !== 'string') { if (dtype && dtype.hasOwnProperty(fileName)) { dtype = dtype[fileName]; } else { dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32; console.warn(`dtype not specified for "${fileName}". Using the default dtype (${dtype}) for this device (${selectedDevice}).`); } } if (dtype === DATA_TYPES.auto) { // Try to choose the auto dtype based on the custom config let config_dtype = custom_config.dtype; if (typeof config_dtype !== 'string') { config_dtype = config_dtype?.[fileName]; } if (config_dtype && config_dtype !== DATA_TYPES.auto && DATA_TYPES.hasOwnProperty(config_dtype)) { // Defined by the config, and is not "auto" dtype = config_dtype; } else { // Choose default dtype based on device, falling back to fp32 dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32; } } const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype); if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) { throw new Error(`Invalid dtype: ${selectedDtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`); } else if (selectedDtype === DATA_TYPES.fp16 && selectedDevice === 'webgpu' && !(await isWebGpuFp16Supported())) { throw new Error(`The device (${selectedDevice}) does not support fp16.`); } // Only valid for models with a decoder const kv_cache_dtype_config = custom_config.kv_cache_dtype; const kv_cache_dtype = kv_cache_dtype_config ? (typeof kv_cache_dtype_config === 'string' ? kv_cache_dtype_config : kv_cache_dtype_config[selectedDtype] ?? 'float32') : undefined; if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) { throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`); } const session_config = { dtype: selectedDtype, kv_cache_dtype, } // Construct the model file name const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype]; const baseName = `${fileName}${suffix}.onnx`; const modelFileName = `${options.subfolder ?? ''}/${baseName}`; const session_options = { ...options.session_options }; // Overwrite `executionProviders` if not specified session_options.executionProviders ??= executionProviders; // Overwrite `freeDimensionOverrides` if specified in config and not set in session options const free_dimension_overrides = custom_config.free_dimension_overrides; if (free_dimension_overrides) { session_options.freeDimensionOverrides ??= free_dimension_overrides; } else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) { console.warn( `WebNN does not currently support dynamic shapes and requires 'free_dimension_overrides' to be set in config.json, preferably as a field within config["transformers.js_config"]["device_config"]["${selectedDevice}"]. ` + `When 'free_dimension_overrides' is not set, you may experience significant performance degradation.` ); } const return_path = apis.IS_NODE_ENV && env.useFSCache; const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, return_path); // Handle onnx external data files const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format; /** @type {Promise<string|{path: string, data: Uint8Array}>[]} */ let externalDataPromises = []; if (use_external_data_format) { let external_data_format; if (typeof use_external_data_format === 'object') { if (use_external_data_format.hasOwnProperty(baseName)) { external_data_format = use_external_data_format[baseName]; } else if (use_external_data_format.hasOwnProperty(fileName)) { external_data_format = use_external_data_format[fileName]; } else { external_data_format = false; } } else { external_data_format = use_external_data_format; } const num_chunks = +external_data_format; // (false=0, true=1, number remains the same) if (num_chunks > MAX_EXTERNAL_DATA_CHUNKS) { throw new Error(`The number of external data chunks (${num_chunks}) exceeds the maximum allowed value (${MAX_EXTERNAL_DATA_CHUNKS}).`); } for (let i = 0; i < num_chunks; ++i) { const path = `${baseName}_data${i === 0 ? '' : '_' + i}`; const fullPath = `${options.subfolder ?? ''}/${path}`; externalDataPromises.push(new Promise(async (resolve, reject) => { const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, return_path); resolve(data instanceof Uint8Array ? { path, data } : path); })); } } else if (session_options.externalData !== undefined) { externalDataPromises = session_options.externalData.map(async (ext) => { // if the external data is a string, fetch the file and replace the string with its content // @ts-expect-error TS2339 if (typeof ext.data === "string") { // @ts-expect-error TS2339 const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options); // @ts-expect-error TS2698 return { ...ext, data: ext_buffer }; } return ext; }); } if (externalDataPromises.length > 0) { const externalData = await Promise.all(externalDataPromises); if (!apis.IS_NODE_ENV) { session_options.externalData = externalData; } } if (selectedDevice === 'webgpu') { const shapes = getKeyValueShapes(options.config, { prefix: 'present', }); if (Object.keys(shapes).length > 0 && !isONNXProxy()) { // Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX /** @type {Record<string, import('onnxruntime-common').Tensor.DataLocation>} */ const preferredOutputLocation = {}; for (const key in shapes) { preferredOutputLocation[key] = 'gpu-buffer'; } session_options.preferredOutputLocation = preferredOutputLocation; } } const buffer_or_path = await bufferOrPathPromise; return { buffer_or_path, session_options, session_config }; } /** * Helper function to create multiple InferenceSession objects. * * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {Record<string, string>} names The names of the model files to load. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. * @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of InferenceSession objects. * @private */ async function constructSessions(pretrained_model_name_or_path, names, options) { return Object.fromEntries(await Promise.all( Object.keys(names).map(async (name) => { const { buffer_or_path, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options); const session = await createInferenceSession(buffer_or_path, session_options, session_config); return [name, session]; }) )); } /** * Helper function to load multiple optional configuration files * @param {string} pretrained_model_name_or_path The path to the directory containing the config file. * @param {Record<string, string>} names The names of the config files to load. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the configs. * @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of configuration objects. * @private */ async function getOptionalConfigs(pretrained_model_name_or_path, names, options) { return Object.fromEntries(await Promise.all( Object.keys(names).map(async (name) => { const config = await getModelJSON(pretrained_model_name_or_path, names[name], false, options); return [name, config]; }) )); } /** * Validate model inputs * @param {Object} session The InferenceSession object that will be run. * @param {Object} inputs The inputs to check. * @returns {Record<string, Tensor>} The checked inputs. * @throws {Error} If any inputs are missing. * @private */ function validateInputs(session, inputs) { /** * NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy` * @type {Record<string, Tensor>} */ const checkedInputs = Object.create(null); const missingInputs = []; for (const inputName of session.inputNames) { const tensor = inputs[inputName]; // Rare case where one of the model's input names corresponds to a built-in // object name (e.g., toString), which would cause a simple (!tensor) check to fail, // because it's not undefined but a function. if (!(tensor instanceof Tensor)) { missingInputs.push(inputName); continue; } // NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker // boundary, transferring ownership to the worker and invalidating the tensor. // So, in this case, we simply sacrifice a clone for it. checkedInputs[inputName] = isONNXProxy() ? tensor.clone() : tensor; } if (missingInputs.length > 0) { throw new Error( `An error occurred during model execution: "Missing the following inputs: ${missingInputs.join(', ')}.`); } const numInputsProvided = Object.keys(inputs).length; const numInputsNeeded = session.inputNames.length; if (numInputsProvided > numInputsNeeded) { // No missing inputs, but too many inputs were provided. // Warn the user and ignore the extra inputs. let ignored = Object.keys(inputs).filter(inputName => !session.inputNames.includes(inputName)); console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(', ')}".`); } return checkedInputs; } /** * Executes an InferenceSession using the specified inputs. * NOTE: `inputs` must contain at least the input names of the model. * - If additional inputs are passed, they will be ignored. * - If inputs are missing, an error will be thrown. * * @param {Object} session The InferenceSession object to run. * @param {Object} inputs An object that maps input names to input tensors. * @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors. * @private */ async function sessionRun(session, inputs) { const checkedInputs = validateInputs(session, inputs); try { // pass the original ort tensor const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor])); let output = await session.run(ortFeed); output = replaceTensors(output); return output; } catch (e) { // Error messages can be long (nested) and uninformative. For this reason, // we apply minor formatting to show the most important information const formatted = Object.fromEntries(Object.entries(checkedInputs) .map(([k, { type, dims, data }]) => [k, { // Extract these properties from the underlying ORT tensor type, dims, data, }])); // This usually occurs when the inputs are of the wrong type. console.error(`An error occurred during model execution: "${e}".`); console.error('Inputs given to model:', formatted); throw e; } } /** * Replaces ONNX Tensor objects with custom Tensor objects to support additional functions. * @param {Object} obj The object to replace tensor objects in. * @returns {Object} The object with tensor objects replaced by custom Tensor objects. * @private */ function replaceTensors(obj) { for (let prop in obj) { if (isONNXTensor(obj[prop])) { obj[prop] = new Tensor(obj[prop]); } else if (typeof obj[prop] === 'object') { replaceTensors(obj[prop]); } } return obj; } /** * Converts an array or Tensor of integers to an int64 Tensor. * @param {any[]|Tensor} items The input integers to be converted. * @returns {Tensor} The int64 Tensor with the converted values. * @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length. * @private */ function toI64Tensor(items) { if (items instanceof Tensor) { return items; } // items is an array if (items.length === 0) { throw Error("items must be non-empty"); } if (Array.isArray(items[0])) { // batched if (items.some(x => x.length !== items[0].length)) { throw Error("Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' and/or 'truncation=True' to have batched tensors with the same length.") } return new Tensor('int64', BigInt64Array.from(items.flat().map(x => BigInt(x))), [items.length, items[0].length] ); } else { //flat return new Tensor('int64', BigInt64Array.from(items.map(x => BigInt(x))), [1, items.length] ); } } /** * Creates a boolean tensor with a single value. * @param {boolean} value The value of the tensor. * @returns {Tensor} The boolean tensor. * @private */ function boolTensor(value) { return new Tensor('bool', [value], [1]); } // JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in /** * Perform forward pass on the seq2seq model (both encoder and decoder). * @param {Object} self The seq2seq model object. * @param {Object} model_inputs The input object for the model containing encoder and decoder inputs. * @returns {Promise<Seq2SeqLMOutput>} Promise that resolves with the output of the seq2seq model. * @private */ async function seq2seqForward(self, model_inputs) { let { encoder_outputs, input_ids, decoder_input_ids, ...other_decoder_inputs } = model_inputs; // Encode if needed if (!encoder_outputs) { const encoder_inputs = pick(model_inputs, self.sessions['model'].inputNames); // Encoder outputs are not given, so we must compute them. encoder_outputs = (await encoderForward(self, encoder_inputs)).last_hidden_state; } other_decoder_inputs.input_ids = decoder_input_ids; other_decoder_inputs.encoder_hidden_states = encoder_outputs; if (self.sessions['decoder_model_merged'].inputNames.includes('encoder_attention_mask')) { other_decoder_inputs.encoder_attention_mask = model_inputs.attention_mask } const decoderResults = await decoderForward(self, other_decoder_inputs, true); return decoderResults; } /** * Forward pass of an encoder model. * @param {Object} self The encoder model. * @param {Object} model_inputs The input data to be used for the forward pass. * @returns {Promise<Object>} The model's outputs. * @private */ async function encoderForward(self, model_inputs) { const session = self.sessions['model']; const encoderFeeds = pick(model_inputs, session.inputNames); if (session.inputNames.includes('inputs_embeds') && !encoderFeeds.inputs_embeds) { if (!model_inputs.input_ids) { throw new Error('Both `input_ids` and `inputs_embeds` are missing in the model inputs.'); } encoderFeeds.inputs_embeds = await self.encode_text({ input_ids: model_inputs.input_ids }); } if (session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) { if (!encoderFeeds.input_ids) { throw new Error('Both `input_ids` and `token_type_ids` are missing in the model inputs.'); } // Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it, // but they weren't created by the tokenizer. encoderFeeds.token_type_ids = zeros_like(encoderFeeds.input_ids); } if (session.inputNames.includes('pixel_mask') && !encoderFeeds.pixel_mask) { if (!encoderFeeds.pixel_values) { throw new Error('Both `pixel_values` and `pixel_mask` are missing in the model inputs.'); } // Assign default `pixel_mask` (all ones) to the `encoderFeeds` if the model expects it, // but they weren't created by the processor. const dims = encoderFeeds.pixel_values.dims; encoderFeeds.pixel_mask = ones([dims[0], dims[2], dims[3]]); } return await sessionRun(session, encoderFeeds); } async function autoEncoderForward(self, model_inputs) { const encoded = await self.encode(model_inputs); const decoded = await self.decode(encoded); return decoded; } /** * Forward pass of a decoder model. * @param {Object} self The decoder model. * @param {Object} model_inputs The input data to be used for the forward pass. * @returns {Promise<Object>} The logits and past key values. * @private */ async function decoderForward(self, model_inputs, is_encoder_decoder = false) { const session = self.sessions[ is_encoder_decoder ? 'decoder_model_merged' : 'model' ] const { past_key_values, ...new_model_inputs } = model_inputs; if (session.inputNames.includes('use_cache_branch')) { new_model_inputs.use_cache_branch = boolTensor(!!past_key_values); } if (session.inputNames.includes('position_ids') && new_model_inputs.attention_mask && !new_model_inputs.position_ids) { // NOTE: Handle a special case for paligemma/gemma3 models, where positions are 1-indexed const start_index = ['paligemma', 'gemma3_text', 'gemma3'].includes(self.config.model_type) ? 1 : 0; new_model_inputs.position_ids = createPositionIds(new_model_inputs, past_key_values, start_index); } // Unpack the `past_key_values` object into model inputs self.addPastKeyValues(new_model_inputs, past_key_values); // Select only the inputs that are needed for the current session const fixed = pick(new_model_inputs, session.inputNames); return await sessionRun(session, fixed); } function default_merge_input_ids_with_features({ modality_token_id, inputs_embeds, modality_features, input_ids, attention_mask, }) { const token_positions = input_ids.tolist().map(ids => ids.reduce((acc, x, idx) => { if (x == modality_token_id) acc.push(idx); return acc; }, []) ); const n_tokens = token_positions.reduce((acc, x) => acc + x.length, 0); const n_features = modality_features.dims[0]; if (n_tokens !== n_features) { throw new Error(`Number of tokens and features do not match: tokens: ${n_tokens}, features ${n_features}`); } // Equivalent to performing a masked_scatter let img = 0; for (let i = 0; i < token_positions.length; ++i) { const tokens = token_positions[i]; const embeds = inputs_embeds[i]; for (let j = 0; j < tokens.length; ++j) { embeds[tokens[j]].data.set(modality_features[img++].data) } } return { inputs_embeds, attention_mask } } function default_merge_input_ids_with_image_features({ image_token_id, inputs_embeds, image_features, input_ids, attention_mask, }) { return default_merge_input_ids_with_features({ modality_token_id: image_token_id, inputs_embeds, modality_features: image_features, input_ids, attention_mask, }) } function default_merge_input_ids_with_audio_features({ audio_token_id, inputs_embeds, audio_features, input_ids, attention_mask, }) { return default_merge_input_ids_with_features({ modality_token_id: audio_token_id, inputs_embeds, modality_features: audio_features, input_ids, attention_mask, }) } /** * Abstract forward pass function for image-text-to-text or audio-text-to-text models. * @param {Object} self The model object. * @param {Object} params Additional parameters. * @param {Function} [params.encode_function] The function to encode the modality values. * @param {Function} [params.merge_function] The function to merge the modality features with the input embeddings. * @param {string} [params.modality_input_name] The modality input name. * @param {string} [params.modality_output_name] The modality output name. * @param {Tensor} [params.input_ids=null] * @param {Tensor} [params.attention_mask=null] * @param {Tensor} [params.position_ids=null] * @param {Tensor} [params.inputs_embeds=null] * @param {Tensor} [params.past_key_values=null] * @param {Object} [params.generation_config=null] * @param {Object} [params.logits_processor=null] * @returns {Promise<Tensor>} The model's output tensor * @private */ async function genericTextToTextForward(self, { // Generic parameters: encode_function, merge_function, modality_input_name, modality_output_name, // Produced by the tokenizer/processor: input_ids = null, attention_mask = null, // Used during generation: position_ids = null, inputs_embeds = null, past_key_values = null, // Generic generation parameters generation_config = null, logits_processor = null, // Additional parameters ...kwargs }) { const modality_values = kwargs[modality_input_name]; if (!inputs_embeds) { // 1. Extract the text embeddings. inputs_embeds = await self.encode_text({ input_ids, ...kwargs }); // 2. Possibly, merge text and modality values if (modality_values && input_ids.dims[1] !== 1) { const modality_features = await encode_function({ // Pass the modality values under its expected key. // The caller knows whether this is audio or image. [modality_input_name]: modality_values, ...kwargs }); ({ inputs_embeds, attention_mask } = merge_function({ [modality_output_name]: modality_features, inputs_embeds, input_ids, attention_mask, })); } else if (past_key_values && modality_values && input_ids.dims[1] === 1) { // This branch handles the cache case. const target_length = input_ids.dims[1]; // always 1 const past_length = Object.values(past_key_values)[0].dims.at(-2); attention_mask = cat([ ones([input_ids.dims[0], past_length]), attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]), ], 1); } } if (!position_ids) { if (self.config.model_type === 'qwen2_vl') { // Special case for qwen2_vl models // @ts-ignore const { image_grid_thw, video_grid_thw } = kwargs; [position_ids] = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) } } // 3. Call the decoder forward using the updated inputs. const outputs = await decoderForward(self, { inputs_embeds, past_key_values, attention_mask, position_ids, generation_config, logits_processor, }, true); return outputs; } /** * Forward pass of an audio-text-to-text model. * @param {Object} self The audio-text-to-text model. * @param {Object} params The inputs for the audio-text-to-text forward pass. * @returns {Promise<Tensor>} The model's output tensor. * @private */ async function audioTextToTextForward(self, params) { return await genericTextToTextForward(self, { ...params, modality_input_name: 'audio_values', modality_output_name: 'audio_features', encode_function: self.encode_audio.bind(self), merge_function: self._merge_input_ids_with_audio_features.bind(self), }); } /** * Forward pass of an image-text-to-text model. * @param {Object} self The image-text-to-text model. * @param {Object} params The inputs for the image-text-to-text forward pass. * @returns {Promise<Tensor>} The model's output tensor. * @private */ async function imageTextToTextForward(self, params) { return await genericTextToTextForward(self, { ...params, modality_input_name: 'pixel_values', modality_output_name: 'image_features', encode_function: self.encode_image.bind(self), merge_function: self._merge_input_ids_with_image_features.bind(self), }); } /** * Helper function to perform the following: * ```python * x = attention_mask.long().cumsum(-1) - 1 * x.masked_fill_(attention_mask == 0, 1) * ``` * @param {Tensor} attention_mask * @returns {{data: BigInt64Array, dims: number[]}} */ function cumsum_masked_fill(attention_mask, start_index = 0) { const [bz, seq_len] = attention_mask.dims; const attn_mask_data = attention_mask.data; const data = new BigInt64Array(attn_mask_data.length); for (let i = 0; i < bz; ++i) { const start = i * seq_len; let sum = BigInt(start_index); for (let j = 0; j < seq_len; ++j) { const index = start + j; if (attn_mask_data[index] === 0n) { data[index] = BigInt(1); } else { // === 1n data[index] = sum; sum += attn_mask_data[index]; } } } return { data, dims: attention_mask.dims }; } /** * If the model supports providing position_ids, we create position_ids on the fly for batch generation, * by computing the cumulative sum of the attention mask along the sequence length dimension. * * Equivalent to: * ```python * position_ids = attention_mask.long().cumsum(-1) - 1 * position_ids.masked_fill_(attention_mask == 0, 1) * if past_key_values: * position_ids = position_ids[:, -input_ids.shape[1] :] * ``` */ function createPositionIds(model_inputs, past_key_values = null, start_index = 0) { const { input_ids, inputs_embeds, attention_mask } = model_inputs; const { data, dims } = cumsum_masked_fill(attention_mask, start_index); let position_ids = new Tensor('int64', data, dims); if (past_key_values) { const offset = -(input_ids ?? inputs_embeds).dims.at(1); position_ids = position_ids.slice(null, [offset, null]); } return position_ids; } function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) { if (model_inputs.past_key_values) { const past_length = Object.values(model_inputs.past_key_values)[0].dims.at(-2); const { input_ids, attention_mask } = model_inputs; // Keep only the unprocessed tokens: // 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where // some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as // input) if (attention_mask && attention_mask.dims[1] > input_ids.dims[1]) { // NOTE: not needed since we only pass the generated tokens to the next forward pass // const offset = -(attention_mask.dims[1] - past_length); // model_inputs.input_ids = input_ids.slice(null, [offset, null]); } // 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. // We can discard input_ids based on the past_length. else if (past_length < input_ids.dims[1]) { // NOTE: Required for phi models. // See https://github.com/huggingface/transformers/issues/30809#issuecomment-2111918479 for more information. model_inputs.input_ids = input_ids.slice(null, [past_length, null]); } // 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. else { if ( // NOTE: Only used by VLMs (!= so that null matches undefined) self.config.image_token_index != null && // Equivalent to `self.config.image_token_index in input_ids` (== so that int matches bigint) input_ids.data.some(x => x == self.config.image_token_index) ) { // TODO: Support multiple image tokens const num_image_tokens = self.config.num_image_tokens; if (!num_image_tokens) { throw new Error('`num_image_tokens` is missing in the model configuration.'); } const num_new_tokens = input_ids.dims[1] - (past_length - num_image_tokens); model_inputs.input_ids = input_ids.slice(null, [-num_new_tokens, null]); // TODO: The attention mask should be formed from the attention mask passed in model_inputs model_inputs.attention_mask = ones([1, past_length + num_new_tokens]); } } } return model_inputs; } function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) { if (model_inputs.past_key_values) { input_ids = input_ids.map(x => [x.at(-1)]); } return { ...model_inputs, decoder_input_ids: toI64Tensor(input_ids), }; } function multimodal_text_to_text_prepare_inputs_for_generation(self, ...args) { if (self.config.is_encoder_decoder) { return encoder_decoder_prepare_inputs_for_generation(self, ...args); } else { return decoder_prepare_inputs_for_generation(self, ...args); } } function multimodality_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) { const has_past_key_values = !!model_inputs.past_key_values; if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { if (has_past_key_values) { model_inputs.input_ids = cat([ model_inputs.input_ids, model_inputs.input_ids, ], 0) // NOTE: attention_mask handled in generation } else { model_inputs.input_ids = cat([ model_inputs.input_ids, full_like(model_inputs.input_ids, BigInt(generation_config.pad_token_id)), ], 0); model_inputs.attention_mask = cat([ model_inputs.attention_mask, full_like(model_inputs.attention_mask, 0n), ], 0); } } if (has_past_key_values || !model_inputs.pixel_values) { model_inputs.pixel_values = full([0, 0, 3, 384, 384], 1.0); } if (has_past_key_values) { const num_img_tokens = 0; const num_text_tokens = 1; const has_image = num_img_tokens > 0 ? 1 : 0; const batch_size = 1; model_inputs.images_seq_mask = new Tensor( 'bool', new Array(num_img_tokens + num_text_tokens).fill(true).fill(false, 0, num_text_tokens), [batch_size, num_img_tokens + num_text_tokens], ); model_inputs.images_emb_mask = new Tensor( 'bool', new Array(num_img_tokens).fill(!!has_image), [batch_size, 1, num_img_tokens], ); } return model_inputs; } ////////////////////////////////////////////////// ////////////////////////////////////////////////// /** * A base class for pre-trained models that provides the model configuration and an ONNX session. */ export class PreTrainedModel extends Callable { main_input_name = 'input_ids'; forward_params = ['input_ids', 'attention_mask']; /** * Creates a new instance of the `PreTrainedModel` class. * @param {import('./configs.js').PretrainedConfig} config The model configuration. * @param {Record<string, any>} sessions The inference sessions for the model. * @param {Record<string, Object>} configs Additional configuration files (e.g., generation_config.json). */ constructor(config, sessions, configs) { super(); this.config = config; this.sessions = sessions; this.configs = configs; const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor); const modelType = MODEL_TYPE_MAPPING.get(modelName); this.can_generate = false; this._forward = null; this._prepare_inputs_for_generation = null; switch (modelType) { case MODEL_TYPES.DecoderOnly: this.can_generate = true; this._forward = decoderForward; this._prepare_inputs_for_generation = decoder_prepare_inputs_for_generation; break; case MODEL_TYPES.Seq2Seq: case MODEL_TYPES.Vision2Seq: case MODEL_TYPES.Musicgen: this.can_generate = true; this._forward = seq2seqForward; this._prepare_inputs_for_generation = encoder_decoder_prepare_inputs_for_generation; break; case MODEL_TYPES.EncoderDecoder: this._forward = seq2seqForward; break; case MODEL_TYPES.ImageTextToText: this.can_generate = true; this._forward = imageTextToTextForward; this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation; break; case MODEL_TYPES.AudioTextToText: this.can_generate = true; this._forward = audioTextToTextForward; this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation; break; case MODEL_TYPES.Phi3V: this.can_generate = true; this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation; break; case MODEL_TYPES.MultiModality: this.can_generate = true; this._prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation; break; case MODEL_TYPES.AutoEncoder: this._forward = autoEncoderForward; break; default: // should be MODEL_TYPES.EncoderOnly this._forward = encoderForward; break; } if (this.can_generate) { this.forward_params.push('past_key_values'); } /** @type {import('./configs.js').TransformersJSConfig} */ this.custom_config = this.config['transformers.js_config'] ?? {}; } /** * Disposes of all the ONNX sessions that were created during inference. * @returns {Promise<unknown[]>} An array of promises, one for each ONNX session that is being disposed. * @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry */ async dispose() { const promises = []; for (const session of Object.values(this.sessions)) { if (session?.handler?.dispose) { promises.push(session.handler.dispose()) } } return await Promise.all(promises); } /** * Instantiate one of the model classes of the library from a pretrained model. * * The model class to instantiate is selected based on the `model_type` property of the config object * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) * * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: * - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a * user or organization name, like `dbmdz/bert-base-german-cased`. * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. * * @returns {Promise<PreTrainedModel>} A new instance of the `PreTrainedModel` class. */ static async from_pretrained(pretrained_model_name_or_path, { progress_callback = null, config = null, cache_dir = null, local_files_only = false, revision = 'main', model_file_name = null, subfolder = 'onnx', device = null, dtype = null, use_external_data_format = null, session_options = {}, } = {}) { let options = { progress_callback, config, cache_dir, local_files_only, revision, model_file_name, subfolder, device, dtype, use_external_data_format, session_options, } const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this); const modelType = MODEL_TYPE_MAPPING.get(modelName); config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); let info; if (modelType === MODEL_TYPES.DecoderOnly) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: options.model_file_name ?? 'model', }, options), getOptionalConfigs(pretrained_model_name_or_path, { generation_config: 'generation_config.json', }, options), ]); } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: 'encoder_model', decoder_model_merged: 'decoder_model_merged', }, options), getOptionalConfigs(pretrained_model_name_or_path, { generation_config: 'generation_config.json', }, options), ]); } else if (modelType === MODEL_TYPES.MaskGeneration) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: 'vision_encoder', prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder', }, options), ]); } else if (modelType === MODEL_TYPES.EncoderDecoder) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: 'encoder_model', decoder_model_merged: 'decoder_model_merged', }, options), ]); } else if (modelType === MODEL_TYPES.ImageTextToText) { const sessions = { embed_tokens: 'embed_tokens', vision_encoder: 'vision_encoder', decoder_model_merged: 'decoder_model_merged', } if (config.is_encoder_decoder) { sessions['model'] = 'encoder_model'; } info = await Promise.all([ constructSessions(pretrained_model_name_or_path, sessions, options), getOptionalConfigs(pretrained_model_name_or_path, { generation_config: 'generation_config.json', }, options), ]); } else if (modelType === MODEL_TYPES.AudioTextToText) { const sessions = { embed_tokens: 'embed_tokens', audio_encoder: 'audio_encoder', decoder_model_merged: 'decoder_model_merged', } info = await Promise.all([ constructSessions(pretrained_model_name_or_path, sessions, options), getOptionalConfigs(pretrained_model_name_or_path, { generation_config: 'generation_config.json', }, options), ]); } else if (modelType === MODEL_TYPES.Musicgen) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: 'text_encoder', decoder_model_merged: 'decoder_model_merged', encodec_decode: 'encodec_decode', }, options), getOptionalConfigs(pretrained_model_name_or_path, { generation_config: 'generation_config.json', }, options), ]); } else if (modelType === MODEL_TYPES.MultiModality) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { prepare_inputs_embeds: 'prepare_inputs_embeds', model: 'language_model', lm_head: 'lm_head', gen_head: 'gen_head', gen_img_embeds: 'gen_img_embeds', image_decode: 'image_decode', }, options), getOptionalConfigs(pretrained_model_name_or_path, { generation_config: 'generation_config.json', }, options), ]); } else if (modelType === MODEL_TYPES.Phi3V) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { prepare_inputs_embeds: 'prepare_inputs_embeds', model: 'model', vision_encoder: 'vision_encoder', }, options), getOptionalConfigs(pretrained_model_name_or_path, { generation_config: 'generation_config.json', }, options), ]); } else if (modelType === MODEL_TYPES.AutoEncoder) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { encoder_model: 'encoder_model', decoder_model: 'decoder_model', }, options), ]); } else { // should be MODEL_TYPES.EncoderOnly if (modelType !== MODEL_TYPES.EncoderOnly) { const type = modelName ?? config?.model_type; if (type !== 'custom') { console.warn(`Model type for '${type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`) } } info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { model: options.model_file_name ?? 'model', }, options), ]); } // @ts-ignore return new this(config, ...info); } /** * Runs the model with the provided inputs * @param {Object} model_inputs Object containing input tensors * @returns {Promise<Object>} Object containing output tensors */ async _call(model_inputs) { return await this.forward(model_inputs); } /** * Forward method for a pretrained model. If not overridden by a subclass, the correct forward method * will be chosen based on the model type. * @param {Object} model_inputs The input data to the model in the format specified in the ONNX model. * @