@huggingface/transformers
Version:
State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
1,273 lines (1,137 loc) • 332 kB
JavaScript
/**
* @file Definitions of all models available in Transformers.js.
*
* **Example:** Load and run an `AutoModel`.
*
* ```javascript
* import { AutoModel, AutoTokenizer } from '@huggingface/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
* let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased');
*
* let inputs = await tokenizer('I love transformers!');
* let { logits } = await model(inputs);
* // Tensor {
* // data: Float32Array(183132) [-7.117443084716797, -7.107812881469727, -7.092104911804199, ...]
* // dims: (3) [1, 6, 30522],
* // type: "float32",
* // size: 183132,
* // }
* ```
*
* We also provide other `AutoModel`s (listed below), which you can use in the same way as the Python library. For example:
*
* **Example:** Load and run an `AutoModelForSeq2SeqLM`.
* ```javascript
* import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@huggingface/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');
* let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small');
*
* let { input_ids } = await tokenizer('translate English to German: I love transformers!');
* let outputs = await model.generate(input_ids);
* let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true });
* // 'Ich liebe Transformatoren!'
* ```
*
* @module models
*/
import {
AutoConfig,
getKeyValueShapes,
} from './configs.js';
import {
deviceToExecutionProviders,
createInferenceSession,
isONNXTensor,
isONNXProxy,
} from './backends/onnx.js';
import {
DATA_TYPES,
DEFAULT_DEVICE_DTYPE_MAPPING,
DEFAULT_DTYPE_SUFFIX_MAPPING,
isWebGpuFp16Supported,
} from './utils/dtypes.js';
import {
Callable,
} from './utils/generic.js';
import {
mergeArrays,
pick,
} from './utils/core.js';
import {
getModelFile,
getModelJSON,
MAX_EXTERNAL_DATA_CHUNKS,
} from './utils/hub.js';
import {
GITHUB_ISSUE_URL,
} from './utils/constants.js';
import {
LogitsProcessorList,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
WhisperTimeStampLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
NoBadWordsLogitsProcessor,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
ClassifierFreeGuidanceLogitsProcessor,
} from './generation/logits_process.js';
import {
GenerationConfig,
} from './generation/configuration_utils.js';
import {
cat,
mean,
zeros,
zeros_like,
ones,
ones_like,
full,
full_like,
stack,
std_mean,
Tensor,
DataTypeMap,
} from './utils/tensor.js';
import { RawImage } from './utils/image.js';
import { dynamic_time_warping, max, medianFilter } from './utils/maths.js';
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
import { LogitsSampler } from './generation/logits_sampler.js';
import { apis, env } from './env.js';
import { WhisperGenerationConfig } from './models/whisper/generation_whisper.js';
import { whisper_language_to_code } from './models/whisper/common_whisper.js';
//////////////////////////////////////////////////
// Model types: used internally
const MODEL_TYPES = {
EncoderOnly: 0,
EncoderDecoder: 1,
Seq2Seq: 2,
Vision2Seq: 3,
DecoderOnly: 4,
MaskGeneration: 5,
ImageTextToText: 6,
Musicgen: 7,
MultiModality: 8,
Phi3V: 9,
AudioTextToText: 10,
AutoEncoder: 11,
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// Helper functions
// NOTE: These will be populated fully later
const MODEL_TYPE_MAPPING = new Map();
const MODEL_NAME_TO_CLASS_MAPPING = new Map();
const MODEL_CLASS_TO_NAME_MAPPING = new Map();
/**
* Constructs an InferenceSession using a model file located at the specified path.
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @returns {Promise<{buffer_or_path: Uint8Array|string, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
* @private
*/
async function getSession(pretrained_model_name_or_path, fileName, options) {
let custom_config = options.config?.['transformers.js_config'] ?? {};
let device = options.device ?? custom_config.device;
if (device && typeof device !== 'string') {
if (device.hasOwnProperty(fileName)) {
device = device[fileName];
} else {
console.warn(`device not specified for "${fileName}". Using the default device.`);
device = null;
}
}
// If the device is not specified, we use the default (supported) execution providers.
const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */(
device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm')
);
const executionProviders = deviceToExecutionProviders(selectedDevice);
// Update custom config with the selected device's config, if it exists
const device_config = custom_config.device_config ?? {};
if (device_config.hasOwnProperty(selectedDevice)) {
custom_config = {
...custom_config,
...device_config[selectedDevice],
};
}
// If options.dtype is specified, we use it to choose the suffix for the model file.
// Otherwise, we use the default dtype for the device.
let dtype = options.dtype ?? custom_config.dtype;
if (typeof dtype !== 'string') {
if (dtype && dtype.hasOwnProperty(fileName)) {
dtype = dtype[fileName];
} else {
dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32;
console.warn(`dtype not specified for "${fileName}". Using the default dtype (${dtype}) for this device (${selectedDevice}).`);
}
}
if (dtype === DATA_TYPES.auto) {
// Try to choose the auto dtype based on the custom config
let config_dtype = custom_config.dtype;
if (typeof config_dtype !== 'string') {
config_dtype = config_dtype?.[fileName];
}
if (config_dtype && config_dtype !== DATA_TYPES.auto && DATA_TYPES.hasOwnProperty(config_dtype)) {
// Defined by the config, and is not "auto"
dtype = config_dtype;
} else {
// Choose default dtype based on device, falling back to fp32
dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32;
}
}
const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype);
if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) {
throw new Error(`Invalid dtype: ${selectedDtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`);
} else if (selectedDtype === DATA_TYPES.fp16 && selectedDevice === 'webgpu' && !(await isWebGpuFp16Supported())) {
throw new Error(`The device (${selectedDevice}) does not support fp16.`);
}
// Only valid for models with a decoder
const kv_cache_dtype_config = custom_config.kv_cache_dtype;
const kv_cache_dtype = kv_cache_dtype_config
? (typeof kv_cache_dtype_config === 'string'
? kv_cache_dtype_config
: kv_cache_dtype_config[selectedDtype] ?? 'float32')
: undefined;
if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
}
const session_config = {
dtype: selectedDtype,
kv_cache_dtype,
}
// Construct the model file name
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
const baseName = `${fileName}${suffix}.onnx`;
const modelFileName = `${options.subfolder ?? ''}/${baseName}`;
const session_options = { ...options.session_options };
// Overwrite `executionProviders` if not specified
session_options.executionProviders ??= executionProviders;
// Overwrite `freeDimensionOverrides` if specified in config and not set in session options
const free_dimension_overrides = custom_config.free_dimension_overrides;
if (free_dimension_overrides) {
session_options.freeDimensionOverrides ??= free_dimension_overrides;
} else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
console.warn(
`WebNN does not currently support dynamic shapes and requires 'free_dimension_overrides' to be set in config.json, preferably as a field within config["transformers.js_config"]["device_config"]["${selectedDevice}"]. ` +
`When 'free_dimension_overrides' is not set, you may experience significant performance degradation.`
);
}
const return_path = apis.IS_NODE_ENV && env.useFSCache;
const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, return_path);
// Handle onnx external data files
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
/** @type {Promise<string|{path: string, data: Uint8Array}>[]} */
let externalDataPromises = [];
if (use_external_data_format) {
let external_data_format;
if (typeof use_external_data_format === 'object') {
if (use_external_data_format.hasOwnProperty(baseName)) {
external_data_format = use_external_data_format[baseName];
} else if (use_external_data_format.hasOwnProperty(fileName)) {
external_data_format = use_external_data_format[fileName];
} else {
external_data_format = false;
}
} else {
external_data_format = use_external_data_format;
}
const num_chunks = +external_data_format; // (false=0, true=1, number remains the same)
if (num_chunks > MAX_EXTERNAL_DATA_CHUNKS) {
throw new Error(`The number of external data chunks (${num_chunks}) exceeds the maximum allowed value (${MAX_EXTERNAL_DATA_CHUNKS}).`);
}
for (let i = 0; i < num_chunks; ++i) {
const path = `${baseName}_data${i === 0 ? '' : '_' + i}`;
const fullPath = `${options.subfolder ?? ''}/${path}`;
externalDataPromises.push(new Promise(async (resolve, reject) => {
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, return_path);
resolve(data instanceof Uint8Array ? { path, data } : path);
}));
}
} else if (session_options.externalData !== undefined) {
externalDataPromises = session_options.externalData.map(async (ext) => {
// if the external data is a string, fetch the file and replace the string with its content
// @ts-expect-error TS2339
if (typeof ext.data === "string") {
// @ts-expect-error TS2339
const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options);
// @ts-expect-error TS2698
return { ...ext, data: ext_buffer };
}
return ext;
});
}
if (externalDataPromises.length > 0) {
const externalData = await Promise.all(externalDataPromises);
if (!apis.IS_NODE_ENV) {
session_options.externalData = externalData;
}
}
if (selectedDevice === 'webgpu') {
const shapes = getKeyValueShapes(options.config, {
prefix: 'present',
});
if (Object.keys(shapes).length > 0 && !isONNXProxy()) {
// Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX
/** @type {Record<string, import('onnxruntime-common').Tensor.DataLocation>} */
const preferredOutputLocation = {};
for (const key in shapes) {
preferredOutputLocation[key] = 'gpu-buffer';
}
session_options.preferredOutputLocation = preferredOutputLocation;
}
}
const buffer_or_path = await bufferOrPathPromise;
return { buffer_or_path, session_options, session_config };
}
/**
* Helper function to create multiple InferenceSession objects.
*
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {Record<string, string>} names The names of the model files to load.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of InferenceSession objects.
* @private
*/
async function constructSessions(pretrained_model_name_or_path, names, options) {
return Object.fromEntries(await Promise.all(
Object.keys(names).map(async (name) => {
const { buffer_or_path, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
const session = await createInferenceSession(buffer_or_path, session_options, session_config);
return [name, session];
})
));
}
/**
* Helper function to load multiple optional configuration files
* @param {string} pretrained_model_name_or_path The path to the directory containing the config file.
* @param {Record<string, string>} names The names of the config files to load.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the configs.
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of configuration objects.
* @private
*/
async function getOptionalConfigs(pretrained_model_name_or_path, names, options) {
return Object.fromEntries(await Promise.all(
Object.keys(names).map(async (name) => {
const config = await getModelJSON(pretrained_model_name_or_path, names[name], false, options);
return [name, config];
})
));
}
/**
* Validate model inputs
* @param {Object} session The InferenceSession object that will be run.
* @param {Object} inputs The inputs to check.
* @returns {Record<string, Tensor>} The checked inputs.
* @throws {Error} If any inputs are missing.
* @private
*/
function validateInputs(session, inputs) {
/**
* NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy`
* @type {Record<string, Tensor>}
*/
const checkedInputs = Object.create(null);
const missingInputs = [];
for (const inputName of session.inputNames) {
const tensor = inputs[inputName];
// Rare case where one of the model's input names corresponds to a built-in
// object name (e.g., toString), which would cause a simple (!tensor) check to fail,
// because it's not undefined but a function.
if (!(tensor instanceof Tensor)) {
missingInputs.push(inputName);
continue;
}
// NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker
// boundary, transferring ownership to the worker and invalidating the tensor.
// So, in this case, we simply sacrifice a clone for it.
checkedInputs[inputName] = isONNXProxy() ? tensor.clone() : tensor;
}
if (missingInputs.length > 0) {
throw new Error(
`An error occurred during model execution: "Missing the following inputs: ${missingInputs.join(', ')}.`);
}
const numInputsProvided = Object.keys(inputs).length;
const numInputsNeeded = session.inputNames.length;
if (numInputsProvided > numInputsNeeded) {
// No missing inputs, but too many inputs were provided.
// Warn the user and ignore the extra inputs.
let ignored = Object.keys(inputs).filter(inputName => !session.inputNames.includes(inputName));
console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(', ')}".`);
}
return checkedInputs;
}
/**
* Executes an InferenceSession using the specified inputs.
* NOTE: `inputs` must contain at least the input names of the model.
* - If additional inputs are passed, they will be ignored.
* - If inputs are missing, an error will be thrown.
*
* @param {Object} session The InferenceSession object to run.
* @param {Object} inputs An object that maps input names to input tensors.
* @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors.
* @private
*/
async function sessionRun(session, inputs) {
const checkedInputs = validateInputs(session, inputs);
try {
// pass the original ort tensor
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
let output = await session.run(ortFeed);
output = replaceTensors(output);
return output;
} catch (e) {
// Error messages can be long (nested) and uninformative. For this reason,
// we apply minor formatting to show the most important information
const formatted = Object.fromEntries(Object.entries(checkedInputs)
.map(([k, { type, dims, data }]) => [k, {
// Extract these properties from the underlying ORT tensor
type, dims, data,
}]));
// This usually occurs when the inputs are of the wrong type.
console.error(`An error occurred during model execution: "${e}".`);
console.error('Inputs given to model:', formatted);
throw e;
}
}
/**
* Replaces ONNX Tensor objects with custom Tensor objects to support additional functions.
* @param {Object} obj The object to replace tensor objects in.
* @returns {Object} The object with tensor objects replaced by custom Tensor objects.
* @private
*/
function replaceTensors(obj) {
for (let prop in obj) {
if (isONNXTensor(obj[prop])) {
obj[prop] = new Tensor(obj[prop]);
} else if (typeof obj[prop] === 'object') {
replaceTensors(obj[prop]);
}
}
return obj;
}
/**
* Converts an array or Tensor of integers to an int64 Tensor.
* @param {any[]|Tensor} items The input integers to be converted.
* @returns {Tensor} The int64 Tensor with the converted values.
* @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
* @private
*/
function toI64Tensor(items) {
if (items instanceof Tensor) {
return items;
}
// items is an array
if (items.length === 0) {
throw Error("items must be non-empty");
}
if (Array.isArray(items[0])) {
// batched
if (items.some(x => x.length !== items[0].length)) {
throw Error("Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' and/or 'truncation=True' to have batched tensors with the same length.")
}
return new Tensor('int64',
BigInt64Array.from(items.flat().map(x => BigInt(x))),
[items.length, items[0].length]
);
} else {
//flat
return new Tensor('int64',
BigInt64Array.from(items.map(x => BigInt(x))),
[1, items.length]
);
}
}
/**
* Creates a boolean tensor with a single value.
* @param {boolean} value The value of the tensor.
* @returns {Tensor} The boolean tensor.
* @private
*/
function boolTensor(value) {
return new Tensor('bool', [value], [1]);
}
// JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in
/**
* Perform forward pass on the seq2seq model (both encoder and decoder).
* @param {Object} self The seq2seq model object.
* @param {Object} model_inputs The input object for the model containing encoder and decoder inputs.
* @returns {Promise<Seq2SeqLMOutput>} Promise that resolves with the output of the seq2seq model.
* @private
*/
async function seq2seqForward(self, model_inputs) {
let { encoder_outputs, input_ids, decoder_input_ids, ...other_decoder_inputs } = model_inputs;
// Encode if needed
if (!encoder_outputs) {
const encoder_inputs = pick(model_inputs, self.sessions['model'].inputNames);
// Encoder outputs are not given, so we must compute them.
encoder_outputs = (await encoderForward(self, encoder_inputs)).last_hidden_state;
}
other_decoder_inputs.input_ids = decoder_input_ids;
other_decoder_inputs.encoder_hidden_states = encoder_outputs;
if (self.sessions['decoder_model_merged'].inputNames.includes('encoder_attention_mask')) {
other_decoder_inputs.encoder_attention_mask = model_inputs.attention_mask
}
const decoderResults = await decoderForward(self, other_decoder_inputs, true);
return decoderResults;
}
/**
* Forward pass of an encoder model.
* @param {Object} self The encoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @returns {Promise<Object>} The model's outputs.
* @private
*/
async function encoderForward(self, model_inputs) {
const session = self.sessions['model'];
const encoderFeeds = pick(model_inputs, session.inputNames);
if (session.inputNames.includes('inputs_embeds') && !encoderFeeds.inputs_embeds) {
if (!model_inputs.input_ids) {
throw new Error('Both `input_ids` and `inputs_embeds` are missing in the model inputs.');
}
encoderFeeds.inputs_embeds = await self.encode_text({ input_ids: model_inputs.input_ids });
}
if (session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
if (!encoderFeeds.input_ids) {
throw new Error('Both `input_ids` and `token_type_ids` are missing in the model inputs.');
}
// Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it,
// but they weren't created by the tokenizer.
encoderFeeds.token_type_ids = zeros_like(encoderFeeds.input_ids);
}
if (session.inputNames.includes('pixel_mask') && !encoderFeeds.pixel_mask) {
if (!encoderFeeds.pixel_values) {
throw new Error('Both `pixel_values` and `pixel_mask` are missing in the model inputs.');
}
// Assign default `pixel_mask` (all ones) to the `encoderFeeds` if the model expects it,
// but they weren't created by the processor.
const dims = encoderFeeds.pixel_values.dims;
encoderFeeds.pixel_mask = ones([dims[0], dims[2], dims[3]]);
}
return await sessionRun(session, encoderFeeds);
}
async function autoEncoderForward(self, model_inputs) {
const encoded = await self.encode(model_inputs);
const decoded = await self.decode(encoded);
return decoded;
}
/**
* Forward pass of a decoder model.
* @param {Object} self The decoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @returns {Promise<Object>} The logits and past key values.
* @private
*/
async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
const session = self.sessions[
is_encoder_decoder ? 'decoder_model_merged' : 'model'
]
const { past_key_values, ...new_model_inputs } = model_inputs;
if (session.inputNames.includes('use_cache_branch')) {
new_model_inputs.use_cache_branch = boolTensor(!!past_key_values);
}
if (session.inputNames.includes('position_ids') && new_model_inputs.attention_mask && !new_model_inputs.position_ids) {
// NOTE: Handle a special case for paligemma/gemma3 models, where positions are 1-indexed
const start_index = ['paligemma', 'gemma3_text', 'gemma3'].includes(self.config.model_type) ? 1 : 0;
new_model_inputs.position_ids = createPositionIds(new_model_inputs, past_key_values, start_index);
}
// Unpack the `past_key_values` object into model inputs
self.addPastKeyValues(new_model_inputs, past_key_values);
// Select only the inputs that are needed for the current session
const fixed = pick(new_model_inputs, session.inputNames);
return await sessionRun(session, fixed);
}
function default_merge_input_ids_with_features({
modality_token_id,
inputs_embeds,
modality_features,
input_ids,
attention_mask,
}) {
const token_positions = input_ids.tolist().map(ids =>
ids.reduce((acc, x, idx) => {
if (x == modality_token_id) acc.push(idx);
return acc;
}, [])
);
const n_tokens = token_positions.reduce((acc, x) => acc + x.length, 0);
const n_features = modality_features.dims[0];
if (n_tokens !== n_features) {
throw new Error(`Number of tokens and features do not match: tokens: ${n_tokens}, features ${n_features}`);
}
// Equivalent to performing a masked_scatter
let img = 0;
for (let i = 0; i < token_positions.length; ++i) {
const tokens = token_positions[i];
const embeds = inputs_embeds[i];
for (let j = 0; j < tokens.length; ++j) {
embeds[tokens[j]].data.set(modality_features[img++].data)
}
}
return { inputs_embeds, attention_mask }
}
function default_merge_input_ids_with_image_features({
image_token_id,
inputs_embeds,
image_features,
input_ids,
attention_mask,
}) {
return default_merge_input_ids_with_features({
modality_token_id: image_token_id,
inputs_embeds,
modality_features: image_features,
input_ids,
attention_mask,
})
}
function default_merge_input_ids_with_audio_features({
audio_token_id,
inputs_embeds,
audio_features,
input_ids,
attention_mask,
}) {
return default_merge_input_ids_with_features({
modality_token_id: audio_token_id,
inputs_embeds,
modality_features: audio_features,
input_ids,
attention_mask,
})
}
/**
* Abstract forward pass function for image-text-to-text or audio-text-to-text models.
* @param {Object} self The model object.
* @param {Object} params Additional parameters.
* @param {Function} [params.encode_function] The function to encode the modality values.
* @param {Function} [params.merge_function] The function to merge the modality features with the input embeddings.
* @param {string} [params.modality_input_name] The modality input name.
* @param {string} [params.modality_output_name] The modality output name.
* @param {Tensor} [params.input_ids=null]
* @param {Tensor} [params.attention_mask=null]
* @param {Tensor} [params.position_ids=null]
* @param {Tensor} [params.inputs_embeds=null]
* @param {Tensor} [params.past_key_values=null]
* @param {Object} [params.generation_config=null]
* @param {Object} [params.logits_processor=null]
* @returns {Promise<Tensor>} The model's output tensor
* @private
*/
async function genericTextToTextForward(self, {
// Generic parameters:
encode_function,
merge_function,
modality_input_name,
modality_output_name,
// Produced by the tokenizer/processor:
input_ids = null,
attention_mask = null,
// Used during generation:
position_ids = null,
inputs_embeds = null,
past_key_values = null,
// Generic generation parameters
generation_config = null,
logits_processor = null,
// Additional parameters
...kwargs
}) {
const modality_values = kwargs[modality_input_name];
if (!inputs_embeds) {
// 1. Extract the text embeddings.
inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
// 2. Possibly, merge text and modality values
if (modality_values && input_ids.dims[1] !== 1) {
const modality_features = await encode_function({
// Pass the modality values under its expected key.
// The caller knows whether this is audio or image.
[modality_input_name]: modality_values,
...kwargs
});
({ inputs_embeds, attention_mask } = merge_function({
[modality_output_name]: modality_features,
inputs_embeds,
input_ids,
attention_mask,
}));
} else if (past_key_values && modality_values && input_ids.dims[1] === 1) {
// This branch handles the cache case.
const target_length = input_ids.dims[1]; // always 1
const past_length = Object.values(past_key_values)[0].dims.at(-2);
attention_mask = cat([
ones([input_ids.dims[0], past_length]),
attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]),
], 1);
}
}
if (!position_ids) {
if (self.config.model_type === 'qwen2_vl') {
// Special case for qwen2_vl models
// @ts-ignore
const { image_grid_thw, video_grid_thw } = kwargs;
[position_ids] = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
}
}
// 3. Call the decoder forward using the updated inputs.
const outputs = await decoderForward(self, {
inputs_embeds,
past_key_values,
attention_mask,
position_ids,
generation_config,
logits_processor,
}, true);
return outputs;
}
/**
* Forward pass of an audio-text-to-text model.
* @param {Object} self The audio-text-to-text model.
* @param {Object} params The inputs for the audio-text-to-text forward pass.
* @returns {Promise<Tensor>} The model's output tensor.
* @private
*/
async function audioTextToTextForward(self, params) {
return await genericTextToTextForward(self, {
...params,
modality_input_name: 'audio_values',
modality_output_name: 'audio_features',
encode_function: self.encode_audio.bind(self),
merge_function: self._merge_input_ids_with_audio_features.bind(self),
});
}
/**
* Forward pass of an image-text-to-text model.
* @param {Object} self The image-text-to-text model.
* @param {Object} params The inputs for the image-text-to-text forward pass.
* @returns {Promise<Tensor>} The model's output tensor.
* @private
*/
async function imageTextToTextForward(self, params) {
return await genericTextToTextForward(self, {
...params,
modality_input_name: 'pixel_values',
modality_output_name: 'image_features',
encode_function: self.encode_image.bind(self),
merge_function: self._merge_input_ids_with_image_features.bind(self),
});
}
/**
* Helper function to perform the following:
* ```python
* x = attention_mask.long().cumsum(-1) - 1
* x.masked_fill_(attention_mask == 0, 1)
* ```
* @param {Tensor} attention_mask
* @returns {{data: BigInt64Array, dims: number[]}}
*/
function cumsum_masked_fill(attention_mask, start_index = 0) {
const [bz, seq_len] = attention_mask.dims;
const attn_mask_data = attention_mask.data;
const data = new BigInt64Array(attn_mask_data.length);
for (let i = 0; i < bz; ++i) {
const start = i * seq_len;
let sum = BigInt(start_index);
for (let j = 0; j < seq_len; ++j) {
const index = start + j;
if (attn_mask_data[index] === 0n) {
data[index] = BigInt(1);
} else { // === 1n
data[index] = sum;
sum += attn_mask_data[index];
}
}
}
return { data, dims: attention_mask.dims };
}
/**
* If the model supports providing position_ids, we create position_ids on the fly for batch generation,
* by computing the cumulative sum of the attention mask along the sequence length dimension.
*
* Equivalent to:
* ```python
* position_ids = attention_mask.long().cumsum(-1) - 1
* position_ids.masked_fill_(attention_mask == 0, 1)
* if past_key_values:
* position_ids = position_ids[:, -input_ids.shape[1] :]
* ```
*/
function createPositionIds(model_inputs, past_key_values = null, start_index = 0) {
const { input_ids, inputs_embeds, attention_mask } = model_inputs;
const { data, dims } = cumsum_masked_fill(attention_mask, start_index);
let position_ids = new Tensor('int64', data, dims);
if (past_key_values) {
const offset = -(input_ids ?? inputs_embeds).dims.at(1);
position_ids = position_ids.slice(null, [offset, null]);
}
return position_ids;
}
function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
if (model_inputs.past_key_values) {
const past_length = Object.values(model_inputs.past_key_values)[0].dims.at(-2);
const { input_ids, attention_mask } = model_inputs;
// Keep only the unprocessed tokens:
// 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
// some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
// input)
if (attention_mask && attention_mask.dims[1] > input_ids.dims[1]) {
// NOTE: not needed since we only pass the generated tokens to the next forward pass
// const offset = -(attention_mask.dims[1] - past_length);
// model_inputs.input_ids = input_ids.slice(null, [offset, null]);
}
// 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
// We can discard input_ids based on the past_length.
else if (past_length < input_ids.dims[1]) {
// NOTE: Required for phi models.
// See https://github.com/huggingface/transformers/issues/30809#issuecomment-2111918479 for more information.
model_inputs.input_ids = input_ids.slice(null, [past_length, null]);
}
// 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
else {
if (
// NOTE: Only used by VLMs (!= so that null matches undefined)
self.config.image_token_index != null &&
// Equivalent to `self.config.image_token_index in input_ids` (== so that int matches bigint)
input_ids.data.some(x => x == self.config.image_token_index)
) {
// TODO: Support multiple image tokens
const num_image_tokens = self.config.num_image_tokens;
if (!num_image_tokens) {
throw new Error('`num_image_tokens` is missing in the model configuration.');
}
const num_new_tokens = input_ids.dims[1] - (past_length - num_image_tokens);
model_inputs.input_ids = input_ids.slice(null, [-num_new_tokens, null]);
// TODO: The attention mask should be formed from the attention mask passed in model_inputs
model_inputs.attention_mask = ones([1, past_length + num_new_tokens]);
}
}
}
return model_inputs;
}
function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
if (model_inputs.past_key_values) {
input_ids = input_ids.map(x => [x.at(-1)]);
}
return {
...model_inputs,
decoder_input_ids: toI64Tensor(input_ids),
};
}
function multimodal_text_to_text_prepare_inputs_for_generation(self, ...args) {
if (self.config.is_encoder_decoder) {
return encoder_decoder_prepare_inputs_for_generation(self, ...args);
} else {
return decoder_prepare_inputs_for_generation(self, ...args);
}
}
function multimodality_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
const has_past_key_values = !!model_inputs.past_key_values;
if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
if (has_past_key_values) {
model_inputs.input_ids = cat([
model_inputs.input_ids,
model_inputs.input_ids,
], 0)
// NOTE: attention_mask handled in generation
} else {
model_inputs.input_ids = cat([
model_inputs.input_ids,
full_like(model_inputs.input_ids, BigInt(generation_config.pad_token_id)),
], 0);
model_inputs.attention_mask = cat([
model_inputs.attention_mask,
full_like(model_inputs.attention_mask, 0n),
], 0);
}
}
if (has_past_key_values || !model_inputs.pixel_values) {
model_inputs.pixel_values = full([0, 0, 3, 384, 384], 1.0);
}
if (has_past_key_values) {
const num_img_tokens = 0;
const num_text_tokens = 1;
const has_image = num_img_tokens > 0 ? 1 : 0;
const batch_size = 1;
model_inputs.images_seq_mask = new Tensor(
'bool',
new Array(num_img_tokens + num_text_tokens).fill(true).fill(false, 0, num_text_tokens),
[batch_size, num_img_tokens + num_text_tokens],
);
model_inputs.images_emb_mask = new Tensor(
'bool',
new Array(num_img_tokens).fill(!!has_image),
[batch_size, 1, num_img_tokens],
);
}
return model_inputs;
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
/**
* A base class for pre-trained models that provides the model configuration and an ONNX session.
*/
export class PreTrainedModel extends Callable {
main_input_name = 'input_ids';
forward_params = ['input_ids', 'attention_mask'];
/**
* Creates a new instance of the `PreTrainedModel` class.
* @param {import('./configs.js').PretrainedConfig} config The model configuration.
* @param {Record<string, any>} sessions The inference sessions for the model.
* @param {Record<string, Object>} configs Additional configuration files (e.g., generation_config.json).
*/
constructor(config, sessions, configs) {
super();
this.config = config;
this.sessions = sessions;
this.configs = configs;
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
this.can_generate = false;
this._forward = null;
this._prepare_inputs_for_generation = null;
switch (modelType) {
case MODEL_TYPES.DecoderOnly:
this.can_generate = true;
this._forward = decoderForward;
this._prepare_inputs_for_generation = decoder_prepare_inputs_for_generation;
break;
case MODEL_TYPES.Seq2Seq:
case MODEL_TYPES.Vision2Seq:
case MODEL_TYPES.Musicgen:
this.can_generate = true;
this._forward = seq2seqForward;
this._prepare_inputs_for_generation = encoder_decoder_prepare_inputs_for_generation;
break;
case MODEL_TYPES.EncoderDecoder:
this._forward = seq2seqForward;
break;
case MODEL_TYPES.ImageTextToText:
this.can_generate = true;
this._forward = imageTextToTextForward;
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
break;
case MODEL_TYPES.AudioTextToText:
this.can_generate = true;
this._forward = audioTextToTextForward;
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
break;
case MODEL_TYPES.Phi3V:
this.can_generate = true;
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
break;
case MODEL_TYPES.MultiModality:
this.can_generate = true;
this._prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation;
break;
case MODEL_TYPES.AutoEncoder:
this._forward = autoEncoderForward;
break;
default:
// should be MODEL_TYPES.EncoderOnly
this._forward = encoderForward;
break;
}
if (this.can_generate) {
this.forward_params.push('past_key_values');
}
/** @type {import('./configs.js').TransformersJSConfig} */
this.custom_config = this.config['transformers.js_config'] ?? {};
}
/**
* Disposes of all the ONNX sessions that were created during inference.
* @returns {Promise<unknown[]>} An array of promises, one for each ONNX session that is being disposed.
* @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry
*/
async dispose() {
const promises = [];
for (const session of Object.values(this.sessions)) {
if (session?.handler?.dispose) {
promises.push(session.handler.dispose())
}
}
return await Promise.all(promises);
}
/**
* Instantiate one of the model classes of the library from a pretrained model.
*
* The model class to instantiate is selected based on the `model_type` property of the config object
* (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible)
*
* @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either:
* - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing model weights, e.g., `./my_model_directory/`.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
*
* @returns {Promise<PreTrainedModel>} A new instance of the `PreTrainedModel` class.
*/
static async from_pretrained(pretrained_model_name_or_path, {
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
model_file_name = null,
subfolder = 'onnx',
device = null,
dtype = null,
use_external_data_format = null,
session_options = {},
} = {}) {
let options = {
progress_callback,
config,
cache_dir,
local_files_only,
revision,
model_file_name,
subfolder,
device,
dtype,
use_external_data_format,
session_options,
}
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
let info;
if (modelType === MODEL_TYPES.DecoderOnly) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.MaskGeneration) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'vision_encoder',
prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder',
}, options),
]);
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
]);
} else if (modelType === MODEL_TYPES.ImageTextToText) {
const sessions = {
embed_tokens: 'embed_tokens',
vision_encoder: 'vision_encoder',
decoder_model_merged: 'decoder_model_merged',
}
if (config.is_encoder_decoder) {
sessions['model'] = 'encoder_model';
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.AudioTextToText) {
const sessions = {
embed_tokens: 'embed_tokens',
audio_encoder: 'audio_encoder',
decoder_model_merged: 'decoder_model_merged',
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.Musicgen) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'text_encoder',
decoder_model_merged: 'decoder_model_merged',
encodec_decode: 'encodec_decode',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.MultiModality) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
prepare_inputs_embeds: 'prepare_inputs_embeds',
model: 'language_model',
lm_head: 'lm_head',
gen_head: 'gen_head',
gen_img_embeds: 'gen_img_embeds',
image_decode: 'image_decode',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.Phi3V) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
prepare_inputs_embeds: 'prepare_inputs_embeds',
model: 'model',
vision_encoder: 'vision_encoder',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.AutoEncoder) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
encoder_model: 'encoder_model',
decoder_model: 'decoder_model',
}, options),
]);
} else { // should be MODEL_TYPES.EncoderOnly
if (modelType !== MODEL_TYPES.EncoderOnly) {
const type = modelName ?? config?.model_type;
if (type !== 'custom') {
console.warn(`Model type for '${type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
}
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
]);
}
// @ts-ignore
return new this(config, ...info);
}
/**
* Runs the model with the provided inputs
* @param {Object} model_inputs Object containing input tensors
* @returns {Promise<Object>} Object containing output tensors
*/
async _call(model_inputs) {
return await this.forward(model_inputs);
}
/**
* Forward method for a pretrained model. If not overridden by a subclass, the correct forward method
* will be chosen based on the model type.
* @param {Object} model_inputs The input data to the model in the format specified in the ONNX model.
* @