chromadb-default-embed
Version:
Chroma's fork of @xenova/transformers serving as our default embedding function
1,299 lines (1,115 loc) • 222 kB
JavaScript
/**
* @file Definitions of all models available in Transformers.js.
*
* **Example:** Load and run an `AutoModel`.
*
* ```javascript
* import { AutoModel, AutoTokenizer } from '@xenova/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
* let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased');
*
* let inputs = await tokenizer('I love transformers!');
* let { logits } = await model(inputs);
* // Tensor {
* // data: Float32Array(183132) [-7.117443084716797, -7.107812881469727, -7.092104911804199, ...]
* // dims: (3) [1, 6, 30522],
* // type: "float32",
* // size: 183132,
* // }
* ```
*
* We also provide other `AutoModel`s (listed below), which you can use in the same way as the Python library. For example:
*
* **Example:** Load and run an `AutoModelForSeq2SeqLM`.
* ```javascript
* import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@xenova/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');
* let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small');
*
* let { input_ids } = await tokenizer('translate English to German: I love transformers!');
* let outputs = await model.generate(input_ids);
* let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true });
* // 'Ich liebe Transformatoren!'
* ```
*
* @module models
*/
import {
AutoConfig,
} from './configs.js';
import {
add_token_types,
} from './tokenizers.js';
import {
Callable,
isIntegralNumber,
isTypedArray,
mergeArrays,
} from './utils/core.js';
import {
getModelFile,
getModelJSON,
} from './utils/hub.js';
import {
LogitsProcessorList,
GenerationConfig,
ForceTokensLogitsProcessor,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
WhisperTimeStampLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
NoBadWordsLogitsProcessor,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
Sampler,
} from './utils/generation.js';
import {
cat,
dynamicTimeWarping,
mean,
ones_like,
stack,
std_mean,
Tensor,
} from './utils/tensor.js';
import { executionProviders, ONNX } from './backends/onnx.js';
import { medianFilter } from './transformers.js';
const { InferenceSession, Tensor: ONNXTensor, env } = ONNX;
/** @typedef {import('onnxruntime-web').InferenceSession} InferenceSession */
//////////////////////////////////////////////////
// Model types: used internally
const MODEL_TYPES = {
EncoderOnly: 0,
EncoderDecoder: 1,
Seq2Seq: 2,
Vision2Seq: 3,
DecoderOnly: 4,
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// Helper functions
// NOTE: These will be populated fully later
const MODEL_TYPE_MAPPING = new Map();
const MODEL_NAME_TO_CLASS_MAPPING = new Map();
const MODEL_CLASS_TO_NAME_MAPPING = new Map();
/**
* Constructs an InferenceSession using a model file located at the specified path.
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
* @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model.
* @returns {Promise<InferenceSession>} A Promise that resolves to an InferenceSession object.
* @private
*/
async function constructSession(pretrained_model_name_or_path, fileName, options) {
// TODO add option for user to force specify their desired execution provider
let modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`;
let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
try {
return await InferenceSession.create(buffer, {
executionProviders,
});
} catch (err) {
// If the execution provided was only wasm, throw the error
if (executionProviders.length === 1 && executionProviders[0] === 'wasm') {
throw err;
}
console.warn(err);
console.warn(
'Something went wrong during model construction (most likely a missing operation). ' +
'Using `wasm` as a fallback. '
)
return await InferenceSession.create(buffer, {
executionProviders: ['wasm']
});
}
}
/**
* Validate model inputs
* @param {InferenceSession} session The InferenceSession object that will be run.
* @param {Record<string, Tensor>} inputs The inputs to check.
* @returns {Record<string, Tensor>} The checked inputs.
* @throws {Error} If any inputs are missing.
* @private
*/
function validateInputs(session, inputs) {
/**
* NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy`
* @type {Record<string, Tensor>}
*/
const checkedInputs = Object.create(null);
const missingInputs = [];
for (const inputName of session.inputNames) {
const tensor = inputs[inputName];
// Rare case where one of the model's input names corresponds to a built-in
// object name (e.g., toString), which would cause a simple (!tensor) check to fail,
// because it's not undefined but a function.
if (!(tensor instanceof Tensor)) {
missingInputs.push(inputName);
continue;
}
// NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker
// boundary, transferring ownership to the worker and invalidating the tensor.
// So, in this case, we simply sacrifice a clone for it.
checkedInputs[inputName] = env.wasm.proxy ? tensor.clone() : tensor;
}
if (missingInputs.length > 0) {
throw new Error(
`An error occurred during model execution: "Missing the following inputs: ${missingInputs.join(', ')}.`);
}
const numInputsProvided = Object.keys(inputs).length;
const numInputsNeeded = session.inputNames.length;
if (numInputsProvided > numInputsNeeded) {
// No missing inputs, but too many inputs were provided.
// Warn the user and ignore the extra inputs.
let ignored = Object.keys(inputs).filter(inputName => !session.inputNames.includes(inputName));
console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(', ')}".`);
}
return checkedInputs;
}
/**
* Executes an InferenceSession using the specified inputs.
* NOTE: `inputs` must contain at least the input names of the model.
* - If additional inputs are passed, they will be ignored.
* - If inputs are missing, an error will be thrown.
*
* @param {InferenceSession} session The InferenceSession object to run.
* @param {Object} inputs An object that maps input names to input tensors.
* @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors.
* @private
*/
async function sessionRun(session, inputs) {
const checkedInputs = validateInputs(session, inputs);
try {
// @ts-ignore
let output = await session.run(checkedInputs);
output = replaceTensors(output);
return output;
} catch (e) {
// This usually occurs when the inputs are of the wrong type.
console.error(`An error occurred during model execution: "${e}".`);
console.error('Inputs given to model:', checkedInputs);
throw e;
}
}
/**
* Replaces ONNX Tensor objects with custom Tensor objects to support additional functions.
* @param {Object} obj The object to replace tensor objects in.
* @returns {Object} The object with tensor objects replaced by custom Tensor objects.
* @private
*/
function replaceTensors(obj) {
for (let prop in obj) {
if (obj[prop] instanceof ONNXTensor) {
obj[prop] = new Tensor(obj[prop]);
} else if (typeof obj[prop] === 'object') {
replaceTensors(obj[prop]);
}
}
return obj;
}
/**
* Converts an array or Tensor of integers to an int64 Tensor.
* @param {Array|Tensor} items The input integers to be converted.
* @returns {Tensor} The int64 Tensor with the converted values.
* @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
* @private
*/
function toI64Tensor(items) {
if (items instanceof Tensor) {
return items;
}
// items is an array
if (items.length === 0) {
throw Error("items must be non-empty");
}
if (Array.isArray(items[0])) {
// batched
if (items.some(x => x.length !== items[0].length)) {
throw Error("Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' and/or 'truncation=True' to have batched tensors with the same length.")
}
return new Tensor('int64',
BigInt64Array.from(items.flat().map(x => BigInt(x))),
[items.length, items[0].length]
);
} else {
//flat
return new Tensor('int64',
BigInt64Array.from(items.map(x => BigInt(x))),
[1, items.length]
);
}
}
/**
* Prepares an attention mask for a sequence of tokens based on configuration options.
* @param {Object} self The calling object instance.
* @param {Tensor} tokens The input tokens.
* @returns {Tensor} The attention mask tensor.
* @private
*/
function prepareAttentionMask(self, tokens) {
// Prepare attention mask
let pad_token_id = self.config.pad_token_id ?? null;
let eos_token_id = self.config.eos_token_id ?? null;
if (isIntegralNumber(eos_token_id)) {
eos_token_id = [eos_token_id];
}
let is_pad_token_in_inputs = tokens.indexOf(pad_token_id) !== -1;
let is_pad_token_not_equal_to_eos_token_id = (eos_token_id === null) || !eos_token_id.includes(pad_token_id)
if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) {
let data = BigInt64Array.from(
// Note: != so that int matches bigint
// @ts-ignore
tokens.data.map(x => x != pad_token_id)
)
return new Tensor('int64', data, tokens.dims)
} else {
return ones_like(tokens);
}
}
/**
* Add position IDs to the feeds object.
* @param {Object} session The inference session.
* @param {Object} feeds The input to the model.
* @param {boolean} use_cache_branch Whether to use the cache branch of the model.
* @returns {void}
* @private
*/
function preparePositionIds(session, feeds, use_cache_branch) {
if (!session.inputNames.includes('position_ids')) return;
const data = new BigInt64Array(feeds.attention_mask.data.length);
// Compute cumulative sum of the attention mask along the sequence length dimension
for (let i = 0; i < feeds.attention_mask.dims[0]; ++i) {
let start = i * feeds.attention_mask.dims[1];
let sum = BigInt(0);
for (let j = 0; j < feeds.attention_mask.dims[1]; ++j) {
const index = start + j;
if (feeds.attention_mask.data[index] === 0n) {
data[index] = BigInt(1);
} else { // === 1n
data[index] = sum;
sum += feeds.attention_mask.data[index];
}
}
}
feeds.position_ids = new Tensor('int64', data, feeds.attention_mask.dims);
if (use_cache_branch) {
feeds.position_ids = feeds.position_ids.slice(null, -1).unsqueeze_(-1);
}
}
/**
* Creates a boolean tensor with a single value.
* @param {boolean} value The value of the tensor.
* @returns {Tensor} The boolean tensor.
* @private
*/
function boolTensor(value) {
return new Tensor('bool', [value], [1]);
}
// JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in
/**
* Perform forward pass on the seq2seq model (both encoder and decoder).
* @param {Object} self The seq2seq model object.
* @param {Object} model_inputs The input object for the model containing encoder and decoder inputs.
* @returns {Promise<Seq2SeqLMOutput>} Promise that resolves with the output of the seq2seq model.
* @private
*/
async function seq2seqForward(self, model_inputs) {
let { encoder_outputs, past_key_values } = model_inputs;
if (!encoder_outputs) {
// Encoder outputs are not given, so we must compute them.
encoder_outputs = (await encoderForward(self, model_inputs)).last_hidden_state;
}
let decoderFeeds = {
input_ids: model_inputs.decoder_input_ids,
encoder_hidden_states: encoder_outputs,
};
const use_cache_branch = !!past_key_values;
if (self.decoder_merged_session.inputNames.includes('use_cache_branch')) {
decoderFeeds.use_cache_branch = boolTensor(use_cache_branch);
}
if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) {
decoderFeeds.encoder_attention_mask = model_inputs.attention_mask
}
preparePositionIds(self.decoder_merged_session, decoderFeeds, use_cache_branch);
self.addPastKeyValues(decoderFeeds, past_key_values);
const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds);
let logits = decoderResults.logits;
past_key_values = self.getPastKeyValues(decoderResults, past_key_values);
// Get cross attention and/or decoder attentions if they are present
const attns = self.getAttentions(decoderResults);
return new Seq2SeqLMOutput({ logits, past_key_values, encoder_outputs, ...attns });
}
/**
* Start the beam search process for the seq2seq model.
* @param {PreTrainedModel} self The seq2seq model object.
* @param {Tensor} inputTokenIds Array of input token ids for each input sequence.
* @param {Object} generation_config The generation config.
* @param {number} numOutputTokens The maximum number of output tokens for the model.
* @returns {Object[]} Array of beam search objects.
* @private
*/
function seq2seqStartBeams(self, inputTokenIds, generation_config, numOutputTokens) {
let beams = [];
let beamId = 0;
// @ts-ignore
const requires_attention_mask = self.requires_attention_mask ?? true;
// decoder_input_ids == output_token_ids
let decoder_input_ids =
generation_config.decoder_input_ids
?? generation_config.decoder_start_token_id
?? generation_config.bos_token_id
?? generation_config.eos_token_id;
// Support input as tensor or list
// TODO support batched decoder_input_ids
if (decoder_input_ids instanceof Tensor) {
decoder_input_ids = decoder_input_ids.tolist().flat();
} else if (!Array.isArray(decoder_input_ids)) {
decoder_input_ids = [decoder_input_ids];
}
for (let tokens of inputTokenIds) {
// TODO: Improve
// Currently, just add back batch dimension.
// In future, allow for true parallel execution
tokens.dims = [1, ...tokens.dims]
// Create beam
let start = {
inputs: tokens,
encoder_outputs: null,
prev_model_outputs: null,
output_token_ids: decoder_input_ids,
done: false,
score: 0,
id: beamId++ // assign unique id to beams
}
if (requires_attention_mask) {
start.attention_mask = prepareAttentionMask(self, tokens);
}
beams.push(start);
}
return beams;
}
/**
* Run beam search on the seq2seq model for a single beam.
* @param {PreTrainedModel} self The seq2seq model object.
* @param {Object} beam The beam search object for which to run the model.
* @param {Object} options options
* @param {string} [options.input_name='input_ids'] The name of the input tensor for the encoder.
* @returns {Promise<Object>} Promise that resolves with the output of the seq2seq model for the given beam.
* @private
*/
async function seq2seqRunBeam(self, beam) {
const input_name = self.main_input_name;
let decoder_input_ids = beam.output_token_ids;
if (beam.prev_model_outputs) {
// After the first step, `prev_model_outputs` won't be null.
// So, we cut decoder_input_ids if past is used
decoder_input_ids = decoder_input_ids.slice(-1);
}
// 1. Prepare
let model_inputs = {
[input_name]: beam.inputs,
decoder_input_ids: toI64Tensor(decoder_input_ids),
encoder_outputs: beam.encoder_outputs,
past_key_values: beam.prev_model_outputs?.past_key_values,
}
if (beam.attention_mask) {
model_inputs.attention_mask = beam.attention_mask
}
// 2. Run
let output = await self.forward(model_inputs);
// 3. Update
beam.prev_model_outputs = output;
beam.encoder_outputs = output.encoder_outputs;
return output;
}
/**
* Update a beam with a new token ID.
* @param {Object} beam The beam to update.
* @param {number} newTokenId The new token ID to add to the beam's output.
* @private
*/
function seq2seqUpdatebeam(beam, newTokenId) {
beam.output_token_ids = [...beam.output_token_ids, newTokenId];
}
/**
* Forward pass of an encoder model.
* @param {Object} self The encoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @returns {Promise<Object>} Promise that resolves with an object containing the model's outputs.
* @private
*/
async function encoderForward(self, model_inputs) {
const encoderFeeds = Object.create(null);
for (const key of self.session.inputNames) {
encoderFeeds[key] = model_inputs[key];
}
if (self.session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
// Assign default `token_type_ids` to the `encoderFeeds` if the model expects it,
// but they weren't created by the tokenizer.
add_token_types(encoderFeeds);
}
return await sessionRun(self.session, encoderFeeds);
}
/**
* Forward pass of a decoder model.
* @param {Object} self The decoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @returns {Promise<Object>} Promise that resolves with an object containing the logits and past key values.
* @private
*/
async function decoderForward(self, model_inputs) {
let { input_ids, past_key_values, attention_mask } = model_inputs;
let decoderFeeds = {
input_ids: input_ids,
attention_mask: attention_mask ?? prepareAttentionMask(self, input_ids),
}
const use_cache_branch = !!past_key_values;
if (self.session.inputNames.includes('use_cache_branch')) {
decoderFeeds.use_cache_branch = boolTensor(use_cache_branch);
}
preparePositionIds(self.session, decoderFeeds, use_cache_branch);
self.addPastKeyValues(decoderFeeds, past_key_values);
let decoderResults = await sessionRun(self.session, decoderFeeds);
let logits = decoderResults.logits;
past_key_values = self.getPastKeyValues(decoderResults, past_key_values);
return { logits, past_key_values };
}
/**
* Starts the generation of text by initializing the beams for the given input token IDs.
* @param {Object} self The text generation model object.
* @param {Tensor} inputTokenIds An tensor of input token IDs to generate text from.
* @param {Object} generation_config The generation config.
* @param {number} numOutputTokens The maximum number of tokens to generate for each beam.
* @param {Tensor} [inputs_attention_mask] The attention mask tensor for the input token IDs.
* @returns {Object[]} An array of beams initialized with the given inputs and parameters.
* @private
*/
function decoderStartBeams(self, inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask) {
let beams = [];
let beamId = 0;
for (let tokens of inputTokenIds) {
let output_token_ids = tokens.tolist().map(Number);
// TODO: Improve
// Currently, just add back batch dimension.
// In future, allow for true parallel execution
tokens.dims = [1, ...tokens.dims]
let attn_mask;
if (inputs_attention_mask) {
attn_mask = inputs_attention_mask[beamId];
attn_mask.dims = [1, ...attn_mask.dims]
} else {
attn_mask = prepareAttentionMask(self, tokens)
}
let start = {
input: tokens,
model_input_ids: tokens,
attention_mask: attn_mask,
prev_model_outputs: null,
output_token_ids: output_token_ids,
num_output_tokens: numOutputTokens,
done: false,
score: 0,
id: beamId++ // assign unique id to beams
}
beams.push(start);
}
return beams;
}
/**
* Runs a single step of the text generation process for a given beam.
*
* @param {Object} self The decoder object.
* @param {Object} beam The beam to run.
* @param {Tensor} beam.input The input tensor.
* @param {Tensor} beam.model_input_ids The input ids to the model.
* @param {Tensor} beam.attention_mask The attention mask.
* @param {Object} beam.prev_model_outputs The past key values.
* @param {number[]} beam.output_token_ids The output token ids.
* @returns {Promise<Object>} The output of the generation step.
* @private
*/
async function decoderRunBeam(self, beam) {
let attnMaskData = new BigInt64Array(beam.output_token_ids.length).fill(1n)
// 1. Prepare
let model_inputs = {
input_ids: beam.model_input_ids,
attention_mask: new Tensor(
'int64',
attnMaskData,
[1, attnMaskData.length]
),
past_key_values: beam.prev_model_outputs?.past_key_values,
}
// 2. Run
let output = await self.forward(model_inputs);
// 3. Update
beam.prev_model_outputs = output;
return output;
}
/**
* Update a beam with a new token ID.
* @param {Object} beam The beam to update.
* @param {number} newTokenId The new token ID to add to the beam's output.
* @private
*/
function decoderUpdatebeam(beam, newTokenId) {
beam.output_token_ids = [...beam.output_token_ids, newTokenId];
beam.model_input_ids = new Tensor('int64', [BigInt(newTokenId)], [1, 1]);
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
/**
* A base class for pre-trained models that provides the model configuration and an ONNX session.
*/
export class PreTrainedModel extends Callable {
main_input_name = 'input_ids';
/**
* Creates a new instance of the `PreTrainedModel` class.
* @param {Object} config The model configuration.
* @param {any} session session for the model.
*/
constructor(config, session) {
super();
this.config = config;
this.session = session;
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
this.can_generate = false;
this._runBeam = null;
this._getStartBeams = null;
this._updateBeam = null;
this._forward = null;
if (modelType === MODEL_TYPES.DecoderOnly) {
this.can_generate = true;
this._runBeam = decoderRunBeam;
this._getStartBeams = decoderStartBeams;
this._updateBeam = decoderUpdatebeam;
this._forward = decoderForward;
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
this.can_generate = true;
this._runBeam = seq2seqRunBeam;
this._getStartBeams = seq2seqStartBeams;
this._updateBeam = seq2seqUpdatebeam;
this._forward = seq2seqForward;
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
this._forward = encoderForward;
} else { // should be MODEL_TYPES.EncoderOnly
this._forward = encoderForward;
}
}
/**
* Disposes of all the ONNX sessions that were created during inference.
* @returns {Promise<unknown[]>} An array of promises, one for each ONNX session that is being disposed.
* @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry
*/
async dispose() {
const promises = [];
for (let key of Object.keys(this)) {
const item = this[key];
// @ts-ignore
if (item instanceof InferenceSession) {
promises.push(item.handler.dispose())
}
}
return await Promise.all(promises);
}
/**
* Instantiate one of the model classes of the library from a pretrained model.
*
* The model class to instantiate is selected based on the `model_type` property of the config object
* (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible)
*
* @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either:
* - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing model weights, e.g., `./my_model_directory/`.
* @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model.
*
* @returns {Promise<PreTrainedModel>} A new instance of the `PreTrainedModel` class.
*/
static async from_pretrained(pretrained_model_name_or_path, {
quantized = true,
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
model_file_name = null,
} = {}) {
let options = {
quantized,
progress_callback,
config,
cache_dir,
local_files_only,
revision,
model_file_name,
}
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
let info;
if (modelType === MODEL_TYPES.DecoderOnly) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'decoder_model_merged', options),
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, 'encoder_model', options),
constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options),
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, 'encoder_model', options),
constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options),
]);
} else { // should be MODEL_TYPES.EncoderOnly
if (modelType !== MODEL_TYPES.EncoderOnly) {
console.warn(`Model type for '${modelName}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`)
}
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'model', options)
]);
}
// @ts-ignore
return new this(...info);
}
/**
* Runs the model with the provided inputs
* @param {Object} model_inputs Object containing input tensors
* @returns {Promise<Object>} Object containing output tensors
*/
async _call(model_inputs) {
return await this.forward(model_inputs);
}
/**
* Forward method for a pretrained model. If not overridden by a subclass, the correct forward method
* will be chosen based on the model type.
* @param {Object} model_inputs The input data to the model in the format specified in the ONNX model.
* @returns {Promise<Object>} The output data from the model in the format specified in the ONNX model.
* @throws {Error} This method must be implemented in subclasses.
*/
async forward(model_inputs) {
return await this._forward(this, model_inputs);
}
/**
* @param {import('./utils/generation.js').GenerationConfigType} generation_config
* @param {number} input_ids_seq_length The starting sequence length for the input ids.
* @returns {LogitsProcessorList}
* @private
*/
_get_logits_processor(
generation_config,
input_ids_seq_length,
// encoder_input_ids, TODO
// prefix_allowed_tokens_fn, TODO
logits_processor = null
) {
const processors = new LogitsProcessorList();
// if (generation_config.diversity_penalty !== null && generation_config.diversity_penalty > 0.0) {
// processors.push(new HammingDiversityLogitsProcessor(
// generation_config.diversity_penalty,
// generation_config.num_beams,
// generation_config.num_beam_groups
// ));
// }
// if (generation_config.encoder_repetition_penalty !== null && generation_config.encoder_repetition_penalty !== 1.0) {
// processors.push(new EncoderRepetitionPenaltyLogitsProcessor(
// generation_config.encoder_repetition_penalty,
// encoder_input_ids
// ));
// }
if (generation_config.repetition_penalty !== null && generation_config.repetition_penalty !== 1.0) {
processors.push(new RepetitionPenaltyLogitsProcessor(generation_config.repetition_penalty));
}
if (generation_config.no_repeat_ngram_size !== null && generation_config.no_repeat_ngram_size > 0) {
processors.push(new NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size));
}
// if (generation_config.encoder_no_repeat_ngram_size !== null && generation_config.encoder_no_repeat_ngram_size > 0) {
// if (this.config.is_encoder_decoder) {
// processors.push(new EncoderNoRepeatNGramLogitsProcessor(
// generation_config.encoder_no_repeat_ngram_size,
// encoder_input_ids
// ));
// } else {
// throw new Error("It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture");
// }
// }
if (generation_config.bad_words_ids !== null) {
processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
}
if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
}
if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
processors.push(new MinNewTokensLengthLogitsProcessor(
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id
));
}
// if (prefix_allowed_tokens_fn !== null) {
// processors.push(new PrefixConstrainedLogitsProcessor(
// prefix_allowed_tokens_fn,
// generation_config.num_beams / generation_config.num_beam_groups
// ));
// }
if (generation_config.forced_bos_token_id !== null) {
processors.push(new ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id));
}
if (generation_config.forced_eos_token_id !== null) {
processors.push(new ForcedEOSTokenLogitsProcessor(
generation_config.max_length,
generation_config.forced_eos_token_id
));
}
// if (generation_config.remove_invalid_values === true) {
// processors.push(new InfNanRemoveLogitsProcessor());
// }
// if (generation_config.exponential_decay_length_penalty !== null) {
// processors.push(new ExponentialDecayLengthPenalty(
// generation_config.exponential_decay_length_penalty,
// generation_config.eos_token_id,
// input_ids_seq_length
// ));
// }
// if (generation_config.suppress_tokens !== null) {
// processors.push(new SuppressTokensLogitsProcessor(generation_config.suppress_tokens));
// }
if (generation_config.begin_suppress_tokens !== null) {
let begin_index = (input_ids_seq_length > 1 || generation_config.forced_bos_token_id === null)
? input_ids_seq_length
: input_ids_seq_length + 1;
if (generation_config.forced_decoder_ids !== null) {
// generation starts after the last token that is forced
begin_index += generation_config.forced_decoder_ids[generation_config.forced_decoder_ids.length - 1][0];
}
processors.push(new SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index));
}
if (generation_config.forced_decoder_ids !== null) {
processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids));
}
if (logits_processor !== null) {
processors.extend(logits_processor)
}
// `LogitNormalization` should always be the last logit processor, when present
// if (generation_config.renormalize_logits === true) {
// processors.push(new LogitNormalization());
// }
return processors;
}
/**
* This function merges multiple generation configs together to form a final generation config to be used by the model for text generation.
* It first creates an empty `GenerationConfig` object, then it applies the model's own `generation_config` property to it. Finally, if a `generation_config` object was passed in the arguments, it overwrites the corresponding properties in the final config with those of the passed config object.
* @param {import('./utils/generation.js').GenerationConfigType} generation_config A `GenerationConfig` object containing generation parameters.
* @returns {import('./utils/generation.js').GenerationConfigType} The final generation config object to be used by the model for text generation.
*/
_get_generation_config(generation_config) {
// Create empty generation config (contains defaults)
// We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
let gen_config = new GenerationConfig(this.config);
// Apply model's generation config, if it exists
if ('generation_config' in this) {
Object.assign(gen_config, this.generation_config);
}
// Finally, use any generation config specified by the user
// when calling `generate`
if (generation_config !== null) {
Object.assign(gen_config, generation_config);
}
return gen_config;
}
/**
* @typedef {import('./utils/maths.js').TypedArray} TypedArray
*/
/**
* @typedef {{ sequences: Tensor, decoder_attentions: Tensor, cross_attentions: Tensor }} EncoderDecoderOutput
* @typedef {Object} DecoderOutput
*
* Generates text based on the given inputs and generation configuration using the model.
* @param {Tensor|Array|TypedArray} inputs An array of input token IDs.
* @param {Object|GenerationConfig|null} generation_config The generation configuration to use. If null, default configuration will be used.
* @param {Object|null} logits_processor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created.
* @param {Object} options options
* @param {Object} [options.inputs_attention_mask=null] An optional attention mask for the inputs.
* @returns {Promise<number[][]|EncoderDecoderOutput|DecoderOutput>} An array of generated output sequences, where each sequence is an array of token IDs.
* @throws {Error} Throws an error if the inputs array is empty.
*/
async generate(
inputs,
generation_config = null,
logits_processor = null,
{
inputs_attention_mask = null
} = {},
) {
if (!this.can_generate) {
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.`
const modelType = this.config.model_type;
const possibleInfo =
MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(modelType)
?? MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.get(modelType)
?? MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.get(modelType)
// ?? MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES.get(modelType) // TODO
?? MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.get(modelType);
if (possibleInfo) {
// TODO: support multiple possible classes
errorMessage += ` Please use the following class instead: '${possibleInfo[0]}'`;
}
throw Error(errorMessage);
}
if (!(inputs instanceof Tensor) && !isTypedArray(inputs) && !Array.isArray(inputs)) {
throw Error(`\`inputs\` must be a Tensor, TypedArray, or Array, but is "${inputs.constructor.name}".`);
}
let input_ids_seq_length;
// Prepare `input_ids` which will be used for auto-regressive generation
// TODO: Update to align with HF transformers' implementation
if (this.config.is_encoder_decoder) {
// Generating from the encoder outputs
input_ids_seq_length = 0;
} else {
input_ids_seq_length = inputs instanceof Tensor ? inputs.dims.at(-1) : inputs.length;
// decoder-only
if (input_ids_seq_length === 0) {
throw Error("Must supply a non-empty array of input token ids.")
}
}
// Update generation config with defaults
generation_config = this._get_generation_config(generation_config);
logits_processor = logits_processor ?? new LogitsProcessorList()
// Update logits processor
logits_processor = this._get_logits_processor(
generation_config,
input_ids_seq_length,
logits_processor
)
/** @type {number[]} */
let eos_token_ids = generation_config.eos_token_id;
if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) {
eos_token_ids = [eos_token_ids];
}
// TODO implement early_stopping
// https://huggingface.co/blog/how-to-generate
let numOutputTokens = 1;
const maxOutputTokens = numOutputTokens + (generation_config.max_new_tokens ?? Infinity);
// Only use max length if max_new_tokens is not provided
const useMaxLength = Number.isInteger(generation_config.max_length) && (generation_config.max_new_tokens ?? null) === null;
let sampler = Sampler.getSampler(generation_config);
// @ts-ignore
let beams = this.getStartBeams(inputs, generation_config, numOutputTokens, inputs_attention_mask);
while (beams.some(x => !x.done) && numOutputTokens < maxOutputTokens) {
let newest_beams = [];
for (let beam of beams) {
if (beam.done) {
// Add this beam back into the pool
newest_beams.push(beam);
continue
}
if (useMaxLength && beam.output_token_ids.length >= generation_config.max_length) {
// Set this beam to done and add it back into the pool
beam.done = true;
newest_beams.push(beam);
continue
}
// @ts-ignore
let output = await this.runBeam(beam);
// add attentions/scores to beam only if user requested
if (generation_config.output_attentions) {
this.addAttentionsToBeam(beam, output);
}
if (generation_config.output_scores) {
// TODO add
}
// Logits are of the form [batch_size, out_seq_length, vocab_size]
// In most cases, this will be [batch_size, 1, vocab_size]
// So, we select the last token's logits:
// (equivalent to `logits = outputs.logits[:, -1, :]`)
let logits = output.logits.slice(null, -1, null);
// Apply logits processor
logits_processor(beam.output_token_ids, logits);
let sampledTokens = sampler(logits);
for (let [newTokenId, logProb] of sampledTokens) {
// use previous beam as a starting point
let newBeam = { ...beam };
// update new beam
// @ts-ignore
this.updateBeam(newBeam, newTokenId);
newBeam.score += logProb;
if (eos_token_ids && eos_token_ids.includes(newTokenId)) {
newBeam.done = true;
}
newest_beams.push(newBeam);
}
}
++numOutputTokens;
// Next, we get the best beams, per ID
newest_beams = this.groupBeams(newest_beams).map(
group => group
.sort((a, b) => b.score - a.score) // sort by score
.slice(0, generation_config.num_beams) // remove outside beam width
);
// Flatten beams
beams = newest_beams.flat();
// Run callback
if (generation_config.callback_function) {
generation_config.callback_function(beams);
}
}
// TODO: Ensure that we can return non-batched outputs
const groupedBeams = this.groupBeams(beams);
const getFlattened = (key) => groupedBeams.map(
batch => {
if (generation_config.num_return_sequences > 1) {
return batch.slice(0, generation_config.num_return_sequences).map(x => x[key]);
} else {
return [batch[0][key]];
}
}
).flat(); // Flatten across batches (depth=1)
const sequences = getFlattened('output_token_ids'); // [1, seqLength]
if (generation_config.return_dict_in_generate) {
// NOTE: `decoder_attentions` and `cross_attentions` should be:
// list (one element for each generated token)
// of list (one element for each layer of the decoder)
// of torch.FloatTensor of shape (batch_size, num_heads, generated_length, sequence_length)
// However, since we are only generating one batch at a time, they are of the form:
// list (batches)
// of list (one element for each generated token)
// of list (one element for each layer of the decoder)
// of torch.FloatTensor of shape (1, num_heads, generated_length, sequence_length)
//
// TODO: In future (when true parallelism, we should be able to return the correct shape)
const decoder_attentions = getFlattened('decoder_attentions');
const cross_attentions = getFlattened('cross_attentions');
return {
sequences,
decoder_attentions,
cross_attentions,
}
} else {
return sequences;
}
}
/**
* Helper function to add attentions to beam
* @param {Object} beam
* @param {Object} output
* @private
*/
addAttentionsToBeam(beam, output) {
if (this.config.is_encoder_decoder) {
if (!output.cross_attentions || output.cross_attentions.length === 0) {
throw Error(
"`output_attentions` is true, but the model did not produce cross-attentions. " +
"This is most likely because the model was not exported with `output_attentions=True`."
)
}
if (!beam.cross_attentions) {
beam.cross_attentions = [];
}
beam.cross_attentions.push(output.cross_attentions);
}
if (!output.decoder_attentions || output.decoder_attentions.length === 0) {
throw Error(
"`output_attentions` is true, but the model did not produce decoder-attentions. " +
"This is most likely because the model was not exported with `output_attentions=True`."
)
}
if (!beam.decoder_attentions) {
beam.decoder_attentions = [];
}
beam.decoder_attentions.push(output.decoder_attentions);
}
/**
* Groups an array of beam objects by their ids.
*
* @param {Array} beams The array of beam objects to group.
* @returns {Array} An array of arrays, where each inner array contains beam objects with the same id.
*/
groupBeams(beams) {
// Group beams by their ids
const groups = Object.create(null);
for (const obj of beams) {
if (groups[obj.id] === undefined) {
groups[obj.id] = [obj];
} else {
groups[obj.id].push(obj);
}
}
return Object.values(groups);
}
/**
* Returns an object containing past key values from the given decoder results object.
*
* @param {Object} decoderResults The decoder results object.
* @param {Object} pastKeyValues The previous past key values.
* @returns {Object} An object containing past key values.
*/
getPastKeyValues(decoderResults, pastKeyValues) {
const pkvs = Object.create(null);
for (const name in decoderResults) {
if (name.startsWith('present')) {
let newName = name.replace('present', 'past_key_values');
if (pastKeyValues && name.includes('encoder')) {
// Optimization introduced by optimum to reuse past key values. So, we just replace the constant
// outputs with the previous past key values.
// https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
pkvs[newName] = pastKeyValues[newName];
} else {
pkvs[newName] = decoderResults[name];
}
}
}
return pkvs;
}
/**
* Returns an object containing attentions from the given decoder results object.
*
* @param {Object} decoderResults The decoder results object.
* @returns {Object} An object containing attentions.
*/
getAttentions(decoderResults) {
const attns = Object.create(null);
for (const attnName of ['cross_attentions', 'decoder_attentions']) {
const result = [];
for (const name in decoderResults) {
if (name.startsWith(attnName)) {
const index = name.split('.').pop()
result[index] = decoderResults[name];
}
}
attns[attnName] = result;
}
return attns;
}
/**
* Adds past key values to the decoder feeds object. If pastKeyValues is null, creates new tensors for past key values.
*
* @param {Object} decoderFeeds The decoder feeds object to add past key values to.
* @param {Object} pastKeyValues An object containing past key values.
*/
addPastKeyValues(decoderFeeds, pastKeyValues) {
if (pastKeyValues) {
Object.assign(decoderFeeds, pastKeyValues)
} else {
// TODO support batches (i.e., batch_size > 1)
const batch_size = 1;
// @ts-ignore
if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) {
// @ts-ignore
let encoder_dims = [batch_size, this.num_encoder_heads, 0, this.encoder_dim_kv];
// @ts-ignore
let decoder_dims = [batch_size, this.num_decoder_heads, 0, this.decoder_dim_kv];
// @ts-ignore
for (let i = 0;