@atjsh/llmlingua-2
Version:
JavaScript/TypeScript Implementation of LLMLingua-2
320 lines • 13.5 kB
JavaScript
// SPDX-License-Identifier: MIT
/**
* @categoryDescription Core
* Class & functions for customized use of prompt compression
*/
import { Tensor, } from "@huggingface/transformers";
import { softmax, tensor3d } from "@tensorflow/tfjs";
import { chunk } from "es-toolkit/array";
import { percentile, replace_added_token, } from "./utils.js";
/**
* The TypeScript implementation on original `PromptCompressor`, which is a class for compressing prompts using a language model.
*
* @see [Original Implementation](https://github.com/microsoft/LLMLingua/blob/e4e172afb42d8ae3c0b6cb271a3f5d6a812846a0/llmlingua/prompt_compressor.py)
* @category Core
*/
export class PromptCompressorLLMLingua2 {
model;
tokenizer;
getPureToken;
isBeginOfNewWord;
oaiTokenizer;
llmlingua2Config;
logger;
addedTokens = [];
specialTokens;
constructor(
/**
* The pre-trained model to use for compression.
*/
model,
/**
* The pre-trained tokenizer to use for compression.
*/
tokenizer,
/**
* Function to get the pure token from a token.
* This is used to normalize tokens before processing.
*/
getPureToken,
/**
* Function to check if a token is the beginning of a new word.
* This is used to determine how to merge tokens into words.
*/
isBeginOfNewWord,
/**
* The tokenizer to use calculating the compression rate.
*/
oaiTokenizer,
/**
* Configuration for LLMLingua2.
*/
llmlingua2Config = {
/**
* Maximum batch size for processing prompts.
* This is used to limit the number of prompts processed in a single batch.
*/
max_batch_size: 50,
/**
* Maximum number of tokens to force in the compression.
* This is used to ensure that certain tokens are always included in the compressed prompt.
*/
max_force_token: 100,
/**
* Maximum sequence length for the model.
* This is used to limit the length of the input sequences to the model.
*/
max_seq_length: 512,
},
/**
* Logger function to log messages.
*/
logger = console.log) {
this.model = model;
this.tokenizer = tokenizer;
this.getPureToken = getPureToken;
this.isBeginOfNewWord = isBeginOfNewWord;
this.oaiTokenizer = oaiTokenizer;
this.llmlingua2Config = llmlingua2Config;
this.logger = logger;
for (let i = 0; i < this.llmlingua2Config.max_force_token; i++) {
this.addedTokens.push(`[NEW${i}]`);
}
const specialTokensMap = this.tokenizer.special_tokens || {};
this.specialTokens = new Set();
for (const [key, value] of Object.entries(specialTokensMap)) {
if (key !== "additional_special_tokens") {
this.specialTokens.add(value);
}
}
}
/**
* Compresses a prompt based on the given options.
*/
async compress(context, { rate, targetToken = -1, tokenToWord = "mean", forceTokens = [], forceReserveDigit = false, dropConsecutive = false, chunkEndTokens = [".", "\n"], }) {
return this.compressSingleContext({
context,
rate,
target_token: targetToken,
token_to_word: tokenToWord,
force_tokens: forceTokens,
force_reserve_digit: forceReserveDigit,
drop_consecutive: dropConsecutive,
chunk_end_tokens: chunkEndTokens,
});
}
/**
* Compresses a prompt based on the given options. Alias for `compress`, but uses snake_case for options.
*
* @alias compress
*/
async compress_prompt(context, options) {
return this.compress(context, {
rate: options.rate,
targetToken: options.target_token,
tokenToWord: options.token_to_Word,
forceTokens: options.force_tokens,
forceReserveDigit: options.force_reserve_digit,
dropConsecutive: options.drop_consecutive,
chunkEndTokens: options.chunk_end_tokens,
});
}
async compressSingleContext(options) {
let { context } = options;
const { rate, target_token, token_to_word, force_tokens, force_reserve_digit, drop_consecutive, chunk_end_tokens, } = options;
let token_map = {};
for (let i = 0; i < force_tokens.length; i++) {
const token = force_tokens[i];
const tokenized = this.tokenizer.tokenize(token);
if (tokenized.length !== 1) {
token_map[token] = this.addedTokens[i];
}
}
const chunkEndTokenSet = new Set(chunk_end_tokens);
chunk_end_tokens.forEach((token) => {
if (token_map[token]) {
chunkEndTokenSet.add(token_map[token]);
}
});
const n_original_token = this.getTokenLength(context);
this.logger("original token length: appx. ", n_original_token.toLocaleString());
for (const [original, newToken] of Object.entries(token_map)) {
context = context.replace(new RegExp(original, "g"), newToken);
}
const chunkedContexts = this.chunkContext(context, chunkEndTokenSet);
this.logger("chunking finished. chunk count: ", chunkedContexts.length.toLocaleString());
let final_reduce_rate = 1.0 - rate;
if (target_token > 0 && n_original_token > 0) {
const rate_to_keep_for_token_level = Math.min(target_token / n_original_token, 1.0);
final_reduce_rate = 1.0 - rate_to_keep_for_token_level;
}
const compressed_context_strs = await this.compressContexts(chunkedContexts, {
reduce_rate: Math.max(0, final_reduce_rate),
token_to_word,
force_tokens,
token_map,
force_reserve_digit,
drop_consecutive,
});
this.logger("compression finished");
const final_compressed_context = compressed_context_strs.join("\n");
return final_compressed_context;
}
chunkContext(originText, chunkEndTokens) {
const maxLenTokens = this.llmlingua2Config.max_seq_length - 2;
const origin_list = [];
const origin_tokens = this.tokenizer.tokenize(originText);
const n = origin_tokens.length;
let st = 0;
while (st < n) {
if (st + maxLenTokens > n - 1) {
const chunk = this.tokenizer.decoder.decode(origin_tokens.slice(st, n - 1));
origin_list.push(chunk);
break;
}
else {
let ed = st + maxLenTokens;
for (let j = 0; j < ed - st; j++) {
if (chunkEndTokens.has(origin_tokens[ed - j])) {
ed = ed - j;
break;
}
}
const chunk = this.tokenizer.decoder.decode(origin_tokens.slice(st, ed + 1));
origin_list.push(chunk);
st = ed + 1;
}
}
return origin_list;
}
getTokenLength(text) {
return this.tokenizer.tokenize(text).length;
}
mergeTokenToWord(tokens, token_probs, force_tokens_original, token_map, force_reserve_digit) {
const words = [];
const word_probs_with_force_logic = [];
const valid_token_probs_no_force = [];
for (let i = 0; i < tokens.length; i++) {
let token = tokens[i];
let prob = token_probs[i];
if (this.specialTokens.has(token)) {
}
else if (this.isBeginOfNewWord(token, force_tokens_original, token_map)) {
const pure_token = this.getPureToken(token);
const prob_no_force = prob;
if (force_tokens_original.includes(pure_token) ||
Object.values(token_map).includes(pure_token)) {
prob = 1.0;
}
token = replace_added_token(token, token_map);
words.push(token);
word_probs_with_force_logic.push([
force_reserve_digit && token.match(/^\d/) ? 1.0 : prob,
]);
valid_token_probs_no_force.push([prob_no_force]);
}
else {
const pure_token = this.getPureToken(token);
words[words.length - 1] += pure_token;
if (word_probs_with_force_logic.length === 0) {
word_probs_with_force_logic.push([
force_reserve_digit && token.match(/^\d/) ? 1.0 : prob,
]);
}
else {
word_probs_with_force_logic[word_probs_with_force_logic.length - 1].push(force_reserve_digit && token.match(/^\d/) ? 1.0 : prob);
}
if (valid_token_probs_no_force.length === 0) {
valid_token_probs_no_force.push([prob]);
}
else {
valid_token_probs_no_force[valid_token_probs_no_force.length - 1].push(prob);
}
}
}
return {
words,
word_probs_with_force_logic,
valid_token_probs_no_force,
};
}
tokenProbToWordProb(tokenProbsPerWord, convertMode = "mean") {
if (convertMode === "mean") {
return tokenProbsPerWord.map((probs) => probs.reduce((sum, prob) => sum + prob, 0) / probs.length);
}
else if (convertMode === "first") {
return tokenProbsPerWord.map((probs) => probs[0]);
}
throw new Error(`Unknown convertMode: ${convertMode}`);
}
async compressContexts(contexts, options) {
const { reduce_rate, token_to_word, force_tokens, token_map, force_reserve_digit, drop_consecutive, } = options;
if (reduce_rate <= 0) {
return contexts;
}
else if (contexts.length === 0) {
return [];
}
const compressed_chunk_strings_flat = [];
const chunked_contexts = chunk(contexts, this.llmlingua2Config.max_batch_size);
for (const context of chunked_contexts) {
const { input_ids, attention_mask } = await this.tokenizer(context, {
padding: true,
truncation: true,
});
this.logger("input tokenization finished");
const input_ids_dims = input_ids.dims;
const outputs = await this.model({
input_ids,
attention_mask,
});
this.logger("model inference finished");
const [batch_size, seq_len, num_classes] = outputs.logits.dims;
this.logger("logits shape:", outputs.logits.dims);
const logits = tensor3d(outputs.logits.data, [batch_size, seq_len, num_classes], "float32");
this.logger("logits tensor created with shape:", logits.shape);
const probs = softmax(logits, -1);
for (let j = 0; j < batch_size; j++) {
const chunk_probs_class1 = probs.slice([j, 0, 1], [1, -1, 1]);
const chunk_ids = input_ids[j];
const chunk_mask = attention_mask[j];
const chunk_mask_number_array = Array.from(chunk_mask.data, (v) => Number(v));
const active_probs = chunk_probs_class1
.dataSync()
.filter((_, i) => chunk_mask_number_array[i] > 0);
const active_ids = chunk_ids.data
.filter((_, i) => chunk_mask_number_array[i] > 0n)
.filter((v) => v !== 0n);
if (active_ids.length === 0) {
compressed_chunk_strings_flat.push("");
continue;
}
const token_list = this.tokenizer.model.convert_ids_to_tokens(new Tensor("int64", active_ids, [active_ids.length]).tolist());
const token_prob_list = Array.from(active_probs);
const { words, word_probs_with_force_logic } = this.mergeTokenToWord(token_list, token_prob_list, force_tokens, token_map, force_reserve_digit);
const word_probs = this.tokenProbToWordProb(word_probs_with_force_logic, token_to_word);
const new_token_probs = [];
for (let i = 0; i < words.length; i++) {
const word = words[i];
const word_prob = word_probs[i];
const new_token = this.oaiTokenizer.encode(word);
new_token_probs.push(...Array(new_token.length).fill(word_prob));
}
const threshold = percentile(new_token_probs, 100 * reduce_rate);
const keep_words = [];
for (let i = 0; i < words.length; i++) {
const word = words[i];
const word_prob = word_probs[i];
if (word_prob > threshold ||
(threshold === 1.0 && word_prob == threshold)) {
keep_words.push(word);
}
}
const keep_str = replace_added_token(this.tokenizer.decoder.decode(keep_words), token_map);
compressed_chunk_strings_flat.push(keep_str);
}
}
return compressed_chunk_strings_flat;
}
}
//# sourceMappingURL=prompt-compressor.js.map