@huggingface/transformers
Version:
State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
1,236 lines (1,080 loc) • 282 kB
JavaScript
/**
* @file Definitions of all models available in Transformers.js.
*
* **Example:** Load and run an `AutoModel`.
*
* ```javascript
* import { AutoModel, AutoTokenizer } from '@huggingface/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
* let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased');
*
* let inputs = await tokenizer('I love transformers!');
* let { logits } = await model(inputs);
* // Tensor {
* // data: Float32Array(183132) [-7.117443084716797, -7.107812881469727, -7.092104911804199, ...]
* // dims: (3) [1, 6, 30522],
* // type: "float32",
* // size: 183132,
* // }
* ```
*
* We also provide other `AutoModel`s (listed below), which you can use in the same way as the Python library. For example:
*
* **Example:** Load and run an `AutoModelForSeq2SeqLM`.
* ```javascript
* import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@huggingface/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');
* let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small');
*
* let { input_ids } = await tokenizer('translate English to German: I love transformers!');
* let outputs = await model.generate(input_ids);
* let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true });
* // 'Ich liebe Transformatoren!'
* ```
*
* @module models
*/
import {
AutoConfig,
getKeyValueShapes,
} from './configs.js';
import {
deviceToExecutionProviders,
createInferenceSession,
isONNXTensor,
isONNXProxy,
} from './backends/onnx.js';
import {
DATA_TYPES,
DEFAULT_DEVICE_DTYPE_MAPPING,
DEFAULT_DTYPE_SUFFIX_MAPPING,
isWebGpuFp16Supported,
} from './utils/dtypes.js';
import {
Callable,
} from './utils/generic.js';
import {
isIntegralNumber,
mergeArrays,
pick,
} from './utils/core.js';
import {
getModelFile,
getModelJSON,
} from './utils/hub.js';
import {
LogitsProcessorList,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
WhisperTimeStampLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
NoBadWordsLogitsProcessor,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
ClassifierFreeGuidanceLogitsProcessor,
} from './generation/logits_process.js';
import {
GenerationConfig,
} from './generation/configuration_utils.js';
import {
cat,
full_like,
mean,
ones,
ones_like,
stack,
std_mean,
Tensor,
zeros_like,
} from './utils/tensor.js';
import { dynamic_time_warping, medianFilter } from './utils/maths.js';
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
import { LogitsSampler } from './generation/logits_sampler.js';
import { apis } from './env.js';
import { WhisperGenerationConfig } from './models/whisper/generation_whisper.js';
import { whisper_language_to_code } from './models/whisper/common_whisper.js';
//////////////////////////////////////////////////
// Model types: used internally
const MODEL_TYPES = {
EncoderOnly: 0,
EncoderDecoder: 1,
Seq2Seq: 2,
Vision2Seq: 3,
DecoderOnly: 4,
MaskGeneration: 5,
ImageTextToText: 6,
Musicgen: 7,
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// Helper functions
// NOTE: These will be populated fully later
const MODEL_TYPE_MAPPING = new Map();
const MODEL_NAME_TO_CLASS_MAPPING = new Map();
const MODEL_CLASS_TO_NAME_MAPPING = new Map();
/**
* Constructs an InferenceSession using a model file located at the specified path.
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @returns {Promise<{buffer: Uint8Array, session_options: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
* @private
*/
async function getSession(pretrained_model_name_or_path, fileName, options) {
const custom_config = options.config?.['transformers.js_config'] ?? {};
let device = options.device ?? custom_config.device;
if (device && typeof device !== 'string') {
if (device.hasOwnProperty(fileName)) {
device = device[fileName];
} else {
console.warn(`device not specified for "${fileName}". Using the default device.`);
device = null;
}
}
// If the device is not specified, we use the default (supported) execution providers.
const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */(
device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm')
);
const executionProviders = deviceToExecutionProviders(selectedDevice);
// If options.dtype is specified, we use it to choose the suffix for the model file.
// Otherwise, we use the default dtype for the device.
let dtype = options.dtype ?? custom_config.dtype;
if (typeof dtype !== 'string') {
if (dtype && dtype.hasOwnProperty(fileName)) {
dtype = dtype[fileName];
} else {
dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32;
console.warn(`dtype not specified for "${fileName}". Using the default dtype (${dtype}) for this device (${selectedDevice}).`);
}
}
const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype);
if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) {
throw new Error(`Invalid dtype: ${selectedDtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`);
} else if (selectedDtype === DATA_TYPES.fp16 && selectedDevice === 'webgpu' && !(await isWebGpuFp16Supported())) {
throw new Error(`The device (${selectedDevice}) does not support fp16.`);
}
// Construct the model file name
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
const session_options = { ...options.session_options } ?? {};
// Overwrite `executionProviders` if not specified
session_options.executionProviders ??= executionProviders;
// Overwrite `freeDimensionOverrides` if specified in config and not set in session options
const free_dimension_overrides = custom_config.free_dimension_overrides;
if (free_dimension_overrides) {
session_options.freeDimensionOverrides ??= free_dimension_overrides;
} else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
console.warn(
'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' +
'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
);
}
const bufferPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
// handle onnx external data files
/** @type {Promise<{path: string, data: Uint8Array}>[]} */
let externalDataPromises = [];
if (options.use_external_data_format && (
options.use_external_data_format === true ||
(
typeof options.use_external_data_format === 'object' &&
options.use_external_data_format.hasOwnProperty(fileName) &&
options.use_external_data_format[fileName] === true
)
)) {
if (apis.IS_NODE_ENV) {
throw new Error('External data format is not yet supported in Node.js');
}
const path = `${fileName}${suffix}.onnx_data`;
const fullPath = `${options.subfolder ?? ''}/${path}`;
externalDataPromises.push(new Promise(async (resolve, reject) => {
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options);
resolve({ path, data })
}));
} else if (session_options.externalData !== undefined) {
externalDataPromises = session_options.externalData.map(async (ext) => {
// if the external data is a string, fetch the file and replace the string with its content
if (typeof ext.data === "string") {
const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options);
return { ...ext, data: ext_buffer };
}
return ext;
});
}
if (externalDataPromises.length > 0) {
session_options.externalData = await Promise.all(externalDataPromises);
}
if (selectedDevice === 'webgpu') {
const shapes = getKeyValueShapes(options.config, {
prefix: 'present',
});
if (Object.keys(shapes).length > 0 && !isONNXProxy()) {
// Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX
/** @type {Record<string, import('onnxruntime-common').Tensor.DataLocation>} */
const preferredOutputLocation = {};
for (const key in shapes) {
preferredOutputLocation[key] = 'gpu-buffer';
}
session_options.preferredOutputLocation = preferredOutputLocation;
}
}
const buffer = await bufferPromise;
return { buffer, session_options };
}
/**
* Helper function to create multiple InferenceSession objects.
*
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {Record<string, string>} names The names of the model files to load.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of InferenceSession objects.
* @private
*/
async function constructSessions(pretrained_model_name_or_path, names, options) {
return Object.fromEntries(await Promise.all(
Object.keys(names).map(async (name) => {
const { buffer, session_options } = await getSession(pretrained_model_name_or_path, names[name], options);
const session = await createInferenceSession(buffer, session_options);
return [name, session];
})
));
}
/**
* Validate model inputs
* @param {Object} session The InferenceSession object that will be run.
* @param {Object} inputs The inputs to check.
* @returns {Record<string, Tensor>} The checked inputs.
* @throws {Error} If any inputs are missing.
* @private
*/
function validateInputs(session, inputs) {
/**
* NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy`
* @type {Record<string, Tensor>}
*/
const checkedInputs = Object.create(null);
const missingInputs = [];
for (const inputName of session.inputNames) {
const tensor = inputs[inputName];
// Rare case where one of the model's input names corresponds to a built-in
// object name (e.g., toString), which would cause a simple (!tensor) check to fail,
// because it's not undefined but a function.
if (!(tensor instanceof Tensor)) {
missingInputs.push(inputName);
continue;
}
// NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker
// boundary, transferring ownership to the worker and invalidating the tensor.
// So, in this case, we simply sacrifice a clone for it.
checkedInputs[inputName] = isONNXProxy() ? tensor.clone() : tensor;
}
if (missingInputs.length > 0) {
throw new Error(
`An error occurred during model execution: "Missing the following inputs: ${missingInputs.join(', ')}.`);
}
const numInputsProvided = Object.keys(inputs).length;
const numInputsNeeded = session.inputNames.length;
if (numInputsProvided > numInputsNeeded) {
// No missing inputs, but too many inputs were provided.
// Warn the user and ignore the extra inputs.
let ignored = Object.keys(inputs).filter(inputName => !session.inputNames.includes(inputName));
console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(', ')}".`);
}
return checkedInputs;
}
/**
* Executes an InferenceSession using the specified inputs.
* NOTE: `inputs` must contain at least the input names of the model.
* - If additional inputs are passed, they will be ignored.
* - If inputs are missing, an error will be thrown.
*
* @param {Object} session The InferenceSession object to run.
* @param {Object} inputs An object that maps input names to input tensors.
* @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors.
* @private
*/
async function sessionRun(session, inputs) {
const checkedInputs = validateInputs(session, inputs);
try {
// pass the original ort tensor
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
let output = await session.run(ortFeed);
output = replaceTensors(output);
return output;
} catch (e) {
// This usually occurs when the inputs are of the wrong type.
console.error(`An error occurred during model execution: "${e}".`);
console.error('Inputs given to model:', checkedInputs);
throw e;
}
}
/**
* Replaces ONNX Tensor objects with custom Tensor objects to support additional functions.
* @param {Object} obj The object to replace tensor objects in.
* @returns {Object} The object with tensor objects replaced by custom Tensor objects.
* @private
*/
function replaceTensors(obj) {
for (let prop in obj) {
if (isONNXTensor(obj[prop])) {
obj[prop] = new Tensor(obj[prop]);
} else if (typeof obj[prop] === 'object') {
replaceTensors(obj[prop]);
}
}
return obj;
}
/**
* Converts an array or Tensor of integers to an int64 Tensor.
* @param {Array|Tensor} items The input integers to be converted.
* @returns {Tensor} The int64 Tensor with the converted values.
* @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
* @private
*/
function toI64Tensor(items) {
if (items instanceof Tensor) {
return items;
}
// items is an array
if (items.length === 0) {
throw Error("items must be non-empty");
}
if (Array.isArray(items[0])) {
// batched
if (items.some(x => x.length !== items[0].length)) {
throw Error("Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' and/or 'truncation=True' to have batched tensors with the same length.")
}
return new Tensor('int64',
BigInt64Array.from(items.flat().map(x => BigInt(x))),
[items.length, items[0].length]
);
} else {
//flat
return new Tensor('int64',
BigInt64Array.from(items.map(x => BigInt(x))),
[1, items.length]
);
}
}
/**
* Creates a boolean tensor with a single value.
* @param {boolean} value The value of the tensor.
* @returns {Tensor} The boolean tensor.
* @private
*/
function boolTensor(value) {
return new Tensor('bool', [value], [1]);
}
// JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in
/**
* Perform forward pass on the seq2seq model (both encoder and decoder).
* @param {Object} self The seq2seq model object.
* @param {Object} model_inputs The input object for the model containing encoder and decoder inputs.
* @returns {Promise<Seq2SeqLMOutput>} Promise that resolves with the output of the seq2seq model.
* @private
*/
async function seq2seqForward(self, model_inputs) {
let { encoder_outputs, input_ids, decoder_input_ids, ...other_decoder_inputs } = model_inputs;
// Encode if needed
if (!encoder_outputs) {
const encoder_inputs = pick(model_inputs, self.sessions['model'].inputNames);
// Encoder outputs are not given, so we must compute them.
encoder_outputs = (await encoderForward(self, encoder_inputs)).last_hidden_state;
}
other_decoder_inputs.input_ids = decoder_input_ids;
other_decoder_inputs.encoder_hidden_states = encoder_outputs;
if (self.sessions['decoder_model_merged'].inputNames.includes('encoder_attention_mask')) {
other_decoder_inputs.encoder_attention_mask = model_inputs.attention_mask
}
const decoderResults = await decoderForward(self, other_decoder_inputs, true);
return decoderResults;
}
/**
* Forward pass of an encoder model.
* @param {Object} self The encoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @returns {Promise<Object>} The model's outputs.
* @private
*/
async function encoderForward(self, model_inputs) {
const session = self.sessions['model'];
const encoderFeeds = pick(model_inputs, session.inputNames);
if (session.inputNames.includes('inputs_embeds') && !encoderFeeds.inputs_embeds) {
if (!model_inputs.input_ids) {
throw new Error('Both `input_ids` and `inputs_embeds` are missing in the model inputs.');
}
encoderFeeds.inputs_embeds = await self.encode_text({ input_ids: model_inputs.input_ids });
}
if (session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
// Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it,
// but they weren't created by the tokenizer.
encoderFeeds.token_type_ids = new Tensor(
'int64',
new BigInt64Array(encoderFeeds.input_ids.data.length),
encoderFeeds.input_ids.dims
)
}
return await sessionRun(session, encoderFeeds);
}
/**
* Forward pass of a decoder model.
* @param {Object} self The decoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @returns {Promise<Object>} The logits and past key values.
* @private
*/
async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
const session = self.sessions[
is_encoder_decoder ? 'decoder_model_merged' : 'model'
]
const { past_key_values, ...new_model_inputs } = model_inputs;
if (session.inputNames.includes('use_cache_branch')) {
new_model_inputs.use_cache_branch = boolTensor(!!past_key_values);
}
if (session.inputNames.includes('position_ids') && new_model_inputs.attention_mask && !new_model_inputs.position_ids) {
new_model_inputs.position_ids = createPositionIds(new_model_inputs, past_key_values);
}
// Unpack the `past_key_values` object into model inputs
self.addPastKeyValues(new_model_inputs, past_key_values);
// Select only the inputs that are needed for the current session
const fixed = pick(new_model_inputs, session.inputNames);
return await sessionRun(session, fixed);
}
/**
* Forward pass of an image-text-to-text model.
* @param {Object} self The image-text-to-text model model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @param {Tensor} [model_inputs.input_ids=null]
* @param {Tensor} [model_inputs.attention_mask=null]
* @param {Tensor} [model_inputs.pixel_values=null]
* @param {Tensor} [model_inputs.position_ids=null]
* @param {Tensor} [model_inputs.inputs_embeds=null]
* @param {Tensor} [model_inputs.past_key_values=null]
* @param {Object} [model_inputs.generation_config=null]
* @param {Object} [model_inputs.logits_processor=null]
* @returns {Promise<Tensor>} The model's output tensor
* @private
*/
async function imageTextToTextForward(self, {
// Produced by the tokenizer/processor:
input_ids = null,
attention_mask = null,
pixel_values = null,
// Used during generation:
position_ids = null,
inputs_embeds = null,
past_key_values = null,
// Generic generation parameters
generation_config = null,
logits_processor = null,
// TODO: needed?
...kwargs
}) {
if (!inputs_embeds) {
// 1. Extract the input embeddings
inputs_embeds = await self.encode_text({ input_ids });
// 2. Possibly, merge text and images
if (pixel_values && input_ids.dims[1] !== 1) {
const image_features = await self.encode_image({ pixel_values });
({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
image_features,
inputs_embeds,
input_ids,
attention_mask,
}));
} else if (past_key_values && pixel_values && input_ids.dims[1] === 1) {
// This is the case when we are generating with cache
const target_length = input_ids.dims[1]; // always 1
const past_length = Object.values(past_key_values)[0].dims.at(-2);
attention_mask = cat([
ones([input_ids.dims[0], past_length]),
attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]),
], 1);
}
}
const outputs = await decoderForward(self, {
inputs_embeds,
past_key_values,
attention_mask,
position_ids,
generation_config,
logits_processor,
}, true);
return outputs;
}
function createPositionIds(model_inputs, past_key_values = null) {
// If the model supports providing position_ids, we create position_ids on the fly for batch generation,
// by computing the cumulative sum of the attention mask along the sequence length dimension.
//
// Equivalent to:
// position_ids = attention_mask.long().cumsum(-1) - 1
// position_ids.masked_fill_(attention_mask == 0, 1)
// if past_key_values:
// position_ids = position_ids[:, -input_ids.shape[1] :]
const { input_ids, inputs_embeds, attention_mask } = model_inputs;
const [bz, seq_len] = attention_mask.dims;
const data = new BigInt64Array(attention_mask.data.length);
for (let i = 0; i < bz; ++i) {
const start = i * seq_len;
let sum = BigInt(0);
for (let j = 0; j < seq_len; ++j) {
const index = start + j;
if (attention_mask.data[index] === 0n) {
data[index] = BigInt(1);
} else { // === 1n
data[index] = sum;
sum += attention_mask.data[index];
}
}
}
let position_ids = new Tensor('int64', data, attention_mask.dims);
if (past_key_values) {
const offset = -(input_ids ?? inputs_embeds).dims.at(1);
position_ids = position_ids.slice(null, [offset, null]);
}
return position_ids;
}
function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
if (model_inputs.past_key_values) {
const past_length = Object.values(model_inputs.past_key_values)[0].dims.at(-2);
const { input_ids, attention_mask } = model_inputs;
// Keep only the unprocessed tokens:
// 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
// some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
// input)
if (attention_mask && attention_mask.dims[1] > input_ids.dims[1]) {
// NOTE: not needed since we only pass the generated tokens to the next forward pass
// const offset = -(attention_mask.dims[1] - past_length);
// model_inputs.input_ids = input_ids.slice(null, [offset, null]);
}
// 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
// We can discard input_ids based on the past_length.
else if (past_length < input_ids.dims[1]) {
// NOTE: Required for phi models.
// See https://github.com/huggingface/transformers/issues/30809#issuecomment-2111918479 for more information.
model_inputs.input_ids = input_ids.slice(null, [past_length, null]);
}
// 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
else {
if (
// NOTE: Only used by VLMs (!= so that null matches undefined)
self.config.image_token_index != null &&
// Equivalent to `self.config.image_token_index in input_ids` (== so that int matches bigint)
input_ids.data.some(x => x == self.config.image_token_index)
) {
// TODO: Support multiple image tokens
const num_image_tokens = self.config.num_image_tokens;
if (!num_image_tokens) {
throw new Error('`num_image_tokens` is missing in the model configuration.');
}
const num_new_tokens = input_ids.dims[1] - (past_length - num_image_tokens);
model_inputs.input_ids = input_ids.slice(null, [-num_new_tokens, null]);
// TODO: The attention mask should be formed from the attention mask passed in model_inputs
model_inputs.attention_mask = ones([1, past_length + num_new_tokens]);
}
}
}
return model_inputs;
}
function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
if (model_inputs.past_key_values) {
input_ids = input_ids.map(x => [x.at(-1)]);
}
return {
...model_inputs,
decoder_input_ids: toI64Tensor(input_ids),
};
}
function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
if (self.config.is_encoder_decoder) {
return encoder_decoder_prepare_inputs_for_generation(self, ...args);
} else {
return decoder_prepare_inputs_for_generation(self, ...args);
}
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
/**
* A base class for pre-trained models that provides the model configuration and an ONNX session.
*/
export class PreTrainedModel extends Callable {
main_input_name = 'input_ids';
forward_params = ['input_ids', 'attention_mask'];
/**
* Creates a new instance of the `PreTrainedModel` class.
* @param {import('./configs.js').PretrainedConfig} config The model configuration.
* @param {Record<string, any>} sessions The inference sessions for the model.
*/
constructor(config, sessions) {
super();
this.config = config;
this.sessions = sessions;
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
this.can_generate = false;
this._forward = null;
this._prepare_inputs_for_generation = null;
switch (modelType) {
case MODEL_TYPES.DecoderOnly:
this.can_generate = true;
this._forward = decoderForward;
this._prepare_inputs_for_generation = decoder_prepare_inputs_for_generation;
break;
case MODEL_TYPES.Seq2Seq:
case MODEL_TYPES.Vision2Seq:
case MODEL_TYPES.Musicgen:
this.can_generate = true;
this._forward = seq2seqForward;
this._prepare_inputs_for_generation = encoder_decoder_prepare_inputs_for_generation;
break;
case MODEL_TYPES.EncoderDecoder:
this._forward = seq2seqForward;
break;
case MODEL_TYPES.ImageTextToText:
this.can_generate = true;
this._forward = imageTextToTextForward;
this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
break;
default:
// should be MODEL_TYPES.EncoderOnly
this._forward = encoderForward;
break;
}
if (this.can_generate) {
this.forward_params.push('past_key_values');
}
/** @type {import('./configs.js').TransformersJSConfig} */
this.custom_config = this.config['transformers.js_config'] ?? {};
}
/**
* Disposes of all the ONNX sessions that were created during inference.
* @returns {Promise<unknown[]>} An array of promises, one for each ONNX session that is being disposed.
* @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry
*/
async dispose() {
const promises = [];
for (const session of Object.values(this.sessions)) {
if (session?.handler?.dispose) {
promises.push(session.handler.dispose())
}
}
return await Promise.all(promises);
}
/**
* Instantiate one of the model classes of the library from a pretrained model.
*
* The model class to instantiate is selected based on the `model_type` property of the config object
* (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible)
*
* @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either:
* - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing model weights, e.g., `./my_model_directory/`.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
*
* @returns {Promise<PreTrainedModel>} A new instance of the `PreTrainedModel` class.
*/
static async from_pretrained(pretrained_model_name_or_path, {
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
model_file_name = null,
subfolder = 'onnx',
device = null,
dtype = null,
use_external_data_format = null,
session_options = {},
} = {}) {
let options = {
progress_callback,
config,
cache_dir,
local_files_only,
revision,
model_file_name,
subfolder,
device,
dtype,
use_external_data_format,
session_options,
}
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
let info;
if (modelType === MODEL_TYPES.DecoderOnly) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);
} else if (modelType === MODEL_TYPES.MaskGeneration) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'vision_encoder',
prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder',
}, options),
]);
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
]);
} else if (modelType === MODEL_TYPES.ImageTextToText) {
const sessions = {
embed_tokens: 'embed_tokens',
vision_encoder: 'vision_encoder',
decoder_model_merged: 'decoder_model_merged',
}
if (config.is_encoder_decoder) {
sessions['model'] = 'encoder_model';
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);
} else if (modelType === MODEL_TYPES.Musicgen) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'text_encoder',
decoder_model_merged: 'decoder_model_merged',
encodec_decode: 'encodec_decode',
}, options),
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);
} else { // should be MODEL_TYPES.EncoderOnly
if (modelType !== MODEL_TYPES.EncoderOnly) {
console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`)
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
]);
}
// @ts-ignore
return new this(config, ...info);
}
/**
* Runs the model with the provided inputs
* @param {Object} model_inputs Object containing input tensors
* @returns {Promise<Object>} Object containing output tensors
*/
async _call(model_inputs) {
return await this.forward(model_inputs);
}
/**
* Forward method for a pretrained model. If not overridden by a subclass, the correct forward method
* will be chosen based on the model type.
* @param {Object} model_inputs The input data to the model in the format specified in the ONNX model.
* @returns {Promise<Object>} The output data from the model in the format specified in the ONNX model.
* @throws {Error} This method must be implemented in subclasses.
*/
async forward(model_inputs) {
return await this._forward(this, model_inputs);
}
/**
* This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`]
* instances used for multinomial sampling.
* @param {GenerationConfig} generation_config The generation config.
* @returns {LogitsProcessorList} generation_config
*/
_get_logits_warper(generation_config) {
// instantiate warpers list
const warpers = new LogitsProcessorList();
if (generation_config.temperature !== null && generation_config.temperature !== 1.0) {
warpers.push(new TemperatureLogitsWarper(generation_config.temperature));
}
if (generation_config.top_k !== null && generation_config.top_k !== 0) {
// TODO: add min_tokens_to_keep
warpers.push(new TopKLogitsWarper(generation_config.top_k));
}
if (generation_config.top_p !== null && generation_config.top_p < 1.0) {
// TODO: add min_tokens_to_keep
warpers.push(new TopPLogitsWarper(generation_config.top_p));
}
return warpers;
}
/**
* @param {GenerationConfig} generation_config
* @param {number} input_ids_seq_length The starting sequence length for the input ids.
* @returns {LogitsProcessorList}
* @private
*/
_get_logits_processor(
generation_config,
input_ids_seq_length,
// encoder_input_ids, TODO
// prefix_allowed_tokens_fn, TODO
logits_processor = null
) {
const processors = new LogitsProcessorList();
// if (generation_config.diversity_penalty !== null && generation_config.diversity_penalty > 0.0) {
// processors.push(new HammingDiversityLogitsProcessor(
// generation_config.diversity_penalty,
// generation_config.num_beams,
// generation_config.num_beam_groups
// ));
// }
// if (generation_config.encoder_repetition_penalty !== null && generation_config.encoder_repetition_penalty !== 1.0) {
// processors.push(new EncoderRepetitionPenaltyLogitsProcessor(
// generation_config.encoder_repetition_penalty,
// encoder_input_ids
// ));
// }
if (generation_config.repetition_penalty !== null && generation_config.repetition_penalty !== 1.0) {
processors.push(new RepetitionPenaltyLogitsProcessor(generation_config.repetition_penalty));
}
if (generation_config.no_repeat_ngram_size !== null && generation_config.no_repeat_ngram_size > 0) {
processors.push(new NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size));
}
// if (generation_config.encoder_no_repeat_ngram_size !== null && generation_config.encoder_no_repeat_ngram_size > 0) {
// if (this.config.is_encoder_decoder) {
// processors.push(new EncoderNoRepeatNGramLogitsProcessor(
// generation_config.encoder_no_repeat_ngram_size,
// encoder_input_ids
// ));
// } else {
// throw new Error("It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture");
// }
// }
if (generation_config.bad_words_ids !== null) {
processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
}
if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
}
if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
processors.push(new MinNewTokensLengthLogitsProcessor(
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id
));
}
// if (prefix_allowed_tokens_fn !== null) {
// processors.push(new PrefixConstrainedLogitsProcessor(
// prefix_allowed_tokens_fn,
// generation_config.num_beams / generation_config.num_beam_groups
// ));
// }
if (generation_config.forced_bos_token_id !== null) {
processors.push(new ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id));
}
if (generation_config.forced_eos_token_id !== null) {
processors.push(new ForcedEOSTokenLogitsProcessor(
generation_config.max_length,
generation_config.forced_eos_token_id
));
}
// if (generation_config.remove_invalid_values === true) {
// processors.push(new InfNanRemoveLogitsProcessor());
// }
// if (generation_config.exponential_decay_length_penalty !== null) {
// processors.push(new ExponentialDecayLengthPenalty(
// generation_config.exponential_decay_length_penalty,
// generation_config.eos_token_id,
// input_ids_seq_length
// ));
// }
// if (generation_config.suppress_tokens !== null) {
// processors.push(new SuppressTokensLogitsProcessor(generation_config.suppress_tokens));
// }
if (generation_config.begin_suppress_tokens !== null) {
const begin_index = (input_ids_seq_length > 1 || generation_config.forced_bos_token_id === null)
? input_ids_seq_length
: input_ids_seq_length + 1;
processors.push(new SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index));
}
// DEPRECATED: https://github.com/huggingface/transformers/pull/29485
// if (generation_config.forced_decoder_ids !== null) {
// processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids));
// }
// 8. prepare batched CFG externally
if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
processors.push(new ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale));
}
if (logits_processor !== null) {
processors.extend(logits_processor)
}
// `LogitNormalization` should always be the last logit processor, when present
// if (generation_config.renormalize_logits === true) {
// processors.push(new LogitNormalization());
// }
return processors;
}
/**
* This function merges multiple generation configs together to form a final generation config to be used by the model for text generation.
* It first creates an empty `GenerationConfig` object, then it applies the model's own `generation_config` property to it. Finally, if a `generation_config` object was passed in the arguments, it overwrites the corresponding properties in the final config with those of the passed config object.
* @param {GenerationConfig|null} generation_config A `GenerationConfig` object containing generation parameters.
* @param {Object} kwargs Additional generation parameters to be used in place of those in the `generation_config` object.
* @returns {GenerationConfig} The final generation config object to be used by the model for text generation.
*/
_prepare_generation_config(generation_config, kwargs, cls = GenerationConfig) {
// Create empty generation config (contains defaults)
// We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
const config = { ...this.config };
for (const key of ["decoder", "generator", "text_config"]) {
// Special case: some models have generation attributes set in the decoder.
// Use them if still unset in the generation config.
if (key in config) {
Object.assign(config, config[key]);
}
}
const gen_config = new cls(config);
// Apply model's generation config, if it exists
if ('generation_config' in this) {
Object.assign(gen_config, this.generation_config);
}
// Next, use any generation config specified by the user
// when calling `generate`
if (generation_config) {
Object.assign(gen_config, generation_config);
}
// Finally, if any kwargs were passed, use them to overwrite
if (kwargs) {
Object.assign(gen_config, pick(kwargs, Object.getOwnPropertyNames(gen_config)));
}
return gen_config;
}
/**
*
* @param {GenerationConfig} generation_config
* @param {StoppingCriteriaList} [stopping_criteria=null]
*/
_get_stopping_criteria(generation_config, stopping_criteria = null) {
const criteria = new StoppingCriteriaList();
if (generation_config.max_length !== null) {
criteria.push(new MaxLengthCriteria(
generation_config.max_length,
this.config.max_position_embeddings ?? null,
));
}
// if (generation_config.max_time !== null) {
// criteria.push(new MaxTimeCriteria(generation_config.max_time));
// }
if (generation_config.eos_token_id !== null) {
criteria.push(new EosTokenCriteria(generation_config.eos_token_id));
}
if (stopping_criteria) {
criteria.extend(stopping_criteria);
}
return criteria;
}
/**
* Confirms that the model class is compatible with generation.
* If not, raises an exception that points to the right class to use.
*/
_validate_model_class() {
if (!this.can_generate) {
const generate_compatible_mappings = [
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
// MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, // TODO
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
];
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
const generate_compatible_classes = new Set();
const modelType = this.config.model_type;
for (const model_mapping of generate_compatible_mappings) {
const supported_models = model_mapping.get(modelType);
if (supported_models) {
generate_compatible_classes.add(supported_models[0]);
}
}
let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.`
if (generate_compatible_classes.size > 0) {
errorMessage += ` Please use the following class instead: ${[...generate_compatible_classes].join(', ')}`;
}
throw Error(errorMessage);
}
}
prepare_inputs_for_generation(...args) {
return this._prepare_inputs_for_generation(this, ...args);
}
/**
*
* @param {Object} inputs
* @param {bigint[][]} inputs.generated_input_ids
* @param {Object} inputs.outputs
* @param {Object} inputs.model_inputs
* @param {boolean} inputs.is_encoder_decoder
* @returns {Object} The updated model inputs for the next generation iteration.
*/
_update_model_kwargs_for_generation({ generated_input_ids, outputs, model_inputs, is_encoder_decoder }) {
// update past_key_values
model_inputs['past_key_values'] = this.getPastKeyValues(outputs, model_inputs.past_key_values);
// update inputs for next run
model_inputs['input_ids'] = new Tensor('int64', generated_input_ids.flat(), [generated_input_ids.length, 1]);
if (!is_encoder_decoder) {
// update attention mask
model_inputs.attention_mask = cat(
[
model_inputs.attention_mask,
ones([model_inputs.attention_mask.dims[0], 1]),
], 1
);
} else if ('decoder_attention_mask' in model_inputs) {
// TODO: update decoder attention mask if the model requires it
}
// force recreate position_ids in next iteration
model_inputs['position_ids'] = null;
return model_inputs;
}
/**
* This function extracts the model-specific `inputs` for generation.
* @param {Object} params
* @param {Tensor} [params.inputs=null]
* @param {number} [params.bos_token_id=null]
* @param {Record<string, Tensor|number[]>} [params.model_kwargs]
* @returns {{inputs_tensor: Tensor, model_inputs: Record<string, Tensor>, model_input_name: string}} The model-specific inputs for generation.
*/
_prepare_model_inputs({ inputs, bos_token_id, model_kwargs }) {
const model_inputs = pick(model_kwargs, this.forward_params);
const input_name = this.main_input_name;
if (input_name in model_inputs) {
if (inputs) {
throw new Error(
"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " +
"Make sure to either pass {inputs} or {input_name}=..."
);
}
} else {
model_inputs[input_name] = inputs;
}
const inputs_tensor = model_inputs[input_name];
return { inputs_tensor, model_inputs, model_input_name: input_name };
}
async _prepare_encoder_decoder_kwargs_for_generation({ inputs_tensor, model_inputs, model_input_name, generation_config }) {
if (
this.sessions['model'].inputNames.includes('inputs_embeds')
&& !model_inputs.inputs_embeds
&& '_prepare_inputs_embeds' in this
) {
// Encoder expects `inputs_embeds` instead of `input_ids`
const { input_ids, pixel_values, attention_mask, ...kwargs } = model_inputs;
// @ts-ignore
const prepared_inputs = await this._prepare_inputs_embeds(model_inputs);
model_inputs = {
...kwargs,
...pick(prepared_inp