UNPKG

@atjsh/llmlingua-2

Version:

JavaScript/TypeScript Implementation of LLMLingua-2

138 lines (128 loc) 5.28 kB
// SPDX-License-Identifier: MIT /** * @categoryDescription Factory A collection of utility functions and types for model-specific token handling. * * @showCategories */ import { AutoConfig, AutoModelForTokenClassification, AutoTokenizer, } from "@huggingface/transformers"; import { PromptCompressorLLMLingua2 } from "./prompt-compressor.js"; import { get_pure_tokens_bert_base_multilingual_cased, get_pure_tokens_xlm_roberta_large, is_begin_of_new_word_bert_base_multilingual_cased, is_begin_of_new_word_xlm_roberta_large, } from "./utils.js"; async function prepareDependencies(modelName, transformerJSConfig, logger, pretrainedConfig, pretrainedTokenizerOptions, modelSpecificOptions) { const config = pretrainedConfig ?? (await AutoConfig.from_pretrained(modelName)); logger({ config }); const tokenizerConfig = { config: { ...config, ...(transformerJSConfig ? { "transformers.js_config": transformerJSConfig } : {}), }, ...pretrainedTokenizerOptions, }; logger({ tokenizerConfig }); const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig); logger({ tokenizer }); const modelConfig = { config: { ...config, ...(transformerJSConfig ? { "transformers.js_config": transformerJSConfig } : {}), }, ...modelSpecificOptions, }; logger({ modelConfig }); const model = await AutoModelForTokenClassification.from_pretrained(modelName, modelConfig); logger({ model }); return { model, tokenizer, config }; } /** * Factory functions to create instances of LLMLingua-2 PromptCompressor * with XLM-RoBERTa model. * * @category Factory * * @example * ```ts import { LLMLingua2 } from "@atjsh/llmlingua-2"; import { Tiktoken } from "js-tiktoken/lite"; import o200k_base from "js-tiktoken/ranks/o200k_base"; const modelName = "atjsh/llmlingua-2-js-xlm-roberta-large-meetingbank"; const oai_tokenizer = new Tiktoken(o200k_base); const { promptCompressor } = await LLMLingua2.WithXLMRoBERTa(modelName, { transformerJSConfig: { device: "auto", dtype: "fp32", }, oaiTokenizer: oai_tokenizer, modelSpecificOptions: { use_external_data_format: true, }, } ); const compressedText: string = await promptCompressor.compress_prompt( "LLMLingua-2, a small-size yet powerful prompt compression method trained via data distillation from GPT-4 for token classification with a BERT-level encoder, excels in task-agnostic compression. It surpasses LLMLingua in handling out-of-domain data, offering 3x-6x faster performance.", { rate: 0.8 } ); console.log({ compressedText }); ``` */ export async function WithXLMRoBERTa(modelName, options) { const { transformerJSConfig, oaiTokenizer, pretrainedConfig, pretrainedTokenizerOptions, modelSpecificOptions, logger = console.log, } = options; const { model, tokenizer, config } = await prepareDependencies(modelName, transformerJSConfig, logger, pretrainedConfig, pretrainedTokenizerOptions, modelSpecificOptions); const promptCompressor = new PromptCompressorLLMLingua2(model, tokenizer, get_pure_tokens_xlm_roberta_large, is_begin_of_new_word_xlm_roberta_large, oaiTokenizer); logger({ promptCompressor }); return { promptCompressor, model, tokenizer, config, }; } /** * Factory functions to create instances of LLMLingua-2 PromptCompressor * with BERT Multilingual model. * * @category Factory * * @example * ```ts import { LLMLingua2 } from "@atjsh/llmlingua-2"; import { Tiktoken } from "js-tiktoken/lite"; import o200k_base from "js-tiktoken/ranks/o200k_base"; const modelName = "Arcoldd/llmlingua4j-bert-base-onnx"; const oai_tokenizer = new Tiktoken(o200k_base); const { promptCompressor } = await LLMLingua2.WithBERTMultilingual(modelName, { transformerJSConfig: { device: "auto", dtype: "fp32", }, oaiTokenizer: oai_tokenizer, modelSpecificOptions: { subfolder: "", }, } ); const compressedText: string = await promptCompressor.compress_prompt( "LLMLingua-2, a small-size yet powerful prompt compression method trained via data distillation from GPT-4 for token classification with a BERT-level encoder, excels in task-agnostic compression. It surpasses LLMLingua in handling out-of-domain data, offering 3x-6x faster performance.", { rate: 0.8 } ); console.log({ compressedText }); ``` */ export async function WithBERTMultilingual(modelName, options) { const { transformerJSConfig, oaiTokenizer, pretrainedConfig, pretrainedTokenizerOptions, modelSpecificOptions, logger = console.log, } = options; const { model, tokenizer, config } = await prepareDependencies(modelName, transformerJSConfig, logger, pretrainedConfig, pretrainedTokenizerOptions, modelSpecificOptions); const promptCompressor = new PromptCompressorLLMLingua2(model, tokenizer, get_pure_tokens_bert_base_multilingual_cased, is_begin_of_new_word_bert_base_multilingual_cased, oaiTokenizer); logger({ promptCompressor }); return { promptCompressor, model, tokenizer, config, }; } //# sourceMappingURL=factory.js.map