UNPKG

transformers-fork

Version:

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!

364 lines • 15.6 kB
declare const LogitsProcessor_base: new () => { (...args: any[]): any; _call(...args: any[]): any; }; /** * Abstract base class for all logit processors that can be applied during generation. */ export class LogitsProcessor extends LogitsProcessor_base { /** * Apply the processor to the input logits. * * @abstract * @param {bigint[][]} input_ids The input ids. * @param {Tensor} logits The logits to process. * @throws {Error} Throws an error if `_call` is not implemented in the subclass. */ _call(input_ids: bigint[][], logits: Tensor): void; } declare const LogitsWarper_base: new () => { (...args: any[]): any; _call(...args: any[]): any; }; /** * Abstract base class for all logit warpers that can be applied during generation with multinomial sampling. */ export class LogitsWarper extends LogitsWarper_base { /** * Apply the processor to the input logits. * * @abstract * @param {bigint[][]} input_ids The input ids. * @param {Tensor} logits The logits to process. * @throws {Error} Throws an error if `_call` is not implemented in the subclass. */ _call(input_ids: bigint[][], logits: Tensor): void; } declare const LogitsProcessorList_base: new () => { (...args: any[]): any; _call(...args: any[]): any; }; /** * A class representing a list of logits processors. A logits processor is a function that modifies the logits * output of a language model. This class provides methods for adding new processors and applying all processors to a * batch of logits. */ export class LogitsProcessorList extends LogitsProcessorList_base { processors: any[]; /** * Adds a new logits processor to the list. * * @param {LogitsProcessor} item The logits processor function to add. */ push(item: LogitsProcessor): void; /** * Adds multiple logits processors to the list. * * @param {LogitsProcessor[]} items The logits processor functions to add. */ extend(items: LogitsProcessor[]): void; /** * Applies all logits processors in the list to a batch of logits, modifying them in-place. * * @param {bigint[][]} input_ids The input IDs for the language model. * @param {Tensor} logits */ _call(input_ids: bigint[][], logits: Tensor): Tensor; [Symbol.iterator](): ArrayIterator<any>; } /** * A LogitsProcessor that forces a BOS token at the beginning of the generated sequence. */ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor { /** * Create a ForcedBOSTokenLogitsProcessor. * @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced. */ constructor(bos_token_id: number); bos_token_id: number; /** * Apply the BOS token forcing to the logits. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Tensor} The logits with BOS token forcing. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } /** * A logits processor that enforces the specified token as the last generated token when `max_length` is reached. */ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor { /** * Create a ForcedEOSTokenLogitsProcessor. * @param {number} max_length The maximum length of the sequence to be generated. * @param {number|number[]} eos_token_id The id(s) of the *end-of-sequence* token. */ constructor(max_length: number, eos_token_id: number | number[]); max_length: number; eos_token_id: number[]; /** * Apply the processor to input_ids and logits. * * @param {bigint[][]} input_ids The input ids. * @param {Tensor} logits The logits tensor. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } /** * A LogitsProcessor that suppresses a list of tokens as soon as the `generate` function starts * generating using `begin_index` tokens. This should ensure that the tokens defined by * `begin_suppress_tokens` at not sampled at the begining of the generation. */ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor { /** * Create a SuppressTokensAtBeginLogitsProcessor. * @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress. * @param {number} begin_index The number of tokens to generate before suppressing tokens. */ constructor(begin_suppress_tokens: number[], begin_index: number); begin_suppress_tokens: number[]; begin_index: number; /** * Apply the BOS token forcing to the logits. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Tensor} The logits with BOS token forcing. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } /** * A LogitsProcessor that handles adding timestamps to generated text. */ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor { /** * Constructs a new WhisperTimeStampLogitsProcessor. * @param {import('../models/whisper/generation_whisper.js').WhisperGenerationConfig} generate_config The config object passed to the `generate()` method of a transformer model. * @param {number[]} init_tokens The initial tokens of the input sequence. */ constructor(generate_config: import("../models/whisper/generation_whisper.js").WhisperGenerationConfig, init_tokens: number[]); eos_token_id: number; no_timestamps_token_id: number; timestamp_begin: number; begin_index: number; max_initial_timestamp_index: number; /** * Modify the logits to handle timestamp tokens. * @param {bigint[][]} input_ids The input sequence of tokens. * @param {Tensor} logits The logits output by the model. * @returns {Tensor} The modified logits. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } /** * A logits processor that disallows ngrams of a certain size to be repeated. */ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { /** * Create a NoRepeatNGramLogitsProcessor. * @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once. */ constructor(no_repeat_ngram_size: number); no_repeat_ngram_size: number; /** * Generate n-grams from a sequence of token ids. * @param {bigint[]} prevInputIds List of previous input ids * @returns {Map<string, number[]>} Map of generated n-grams */ getNgrams(prevInputIds: bigint[]): Map<string, number[]>; /** * Generate n-grams from a sequence of token ids. * @param {Map<string, number[]>} bannedNgrams Map of banned n-grams * @param {bigint[]} prevInputIds List of previous input ids * @returns {number[]} Map of generated n-grams */ getGeneratedNgrams(bannedNgrams: Map<string, number[]>, prevInputIds: bigint[]): number[]; /** * Calculate banned n-gram tokens * @param {bigint[]} prevInputIds List of previous input ids * @returns {number[]} Map of generated n-grams */ calcBannedNgramTokens(prevInputIds: bigint[]): number[]; /** * Apply the no-repeat-ngram processor to the logits. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Tensor} The logits with no-repeat-ngram processing. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } /** * A logits processor that prevents the repetition of previous tokens through a penalty. * This penalty is applied at most once per token. Note that, for decoder-only models like most LLMs, * the considered tokens include the prompt. * * In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a * penalty of around 1.2 to achieve a good balance between truthful generation and lack of repetition. * To penalize and reduce repetition, use `penalty` values above 1.0, where a higher value penalizes * more strongly. To reward and encourage repetition, use `penalty` values between 0.0 and 1.0, where * a lower value rewards more strongly. */ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { /** * Create a RepetitionPenaltyLogitsProcessor. * @param {number} penalty The parameter for repetition penalty. * - 1.0 means no penalty. Above 1.0 penalizes previously generated tokens. * - Between 0.0 and 1.0 rewards previously generated tokens. */ constructor(penalty: number); penalty: number; /** * Apply the repetition penalty to the logits. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Tensor} The logits with repetition penalty processing. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } /** * A logits processor that enforces a minimum number of tokens. */ export class MinLengthLogitsProcessor extends LogitsProcessor { /** * Create a MinLengthLogitsProcessor. * @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity. * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token. */ constructor(min_length: number, eos_token_id: number | number[]); min_length: number; eos_token_id: number[]; /** * Apply logit processor. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Tensor} The processed logits. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } /** * A logits processor that enforces a minimum number of new tokens. */ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { /** * Create a MinNewTokensLengthLogitsProcessor. * @param {number} prompt_length_to_skip The input tokens length. * @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity. * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token. */ constructor(prompt_length_to_skip: number, min_new_tokens: number, eos_token_id: number | number[]); prompt_length_to_skip: number; min_new_tokens: number; eos_token_id: number[]; /** * Apply logit processor. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Tensor} The processed logits. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } export class NoBadWordsLogitsProcessor extends LogitsProcessor { /** * Create a `NoBadWordsLogitsProcessor`. * @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated. * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. */ constructor(bad_words_ids: number[][], eos_token_id: number | number[]); bad_words_ids: number[][]; eos_token_id: number[]; /** * Apply logit processor. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Tensor} The processed logits. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } /** * [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, * where the first half correspond to the conditional logits (predicted from the input prompt) and the second half * correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a * weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`. * * See [the paper](https://arxiv.org/abs/2306.05284) for more information. */ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor { /** * Create a `ClassifierFreeGuidanceLogitsProcessor`. * @param {number} guidance_scale The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. * Higher guidance scale encourages the model to generate samples that are more closely linked to the input * prompt, usually at the expense of poorer quality. */ constructor(guidance_scale: number); guidance_scale: number; /** * Apply logit processor. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Tensor} The processed logits. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } /** * [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means * that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and [`TopKLogitsWarper`]. */ export class TemperatureLogitsWarper extends LogitsWarper { /** * Create a `TemperatureLogitsWarper`. * @param {number} temperature Strictly positive float value used to modulate the logits distribution. * A value smaller than `1` decreases randomness (and vice versa), with `0` being equivalent to shifting * all probability mass to the most likely token. */ constructor(temperature: number); temperature: number; /** * Apply logit warper. * @param {bigint[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Tensor} The processed logits. */ _call(input_ids: bigint[][], logits: Tensor): Tensor; } /** * [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. * Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`]. */ export class TopPLogitsWarper extends LogitsWarper { /** * Create a `TopPLogitsWarper`. * @param {number} top_p If set to < 1, only the smallest set of most probable tokens with * probabilities that add up to `top_p` or higher are kept for generation. * @param {Object} options Additional options for the top-p sampling. * @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value. * @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered. */ constructor(top_p: number, { filter_value, min_tokens_to_keep, }?: { filter_value?: number; min_tokens_to_keep?: number; }); top_p: number; filter_value: number; min_tokens_to_keep: number; } /** * [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. * Often used together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`]. */ export class TopKLogitsWarper extends LogitsWarper { /** * Create a `TopKLogitsWarper`. * @param {number} top_k If set to > 0, only the top `top_k` tokens are kept for generation. * @param {Object} options Additional options for the top-k sampling. * @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value. * @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered. */ constructor(top_k: number, { filter_value, min_tokens_to_keep, }?: { filter_value?: number; min_tokens_to_keep?: number; }); top_k: number; filter_value: number; } import { Tensor } from "../utils/tensor.js"; export {}; //# sourceMappingURL=logits_process.d.ts.map