UNPKG

transformers-fork

Version:

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!

205 lines (176 loc) • 6.84 kB
/** * @module generation/logits_sampler */ import { Callable } from "../utils/generic.js"; import { Tensor, topk } from "../utils/tensor.js"; import { max, softmax, } from '../utils/maths.js'; import { GenerationConfig } from '../generation/configuration_utils.js'; /** * Sampler is a base class for all sampling methods used for text generation. */ export class LogitsSampler extends Callable { /** * Creates a new Sampler object with the specified generation config. * @param {GenerationConfig} generation_config The generation config. */ constructor(generation_config) { super(); this.generation_config = generation_config; } /** * Executes the sampler, using the specified logits. * @param {Tensor} logits * @returns {Promise<[bigint, number][]>} */ async _call(logits) { // Sample from logits, of dims [batch, sequence_length, vocab_size]. // If index is specified, sample from [batch, index, vocab_size]. return this.sample(logits); } /** * Abstract method for sampling the logits. * @param {Tensor} logits * @throws {Error} If not implemented in subclass. * @returns {Promise<[bigint, number][]>} */ async sample(logits) { throw Error("sample should be implemented in subclasses.") } /** * Returns the specified logits as an array, with temperature applied. * @param {Tensor} logits * @param {number} index * @returns {Float32Array} */ getLogits(logits, index) { let vocabSize = logits.dims.at(-1); let logs = /** @type {Float32Array} */(logits.data); if (index === -1) { logs = logs.slice(-vocabSize); } else { let startIndex = index * vocabSize; logs = logs.slice(startIndex, startIndex + vocabSize); } return logs; } /** * Selects an item randomly based on the specified probabilities. * @param {import("../transformers.js").DataArray} probabilities An array of probabilities to use for selection. * @returns {number} The index of the selected item. */ randomSelect(probabilities) { // Return index of chosen item let sumProbabilities = 0; for (let i = 0; i < probabilities.length; ++i) { sumProbabilities += probabilities[i]; } let r = Math.random() * sumProbabilities; for (let i = 0; i < probabilities.length; ++i) { r -= probabilities[i]; if (r <= 0) { return i; } } return 0; // return first (most probable) as a fallback } /** * Returns a Sampler object based on the specified options. * @param {GenerationConfig} generation_config An object containing options for the sampler. * @returns {LogitsSampler} A Sampler object. */ static getSampler(generation_config) { // - *greedy decoding*: `num_beams=1` and `do_sample=False` // - *contrastive search*: `penalty_alpha>0` and `top_k>1` // - *multinomial sampling*: `num_beams=1` and `do_sample=True` // - *beam-search decoding*: `num_beams>1` and `do_sample=False` // - *beam-search multinomial sampling*: `num_beams>1` and `do_sample=True` // - *diverse beam-search decoding*: `num_beams>1` and `num_beam_groups>1` // - *constrained beam-search decoding*: `constraints!=None` or `force_words_ids!=None` // NOTE: beam search is implemented directly into the generation function if (generation_config.do_sample) { return new MultinomialSampler(generation_config); } else if (generation_config.num_beams > 1) { return new BeamSearchSampler(generation_config); } else { if (generation_config.num_return_sequences > 1) { throw Error(`num_return_sequences has to be 1 when doing greedy search, but is ${generation_config.num_return_sequences}.`) } return new GreedySampler(generation_config); } } } /** * Class representing a Greedy Sampler. */ class GreedySampler extends LogitsSampler { /** * Sample the maximum probability of a given logits tensor. * @param {Tensor} logits * @returns {Promise<[bigint, number][]>} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search). */ async sample(logits) { // NOTE: no need to do log_softmax here since we only take the maximum const argmax = max(logits.data)[1]; // Note: score is meaningless in this context, since we are performing // greedy search (p = 1 => log(p) = 0) return [ [BigInt(argmax), 0] ]; } } /** * Class representing a MultinomialSampler. */ class MultinomialSampler extends LogitsSampler { /** * Sample from the logits. * @param {Tensor} logits * @returns {Promise<[bigint, number][]>} */ async sample(logits) { let k = logits.dims.at(-1); // defaults to vocab size if (this.generation_config.top_k > 0) { k = Math.min(this.generation_config.top_k, k); } // Get top k tokens const [v, i] = await topk(logits, k); // Compute softmax over logits const probabilities = softmax(/** @type {Float32Array} */(v.data)); return Array.from({ length: this.generation_config.num_beams }, () => { const sampledIndex = this.randomSelect(probabilities); return [ i.data[sampledIndex], // token id Math.log(probabilities[sampledIndex]), // score ]; }); } } /** * Class representing a BeamSearchSampler. */ class BeamSearchSampler extends LogitsSampler { /** * Sample from the logits. * @param {Tensor} logits * @returns {Promise<[bigint, number][]>} */ async sample(logits) { let k = logits.dims.at(-1); // defaults to vocab size if (this.generation_config.top_k > 0) { k = Math.min(this.generation_config.top_k, k); } // Get top k tokens const [v, i] = await topk(logits, k); // Compute softmax over logits const probabilities = softmax(/** @type {Float32Array} */(v.data)); return Array.from({ length: this.generation_config.num_beams }, (_, x) => { return [ i.data[x], // token id Math.log(probabilities[x]), // score ]; }); } }