UNPKG

@atjsh/llmlingua-2

Version:

JavaScript/TypeScript Implementation of LLMLingua-2

320 lines 13.5 kB
// SPDX-License-Identifier: MIT /** * @categoryDescription Core * Class & functions for customized use of prompt compression */ import { Tensor, } from "@huggingface/transformers"; import { softmax, tensor3d } from "@tensorflow/tfjs"; import { chunk } from "es-toolkit/array"; import { percentile, replace_added_token, } from "./utils.js"; /** * The TypeScript implementation on original `PromptCompressor`, which is a class for compressing prompts using a language model. * * @see [Original Implementation](https://github.com/microsoft/LLMLingua/blob/e4e172afb42d8ae3c0b6cb271a3f5d6a812846a0/llmlingua/prompt_compressor.py) * @category Core */ export class PromptCompressorLLMLingua2 { model; tokenizer; getPureToken; isBeginOfNewWord; oaiTokenizer; llmlingua2Config; logger; addedTokens = []; specialTokens; constructor( /** * The pre-trained model to use for compression. */ model, /** * The pre-trained tokenizer to use for compression. */ tokenizer, /** * Function to get the pure token from a token. * This is used to normalize tokens before processing. */ getPureToken, /** * Function to check if a token is the beginning of a new word. * This is used to determine how to merge tokens into words. */ isBeginOfNewWord, /** * The tokenizer to use calculating the compression rate. */ oaiTokenizer, /** * Configuration for LLMLingua2. */ llmlingua2Config = { /** * Maximum batch size for processing prompts. * This is used to limit the number of prompts processed in a single batch. */ max_batch_size: 50, /** * Maximum number of tokens to force in the compression. * This is used to ensure that certain tokens are always included in the compressed prompt. */ max_force_token: 100, /** * Maximum sequence length for the model. * This is used to limit the length of the input sequences to the model. */ max_seq_length: 512, }, /** * Logger function to log messages. */ logger = console.log) { this.model = model; this.tokenizer = tokenizer; this.getPureToken = getPureToken; this.isBeginOfNewWord = isBeginOfNewWord; this.oaiTokenizer = oaiTokenizer; this.llmlingua2Config = llmlingua2Config; this.logger = logger; for (let i = 0; i < this.llmlingua2Config.max_force_token; i++) { this.addedTokens.push(`[NEW${i}]`); } const specialTokensMap = this.tokenizer.special_tokens || {}; this.specialTokens = new Set(); for (const [key, value] of Object.entries(specialTokensMap)) { if (key !== "additional_special_tokens") { this.specialTokens.add(value); } } } /** * Compresses a prompt based on the given options. */ async compress(context, { rate, targetToken = -1, tokenToWord = "mean", forceTokens = [], forceReserveDigit = false, dropConsecutive = false, chunkEndTokens = [".", "\n"], }) { return this.compressSingleContext({ context, rate, target_token: targetToken, token_to_word: tokenToWord, force_tokens: forceTokens, force_reserve_digit: forceReserveDigit, drop_consecutive: dropConsecutive, chunk_end_tokens: chunkEndTokens, }); } /** * Compresses a prompt based on the given options. Alias for `compress`, but uses snake_case for options. * * @alias compress */ async compress_prompt(context, options) { return this.compress(context, { rate: options.rate, targetToken: options.target_token, tokenToWord: options.token_to_Word, forceTokens: options.force_tokens, forceReserveDigit: options.force_reserve_digit, dropConsecutive: options.drop_consecutive, chunkEndTokens: options.chunk_end_tokens, }); } async compressSingleContext(options) { let { context } = options; const { rate, target_token, token_to_word, force_tokens, force_reserve_digit, drop_consecutive, chunk_end_tokens, } = options; let token_map = {}; for (let i = 0; i < force_tokens.length; i++) { const token = force_tokens[i]; const tokenized = this.tokenizer.tokenize(token); if (tokenized.length !== 1) { token_map[token] = this.addedTokens[i]; } } const chunkEndTokenSet = new Set(chunk_end_tokens); chunk_end_tokens.forEach((token) => { if (token_map[token]) { chunkEndTokenSet.add(token_map[token]); } }); const n_original_token = this.getTokenLength(context); this.logger("original token length: appx. ", n_original_token.toLocaleString()); for (const [original, newToken] of Object.entries(token_map)) { context = context.replace(new RegExp(original, "g"), newToken); } const chunkedContexts = this.chunkContext(context, chunkEndTokenSet); this.logger("chunking finished. chunk count: ", chunkedContexts.length.toLocaleString()); let final_reduce_rate = 1.0 - rate; if (target_token > 0 && n_original_token > 0) { const rate_to_keep_for_token_level = Math.min(target_token / n_original_token, 1.0); final_reduce_rate = 1.0 - rate_to_keep_for_token_level; } const compressed_context_strs = await this.compressContexts(chunkedContexts, { reduce_rate: Math.max(0, final_reduce_rate), token_to_word, force_tokens, token_map, force_reserve_digit, drop_consecutive, }); this.logger("compression finished"); const final_compressed_context = compressed_context_strs.join("\n"); return final_compressed_context; } chunkContext(originText, chunkEndTokens) { const maxLenTokens = this.llmlingua2Config.max_seq_length - 2; const origin_list = []; const origin_tokens = this.tokenizer.tokenize(originText); const n = origin_tokens.length; let st = 0; while (st < n) { if (st + maxLenTokens > n - 1) { const chunk = this.tokenizer.decoder.decode(origin_tokens.slice(st, n - 1)); origin_list.push(chunk); break; } else { let ed = st + maxLenTokens; for (let j = 0; j < ed - st; j++) { if (chunkEndTokens.has(origin_tokens[ed - j])) { ed = ed - j; break; } } const chunk = this.tokenizer.decoder.decode(origin_tokens.slice(st, ed + 1)); origin_list.push(chunk); st = ed + 1; } } return origin_list; } getTokenLength(text) { return this.tokenizer.tokenize(text).length; } mergeTokenToWord(tokens, token_probs, force_tokens_original, token_map, force_reserve_digit) { const words = []; const word_probs_with_force_logic = []; const valid_token_probs_no_force = []; for (let i = 0; i < tokens.length; i++) { let token = tokens[i]; let prob = token_probs[i]; if (this.specialTokens.has(token)) { } else if (this.isBeginOfNewWord(token, force_tokens_original, token_map)) { const pure_token = this.getPureToken(token); const prob_no_force = prob; if (force_tokens_original.includes(pure_token) || Object.values(token_map).includes(pure_token)) { prob = 1.0; } token = replace_added_token(token, token_map); words.push(token); word_probs_with_force_logic.push([ force_reserve_digit && token.match(/^\d/) ? 1.0 : prob, ]); valid_token_probs_no_force.push([prob_no_force]); } else { const pure_token = this.getPureToken(token); words[words.length - 1] += pure_token; if (word_probs_with_force_logic.length === 0) { word_probs_with_force_logic.push([ force_reserve_digit && token.match(/^\d/) ? 1.0 : prob, ]); } else { word_probs_with_force_logic[word_probs_with_force_logic.length - 1].push(force_reserve_digit && token.match(/^\d/) ? 1.0 : prob); } if (valid_token_probs_no_force.length === 0) { valid_token_probs_no_force.push([prob]); } else { valid_token_probs_no_force[valid_token_probs_no_force.length - 1].push(prob); } } } return { words, word_probs_with_force_logic, valid_token_probs_no_force, }; } tokenProbToWordProb(tokenProbsPerWord, convertMode = "mean") { if (convertMode === "mean") { return tokenProbsPerWord.map((probs) => probs.reduce((sum, prob) => sum + prob, 0) / probs.length); } else if (convertMode === "first") { return tokenProbsPerWord.map((probs) => probs[0]); } throw new Error(`Unknown convertMode: ${convertMode}`); } async compressContexts(contexts, options) { const { reduce_rate, token_to_word, force_tokens, token_map, force_reserve_digit, drop_consecutive, } = options; if (reduce_rate <= 0) { return contexts; } else if (contexts.length === 0) { return []; } const compressed_chunk_strings_flat = []; const chunked_contexts = chunk(contexts, this.llmlingua2Config.max_batch_size); for (const context of chunked_contexts) { const { input_ids, attention_mask } = await this.tokenizer(context, { padding: true, truncation: true, }); this.logger("input tokenization finished"); const input_ids_dims = input_ids.dims; const outputs = await this.model({ input_ids, attention_mask, }); this.logger("model inference finished"); const [batch_size, seq_len, num_classes] = outputs.logits.dims; this.logger("logits shape:", outputs.logits.dims); const logits = tensor3d(outputs.logits.data, [batch_size, seq_len, num_classes], "float32"); this.logger("logits tensor created with shape:", logits.shape); const probs = softmax(logits, -1); for (let j = 0; j < batch_size; j++) { const chunk_probs_class1 = probs.slice([j, 0, 1], [1, -1, 1]); const chunk_ids = input_ids[j]; const chunk_mask = attention_mask[j]; const chunk_mask_number_array = Array.from(chunk_mask.data, (v) => Number(v)); const active_probs = chunk_probs_class1 .dataSync() .filter((_, i) => chunk_mask_number_array[i] > 0); const active_ids = chunk_ids.data .filter((_, i) => chunk_mask_number_array[i] > 0n) .filter((v) => v !== 0n); if (active_ids.length === 0) { compressed_chunk_strings_flat.push(""); continue; } const token_list = this.tokenizer.model.convert_ids_to_tokens(new Tensor("int64", active_ids, [active_ids.length]).tolist()); const token_prob_list = Array.from(active_probs); const { words, word_probs_with_force_logic } = this.mergeTokenToWord(token_list, token_prob_list, force_tokens, token_map, force_reserve_digit); const word_probs = this.tokenProbToWordProb(word_probs_with_force_logic, token_to_word); const new_token_probs = []; for (let i = 0; i < words.length; i++) { const word = words[i]; const word_prob = word_probs[i]; const new_token = this.oaiTokenizer.encode(word); new_token_probs.push(...Array(new_token.length).fill(word_prob)); } const threshold = percentile(new_token_probs, 100 * reduce_rate); const keep_words = []; for (let i = 0; i < words.length; i++) { const word = words[i]; const word_prob = word_probs[i]; if (word_prob > threshold || (threshold === 1.0 && word_prob == threshold)) { keep_words.push(word); } } const keep_str = replace_added_token(this.tokenizer.decoder.decode(keep_words), token_map); compressed_chunk_strings_flat.push(keep_str); } } return compressed_chunk_strings_flat; } } //# sourceMappingURL=prompt-compressor.js.map