UNPKG

@huggingface/transformers

Version:

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!

222 lines (199 loc) • 8.16 kB
/** * @module generation/streamers */ import { mergeArrays } from '../utils/core.js'; import { is_chinese_char } from '../tokenizers.js'; import { apis } from '../env.js'; export class BaseStreamer { /** * Function that is called by `.generate()` to push new tokens * @param {bigint[][]} value */ put(value) { throw Error('Not implemented'); } /** * Function that is called by `.generate()` to signal the end of generation */ end() { throw Error('Not implemented'); } } const stdout_write = apis.IS_PROCESS_AVAILABLE ? x => process.stdout.write(x) : x => console.log(x); /** * Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. */ export class TextStreamer extends BaseStreamer { /** * * @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer * @param {Object} options * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens * @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding * @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display * @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated * @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method */ constructor(tokenizer, { skip_prompt = false, callback_function = null, token_callback_function = null, skip_special_tokens = true, decode_kwargs = {}, ...kwargs } = {}) { super(); this.tokenizer = tokenizer; this.skip_prompt = skip_prompt; this.callback_function = callback_function ?? stdout_write; this.token_callback_function = token_callback_function; this.decode_kwargs = { skip_special_tokens, ...decode_kwargs, ...kwargs }; // variables used in the streaming process this.token_cache = []; this.print_len = 0; this.next_tokens_are_prompt = true; } /** * Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. * @param {bigint[][]} value */ put(value) { if (value.length > 1) { throw Error('TextStreamer only supports batch size of 1'); } const is_prompt = this.next_tokens_are_prompt; if (is_prompt) { this.next_tokens_are_prompt = false; if (this.skip_prompt) return; } const tokens = value[0]; this.token_callback_function?.(tokens) // Add the new token to the cache and decodes the entire thing. this.token_cache = mergeArrays(this.token_cache, tokens); const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs); let printable_text; if (is_prompt || text.endsWith('\n')) { // After the symbol for a new line, we flush the cache. printable_text = text.slice(this.print_len); this.token_cache = []; this.print_len = 0; } else if (text.length > 0 && is_chinese_char(text.charCodeAt(text.length - 1))) { // If the last token is a CJK character, we print the characters. printable_text = text.slice(this.print_len); this.print_len += printable_text.length; } else { // Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, // which may change with the subsequent token -- there are probably smarter ways to do this!) printable_text = text.slice(this.print_len, text.lastIndexOf(' ') + 1); this.print_len += printable_text.length; } this.on_finalized_text(printable_text, false); } /** * Flushes any remaining cache and prints a newline to stdout. */ end() { let printable_text; if (this.token_cache.length > 0) { const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs); printable_text = text.slice(this.print_len); this.token_cache = []; this.print_len = 0; } else { printable_text = ''; } this.next_tokens_are_prompt = true; this.on_finalized_text(printable_text, true); } /** * Prints the new text to stdout. If the stream is ending, also prints a newline. * @param {string} text * @param {boolean} stream_end */ on_finalized_text(text, stream_end) { if (text.length > 0) { this.callback_function?.(text); } if (stream_end && this.callback_function === stdout_write && apis.IS_PROCESS_AVAILABLE) { this.callback_function?.('\n'); } } } /** * Utility class to handle streaming of tokens generated by whisper speech-to-text models. * Callback functions are invoked when each of the following events occur: * - A new chunk starts (on_chunk_start) * - A new token is generated (callback_function) * - A chunk ends (on_chunk_end) * - The stream is finalized (on_finalize) */ export class WhisperTextStreamer extends TextStreamer { /** * @param {import('../tokenizers.js').WhisperTokenizer} tokenizer * @param {Object} options * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens * @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display * @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated * @param {function(number): void} [options.on_chunk_start=null] Function to call when a new chunk starts * @param {function(number): void} [options.on_chunk_end=null] Function to call when a chunk ends * @param {function(): void} [options.on_finalize=null] Function to call when the stream is finalized * @param {number} [options.time_precision=0.02] Precision of the timestamps * @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding * @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method */ constructor(tokenizer, { skip_prompt = false, callback_function = null, token_callback_function = null, on_chunk_start = null, on_chunk_end = null, on_finalize = null, time_precision = 0.02, skip_special_tokens = true, decode_kwargs = {}, } = {}) { super(tokenizer, { skip_prompt, skip_special_tokens, callback_function, token_callback_function, decode_kwargs, }); this.timestamp_begin = tokenizer.timestamp_begin; this.on_chunk_start = on_chunk_start; this.on_chunk_end = on_chunk_end; this.on_finalize = on_finalize; this.time_precision = time_precision; this.waiting_for_timestamp = false; } /** * @param {bigint[][]} value */ put(value) { if (value.length > 1) { throw Error('WhisperTextStreamer only supports batch size of 1'); } const tokens = value[0]; // Check if the token is a timestamp if (tokens.length === 1) { const offset = Number(tokens[0]) - this.timestamp_begin; if (offset >= 0) { const time = offset * this.time_precision; if (this.waiting_for_timestamp) { this.on_chunk_end?.(time); } else { this.on_chunk_start?.(time); } this.waiting_for_timestamp = !this.waiting_for_timestamp; // Toggle value = [[]]; // Skip timestamp } } return super.put(value); } end() { super.end(); this.on_finalize?.(); } }