UNPKG

transformers-fork

Version:

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!

157 lines (137 loc) • 4.83 kB
/** * @module generation/stopping_criteria */ import { Callable } from "../utils/generic.js"; // NOTE: // Stopping Criteria returns a list of `batch_size` booleans, indicating whether each sequence in the batch should be stopped. /** * Abstract base class for all stopping criteria that can be applied during generation. */ export class StoppingCriteria extends Callable { /** * * @param {number[][]} input_ids (`number[][]` of shape `(batch_size, sequence_length)`): * Indices of input sequence tokens in the vocabulary. * @param {number[][]} scores scores (`number[][]` of shape `(batch_size, config.vocab_size)`): * Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax * or scores for each vocabulary token after SoftMax. * @returns {boolean[]} A list of booleans indicating whether each sequence should be stopped. */ _call(input_ids, scores) { throw Error("StoppingCriteria needs to be subclassed"); } } /** */ export class StoppingCriteriaList extends Callable { /** * Constructs a new instance of `StoppingCriteriaList`. */ constructor() { super(); this.criteria = []; } /** * Adds a new stopping criterion to the list. * * @param {StoppingCriteria} item The stopping criterion to add. */ push(item) { this.criteria.push(item); } /** * Adds multiple stopping criteria to the list. * * @param {StoppingCriteria|StoppingCriteriaList|StoppingCriteria[]} items The stopping criteria to add. */ extend(items) { if (items instanceof StoppingCriteriaList) { items = items.criteria; } else if (items instanceof StoppingCriteria) { items = [items]; } this.criteria.push(...items); } _call(input_ids, scores) { const is_done = new Array(input_ids.length).fill(false); for (const criterion of this.criteria) { const criterion_done = criterion(input_ids, scores); for (let i = 0; i < is_done.length; ++i) { is_done[i] ||= criterion_done[i]; } } return is_done; } [Symbol.iterator]() { return this.criteria.values(); } } /** * This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. * Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens. */ export class MaxLengthCriteria extends StoppingCriteria { /** * * @param {number} max_length The maximum length that the output sequence can have in number of tokens. * @param {number} [max_position_embeddings=null] The maximum model length, as defined by the model's `config.max_position_embeddings` attribute. */ constructor(max_length, max_position_embeddings = null) { super(); this.max_length = max_length; this.max_position_embeddings = max_position_embeddings; } _call(input_ids) { return input_ids.map(ids => ids.length >= this.max_length); } } // TODO: add MaxTimeCriteria /** * This class can be used to stop generation whenever the "end-of-sequence" token is generated. * By default, it uses the `model.generation_config.eos_token_id`. */ export class EosTokenCriteria extends StoppingCriteria { /** * * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. * Optionally, use a list to set multiple *end-of-sequence* tokens. */ constructor(eos_token_id) { super(); if (!Array.isArray(eos_token_id)) { eos_token_id = [eos_token_id]; } this.eos_token_id = eos_token_id; } /** * * @param {number[][]} input_ids * @param {number[][]} scores * @returns {boolean[]} */ _call(input_ids, scores) { return input_ids.map(ids => { const last = ids.at(-1); // NOTE: We use == instead of === to allow for number/bigint comparison return this.eos_token_id.some(eos_id => last == eos_id); }); } } /** * This class can be used to stop generation whenever the user interrupts the process. */ export class InterruptableStoppingCriteria extends StoppingCriteria { constructor() { super(); this.interrupted = false; } interrupt() { this.interrupted = true; } reset() { this.interrupted = false; } _call(input_ids, scores) { return new Array(input_ids.length).fill(this.interrupted); } }