@atjsh/llmlingua-2
Version:
JavaScript/TypeScript Implementation of LLMLingua-2
138 lines (128 loc) • 5.28 kB
JavaScript
// SPDX-License-Identifier: MIT
/**
* @categoryDescription Factory
A collection of utility functions and types for model-specific token handling.
*
* @showCategories
*/
import { AutoConfig, AutoModelForTokenClassification, AutoTokenizer, } from "@huggingface/transformers";
import { PromptCompressorLLMLingua2 } from "./prompt-compressor.js";
import { get_pure_tokens_bert_base_multilingual_cased, get_pure_tokens_xlm_roberta_large, is_begin_of_new_word_bert_base_multilingual_cased, is_begin_of_new_word_xlm_roberta_large, } from "./utils.js";
async function prepareDependencies(modelName, transformerJSConfig, logger, pretrainedConfig, pretrainedTokenizerOptions, modelSpecificOptions) {
const config = pretrainedConfig ?? (await AutoConfig.from_pretrained(modelName));
logger({ config });
const tokenizerConfig = {
config: {
...config,
...(transformerJSConfig
? { "transformers.js_config": transformerJSConfig }
: {}),
},
...pretrainedTokenizerOptions,
};
logger({ tokenizerConfig });
const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig);
logger({ tokenizer });
const modelConfig = {
config: {
...config,
...(transformerJSConfig
? { "transformers.js_config": transformerJSConfig }
: {}),
},
...modelSpecificOptions,
};
logger({ modelConfig });
const model = await AutoModelForTokenClassification.from_pretrained(modelName, modelConfig);
logger({ model });
return { model, tokenizer, config };
}
/**
* Factory functions to create instances of LLMLingua-2 PromptCompressor
* with XLM-RoBERTa model.
*
* @category Factory
*
* @example
* ```ts
import { LLMLingua2 } from "@atjsh/llmlingua-2";
import { Tiktoken } from "js-tiktoken/lite";
import o200k_base from "js-tiktoken/ranks/o200k_base";
const modelName = "atjsh/llmlingua-2-js-xlm-roberta-large-meetingbank";
const oai_tokenizer = new Tiktoken(o200k_base);
const { promptCompressor } = await LLMLingua2.WithXLMRoBERTa(modelName,
{
transformerJSConfig: {
device: "auto",
dtype: "fp32",
},
oaiTokenizer: oai_tokenizer,
modelSpecificOptions: {
use_external_data_format: true,
},
}
);
const compressedText: string = await promptCompressor.compress_prompt(
"LLMLingua-2, a small-size yet powerful prompt compression method trained via data distillation from GPT-4 for token classification with a BERT-level encoder, excels in task-agnostic compression. It surpasses LLMLingua in handling out-of-domain data, offering 3x-6x faster performance.",
{ rate: 0.8 }
);
console.log({ compressedText });
```
*/
export async function WithXLMRoBERTa(modelName, options) {
const { transformerJSConfig, oaiTokenizer, pretrainedConfig, pretrainedTokenizerOptions, modelSpecificOptions, logger = console.log, } = options;
const { model, tokenizer, config } = await prepareDependencies(modelName, transformerJSConfig, logger, pretrainedConfig, pretrainedTokenizerOptions, modelSpecificOptions);
const promptCompressor = new PromptCompressorLLMLingua2(model, tokenizer, get_pure_tokens_xlm_roberta_large, is_begin_of_new_word_xlm_roberta_large, oaiTokenizer);
logger({ promptCompressor });
return {
promptCompressor,
model,
tokenizer,
config,
};
}
/**
* Factory functions to create instances of LLMLingua-2 PromptCompressor
* with BERT Multilingual model.
*
* @category Factory
*
* @example
* ```ts
import { LLMLingua2 } from "@atjsh/llmlingua-2";
import { Tiktoken } from "js-tiktoken/lite";
import o200k_base from "js-tiktoken/ranks/o200k_base";
const modelName = "Arcoldd/llmlingua4j-bert-base-onnx";
const oai_tokenizer = new Tiktoken(o200k_base);
const { promptCompressor } = await LLMLingua2.WithBERTMultilingual(modelName,
{
transformerJSConfig: {
device: "auto",
dtype: "fp32",
},
oaiTokenizer: oai_tokenizer,
modelSpecificOptions: {
subfolder: "",
},
}
);
const compressedText: string = await promptCompressor.compress_prompt(
"LLMLingua-2, a small-size yet powerful prompt compression method trained via data distillation from GPT-4 for token classification with a BERT-level encoder, excels in task-agnostic compression. It surpasses LLMLingua in handling out-of-domain data, offering 3x-6x faster performance.",
{ rate: 0.8 }
);
console.log({ compressedText });
```
*/
export async function WithBERTMultilingual(modelName, options) {
const { transformerJSConfig, oaiTokenizer, pretrainedConfig, pretrainedTokenizerOptions, modelSpecificOptions, logger = console.log, } = options;
const { model, tokenizer, config } = await prepareDependencies(modelName, transformerJSConfig, logger, pretrainedConfig, pretrainedTokenizerOptions, modelSpecificOptions);
const promptCompressor = new PromptCompressorLLMLingua2(model, tokenizer, get_pure_tokens_bert_base_multilingual_cased, is_begin_of_new_word_bert_base_multilingual_cased, oaiTokenizer);
logger({ promptCompressor });
return {
promptCompressor,
model,
tokenizer,
config,
};
}
//# sourceMappingURL=factory.js.map