UNPKG

transformers-fork

Version:

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!

1,188 lines (1,064 loc) 147 kB
/** * @file Pipelines provide a high-level, easy to use, API for running machine learning models. * * **Example:** Instantiate pipeline using the `pipeline` function. * ```javascript * import { pipeline } from '@huggingface/transformers'; * * const classifier = await pipeline('sentiment-analysis'); * const output = await classifier('I love transformers!'); * // [{'label': 'POSITIVE', 'score': 0.999817686}] * ``` * * @module pipelines */ import { AutoTokenizer, PreTrainedTokenizer, } from './tokenizers.js'; import { AutoModel, AutoModelForSequenceClassification, AutoModelForAudioClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForMaskedLM, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoModelForTextToWaveform, AutoModelForTextToSpectrogram, AutoModelForCTC, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForSemanticSegmentation, AutoModelForUniversalSegmentation, AutoModelForObjectDetection, AutoModelForZeroShotObjectDetection, AutoModelForDocumentQuestionAnswering, AutoModelForImageToImage, AutoModelForDepthEstimation, AutoModelForImageFeatureExtraction, PreTrainedModel, } from './models.js'; import { AutoProcessor, } from './models/auto/processing_auto.js'; import { Processor, } from './base/processing_utils.js'; import { Callable, } from './utils/generic.js'; import { dispatchCallback, product, } from './utils/core.js'; import { softmax, max, round, } from './utils/maths.js'; import { read_audio } from './utils/audio.js'; import { Tensor, mean_pooling, interpolate, quantize_embeddings, topk, } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; /** * @typedef {string | RawImage | URL} ImageInput * @typedef {ImageInput|ImageInput[]} ImagePipelineInputs */ /** * Prepare images for further tasks. * @param {ImagePipelineInputs} images images to prepare. * @returns {Promise<RawImage[]>} returns processed images. * @private */ async function prepareImages(images) { if (!Array.isArray(images)) { images = [images]; } // Possibly convert any non-images to images return await Promise.all(images.map(x => RawImage.read(x))); } /** * @typedef {string | URL | Float32Array | Float64Array} AudioInput * @typedef {AudioInput|AudioInput[]} AudioPipelineInputs */ /** * Prepare audios for further tasks. * @param {AudioPipelineInputs} audios audios to prepare. * @param {number} sampling_rate sampling rate of the audios. * @returns {Promise<Float32Array[]>} The preprocessed audio data. * @private */ async function prepareAudios(audios, sampling_rate) { if (!Array.isArray(audios)) { audios = [audios]; } return await Promise.all(audios.map(x => { if (typeof x === 'string' || x instanceof URL) { return read_audio(x, sampling_rate); } else if (x instanceof Float64Array) { return new Float32Array(x); } return x; })); } /** * @typedef {Object} BoundingBox * @property {number} xmin The minimum x coordinate of the bounding box. * @property {number} ymin The minimum y coordinate of the bounding box. * @property {number} xmax The maximum x coordinate of the bounding box. * @property {number} ymax The maximum y coordinate of the bounding box. */ /** * Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... } * @param {number[]} box The bounding box as a list. * @param {boolean} asInteger Whether to cast to integers. * @returns {BoundingBox} The bounding box as an object. * @private */ function get_bounding_box(box, asInteger) { if (asInteger) { box = box.map(x => x | 0); } const [xmin, ymin, xmax, ymax] = box; return { xmin, ymin, xmax, ymax }; } /** * @callback DisposeType Disposes the item. * @returns {Promise<void>} A promise that resolves when the item has been disposed. * * @typedef {Object} Disposable * @property {DisposeType} dispose A promise that resolves when the pipeline has been disposed. */ /** * The Pipeline class is the class from which all pipelines inherit. * Refer to this class for methods shared across different pipelines. */ export class Pipeline extends Callable { /** * Create a new Pipeline. * @param {Object} options An object containing the following properties: * @param {string} [options.task] The task of the pipeline. Useful for specifying subtasks. * @param {PreTrainedModel} [options.model] The model used by the pipeline. * @param {PreTrainedTokenizer} [options.tokenizer=null] The tokenizer used by the pipeline (if any). * @param {Processor} [options.processor=null] The processor used by the pipeline (if any). */ constructor({ task, model, tokenizer = null, processor = null }) { super(); this.task = task; this.model = model; this.tokenizer = tokenizer; this.processor = processor; } /** @type {DisposeType} */ async dispose() { await this.model.dispose(); } } /** * @typedef {Object} ModelTokenizerConstructorArgs * @property {string} task The task of the pipeline. Useful for specifying subtasks. * @property {PreTrainedModel} model The model used by the pipeline. * @property {PreTrainedTokenizer} tokenizer The tokenizer used by the pipeline. * * @typedef {ModelTokenizerConstructorArgs} TextPipelineConstructorArgs An object used to instantiate a text-based pipeline. */ /** * @typedef {Object} ModelProcessorConstructorArgs * @property {string} task The task of the pipeline. Useful for specifying subtasks. * @property {PreTrainedModel} model The model used by the pipeline. * @property {Processor} processor The processor used by the pipeline. * * @typedef {ModelProcessorConstructorArgs} AudioPipelineConstructorArgs An object used to instantiate an audio-based pipeline. * @typedef {ModelProcessorConstructorArgs} ImagePipelineConstructorArgs An object used to instantiate an image-based pipeline. */ /** * @typedef {Object} ModelTokenizerProcessorConstructorArgs * @property {string} task The task of the pipeline. Useful for specifying subtasks. * @property {PreTrainedModel} model The model used by the pipeline. * @property {PreTrainedTokenizer} tokenizer The tokenizer used by the pipeline. * @property {Processor} processor The processor used by the pipeline. * * @typedef {ModelTokenizerProcessorConstructorArgs} TextAudioPipelineConstructorArgs An object used to instantiate a text- and audio-based pipeline. * @typedef {ModelTokenizerProcessorConstructorArgs} TextImagePipelineConstructorArgs An object used to instantiate a text- and image-based pipeline. */ /** * @typedef {Object} TextClassificationSingle * @property {string} label The label predicted. * @property {number} score The corresponding probability. * @typedef {TextClassificationSingle[]} TextClassificationOutput * * @typedef {Object} TextClassificationPipelineOptions Parameters specific to text classification pipelines. * @property {number} [top_k=1] The number of top predictions to be returned. * * @callback TextClassificationPipelineCallback Classify the text(s) given as inputs. * @param {string|string[]} texts The input text(s) to be classified. * @param {TextClassificationPipelineOptions} [options] The options to use for text classification. * @returns {Promise<TextClassificationOutput|TextClassificationOutput[]>} An array or object containing the predicted labels and scores. * * @typedef {TextPipelineConstructorArgs & TextClassificationPipelineCallback & Disposable} TextClassificationPipelineType */ /** * Text classification pipeline using any `ModelForSequenceClassification`. * * **Example:** Sentiment-analysis w/ `Xenova/distilbert-base-uncased-finetuned-sst-2-english`. * ```javascript * const classifier = await pipeline('sentiment-analysis', 'Xenova/distilbert-base-uncased-finetuned-sst-2-english'); * const output = await classifier('I love transformers!'); * // [{ label: 'POSITIVE', score: 0.999788761138916 }] * ``` * * **Example:** Multilingual sentiment-analysis w/ `Xenova/bert-base-multilingual-uncased-sentiment` (and return top 5 classes). * ```javascript * const classifier = await pipeline('sentiment-analysis', 'Xenova/bert-base-multilingual-uncased-sentiment'); * const output = await classifier('Le meilleur film de tous les temps.', { top_k: 5 }); * // [ * // { label: '5 stars', score: 0.9610759615898132 }, * // { label: '4 stars', score: 0.03323351591825485 }, * // { label: '3 stars', score: 0.0036155181005597115 }, * // { label: '1 star', score: 0.0011325967498123646 }, * // { label: '2 stars', score: 0.0009423971059732139 } * // ] * ``` * * **Example:** Toxic comment classification w/ `Xenova/toxic-bert` (and return all classes). * ```javascript * const classifier = await pipeline('text-classification', 'Xenova/toxic-bert'); * const output = await classifier('I hate you!', { top_k: null }); * // [ * // { label: 'toxic', score: 0.9593140482902527 }, * // { label: 'insult', score: 0.16187334060668945 }, * // { label: 'obscene', score: 0.03452680632472038 }, * // { label: 'identity_hate', score: 0.0223250575363636 }, * // { label: 'threat', score: 0.019197041168808937 }, * // { label: 'severe_toxic', score: 0.005651099607348442 } * // ] * ``` */ export class TextClassificationPipeline extends (/** @type {new (options: TextPipelineConstructorArgs) => TextClassificationPipelineType} */ (Pipeline)) { /** * Create a new TextClassificationPipeline. * @param {TextPipelineConstructorArgs} options An object used to instantiate the pipeline. */ constructor(options) { super(options); } /** @type {TextClassificationPipelineCallback} */ async _call(texts, { top_k = 1 } = {}) { // Run tokenization const model_inputs = this.tokenizer(texts, { padding: true, truncation: true, }); // Run model const outputs = await this.model(model_inputs) // TODO: Use softmax tensor function const function_to_apply = this.model.config.problem_type === 'multi_label_classification' ? batch => batch.sigmoid() : batch => new Tensor( 'float32', softmax(batch.data), batch.dims, ); // single_label_classification (default) const id2label = this.model.config.id2label; const toReturn = []; for (const batch of outputs.logits) { const output = function_to_apply(batch); const scores = await topk(output, top_k); const values = scores[0].tolist(); const indices = scores[1].tolist(); const vals = indices.map((x, i) => ({ label: id2label ? id2label[x] : `LABEL_${x}`, score: values[i], })); if (top_k === 1) { toReturn.push(...vals); } else { toReturn.push(vals); } } return Array.isArray(texts) || top_k === 1 ? /** @type {TextClassificationOutput} */ (toReturn) : /** @type {TextClassificationOutput[]} */ (toReturn)[0]; } } /** * @typedef {Object} TokenClassificationSingle * @property {string} word The token/word classified. This is obtained by decoding the selected tokens. * @property {number} score The corresponding probability for `entity`. * @property {string} entity The entity predicted for that token/word. * @property {number} index The index of the corresponding token in the sentence. * @property {number} [start] The index of the start of the corresponding entity in the sentence. * @property {number} [end] The index of the end of the corresponding entity in the sentence. * @typedef {TokenClassificationSingle[]} TokenClassificationOutput * * @typedef {Object} TokenClassificationPipelineOptions Parameters specific to token classification pipelines. * @property {string[]} [ignore_labels] A list of labels to ignore. * @property {string} [aggregation_strategy] Aggregation_strategy: 'none', 'simple', 'first', 'max', 'average'. * * @callback TokenClassificationPipelineCallback Classify each token of the text(s) given as inputs. * @param {string|string[]} texts One or several texts (or one list of texts) for token classification. * @param {TokenClassificationPipelineOptions} [options] The options to use for token classification. * @returns {Promise<TokenClassificationOutput|TokenClassificationOutput[]>} The result. * * @typedef {TextPipelineConstructorArgs & TokenClassificationPipelineCallback & Disposable} TokenClassificationPipelineType */ /** * Named Entity Recognition pipeline using any `ModelForTokenClassification`. * * **Example:** Perform named entity recognition with `Xenova/bert-base-NER`. * ```javascript * const classifier = await pipeline('token-classification', 'Xenova/bert-base-NER'); * const output = await classifier('My name is Sarah and I live in London'); * // [ * // { entity: 'B-PER', score: 0.9980202913284302, index: 4, word: 'Sarah' }, * // { entity: 'B-LOC', score: 0.9994474053382874, index: 9, word: 'London' } * // ] * ``` * * **Example:** Perform named entity recognition with `Xenova/bert-base-NER` (and return all labels). * ```javascript * const classifier = await pipeline('token-classification', 'Xenova/bert-base-NER'); * const output = await classifier('Sarah lives in the United States of America', { ignore_labels: [] }); * // [ * // { entity: 'B-PER', score: 0.9966587424278259, index: 1, word: 'Sarah' }, * // { entity: 'O', score: 0.9987385869026184, index: 2, word: 'lives' }, * // { entity: 'O', score: 0.9990072846412659, index: 3, word: 'in' }, * // { entity: 'O', score: 0.9988298416137695, index: 4, word: 'the' }, * // { entity: 'B-LOC', score: 0.9995510578155518, index: 5, word: 'United' }, * // { entity: 'I-LOC', score: 0.9990395307540894, index: 6, word: 'States' }, * // { entity: 'I-LOC', score: 0.9986724853515625, index: 7, word: 'of' }, * // { entity: 'I-LOC', score: 0.9975294470787048, index: 8, word: 'America' } * // ] * * const output2 = await classifier('Sarah lives in the United States of America', { aggregation_strategy: 'average' }); * // [ * // {"entity": "PER", "score": 0.983073890209198, "index": null, "word": "Sarah", "start": 0, "end": 5}, * // {"entity": "LOC", "score": 0.9850180596113205, "index": null, "word": "United States of America", "start": 19, "end": 43} * // ] * ``` */ export class TokenClassificationPipeline extends (/** @type {new (options: TextPipelineConstructorArgs) => TokenClassificationPipelineType} */ (Pipeline)) { /** * Create a new TokenClassificationPipeline. * @param {TextPipelineConstructorArgs} options An object used to instantiate the pipeline. */ constructor(options) { super(options); } /** @type {TokenClassificationPipelineCallback} */ async _call(texts, { ignore_labels = ['O'], aggregation_strategy = 'none', } = {}) { const isBatched = Array.isArray(texts); // Run tokenization const model_inputs = this.tokenizer(isBatched ? texts : [texts], { padding: true, truncation: true, }); // Run model const outputs = await this.model(model_inputs) const logits = outputs.logits; const id2label = this.model.config.id2label; const toReturn = []; for (let i = 0; i < logits.dims[0]; ++i) { const ids = model_inputs.input_ids[i]; const batch = logits[i]; // List of tokens that aren't ignored const tokens = []; for (let j = 0; j < batch.dims[0]; ++j) { const tokenData = batch[j]; const topScoreIndex = max(tokenData.data)[1]; const entity = id2label ? id2label[topScoreIndex] : `LABEL_${topScoreIndex}`; if (ignore_labels.includes(entity)) { // We predicted a token that should be ignored. So, we skip it. continue; } // TODO add option to keep special tokens? const word = this.tokenizer.decode([ids[j].item()], { skip_special_tokens: true }); if (word === '') { // Was a special token. So, we skip it. continue; } const scores = softmax(tokenData.data); tokens.push({ entity: entity, score: scores[topScoreIndex], index: j, word: word, start: null, end: null, }); } toReturn.push(tokens); } // aggregation_strategy if (!['none', 'simple', 'first', 'max', 'average'].includes(aggregation_strategy)) { console.warn('Unknown aggregation_strategy.'); aggregation_strategy = 'none'; } let toReturn2 = []; if (aggregation_strategy != 'none') { toReturn2 = Array.from(toReturn); toReturn.length = 0; } // Tagging schemes in NER // I => “inside”, O => “outside”, B => “beginning”, E => “end”, S => “single token entity”. // Convert to BIO toReturn2.forEach(tokens => { let tags = ''; tokens.forEach((token, i) => { tags += token.entity[0]; }) if (tags.includes('E')) { tags = tags.replaceAll(/I(I*)E/g, 'B$1I').replaceAll(/E/g, 'B'); } if (tags.includes('S')) { tags = tags.replaceAll(/S/g, 'B'); } tokens.forEach((token, i) => { tokens[i].entity = tags[i] + tokens[i].entity.substring(1); }) }) // Aggregate toReturn2.forEach(tokens => { let agg_token = {}; toReturn.push([]); tokens.forEach((token, i) => { if (!agg_token.entity) { agg_token = { entity: [token.entity], score: [token.score], index: [token.index], word: token.word, start: null, end: null, }; } else { agg_token.entity.push(token.entity); agg_token.score.push(token.score); agg_token.index.push(token.index); agg_token.word += (token.word.includes('#') ? '' : ' ') + token.word.replaceAll('#', ''); } if ( i == tokens.length - 1 || (tokens[i + 1].index - token.index > 1) || (tokens[i + 1].entity[0] != 'I' && tokens[i + 1].entity[0] != token.entity[0]) || (aggregation_strategy == 'simple' && tokens[i + 1].entity[0] == 'B') ) { if (aggregation_strategy == 'simple' || aggregation_strategy == 'first') { agg_token.entity = agg_token.entity[0].substring(2); agg_token.score = agg_token.score[0]; } else { const _max = Math.max(...agg_token.score); agg_token.entity = agg_token.entity[agg_token.score.indexOf(_max)].substring(2); if (aggregation_strategy == 'max') { agg_token.score = _max; } else if (aggregation_strategy == 'average') { agg_token.score = (arr => arr.reduce((a, b, c, d) => (a + b / d.length), 0))(agg_token.score); } } agg_token.index = null; toReturn.at(-1).push(agg_token); agg_token = {}; } }) }) // start end tokens this.tokenizer.get_offsets_mapping(toReturn.map(x=>x.map(x=>x.word)), texts).forEach((offsets, i) => { offsets.forEach((offset, j) => { // toReturn[i][j].start = offset[2].includes('#') ? toReturn[i][j-1].end : offset[0]; // toReturn[i][j].end = offset[2].includes('#') ? toReturn[i][j-1].end + offset[2].replaceAll('#', '').length : offset[1]; if (j > 0 && toReturn[i][j-1]) { toReturn[i][j].start = offset[2].includes('#') ? toReturn[i][j-1].end : offset[0]; toReturn[i][j].end = offset[2].includes('#') ? toReturn[i][j-1].end + offset[2].replaceAll('#', '').length : offset[1]; } else { toReturn[i][j].start = offset[0]; toReturn[i][j].end = offset[1]; } }) }) return isBatched ? toReturn : toReturn[0]; } } /** * @typedef {Object} QuestionAnsweringOutput * @property {number} score The probability associated to the answer. * @property {number} [start] The character start index of the answer (in the tokenized version of the input). * @property {number} [end] The character end index of the answer (in the tokenized version of the input). * @property {string} answer The answer to the question. * * @typedef {Object} QuestionAnsweringPipelineOptions Parameters specific to question answering pipelines. * @property {number} [top_k=1] The number of top answer predictions to be returned. * * @callback QuestionAnsweringPipelineCallback Answer the question(s) given as inputs by using the context(s). * @param {string|string[]} question One or several question(s) (must be used in conjunction with the `context` argument). * @param {string|string[]} context One or several context(s) associated with the question(s) (must be used in conjunction with the `question` argument). * @param {QuestionAnsweringPipelineOptions} [options] The options to use for question answering. * @returns {Promise<QuestionAnsweringOutput|QuestionAnsweringOutput[]>} An array or object containing the predicted answers and scores. * * @typedef {TextPipelineConstructorArgs & QuestionAnsweringPipelineCallback & Disposable} QuestionAnsweringPipelineType */ /** * Question Answering pipeline using any `ModelForQuestionAnswering`. * * **Example:** Run question answering with `Xenova/distilbert-base-uncased-distilled-squad`. * ```javascript * const answerer = await pipeline('question-answering', 'Xenova/distilbert-base-uncased-distilled-squad'); * const question = 'Who was Jim Henson?'; * const context = 'Jim Henson was a nice puppet.'; * const output = await answerer(question, context); * // { * // answer: "a nice puppet", * // score: 0.5768911502526741 * // } * ``` */ export class QuestionAnsweringPipeline extends (/** @type {new (options: TextPipelineConstructorArgs) => QuestionAnsweringPipelineType} */ (Pipeline)) { /** * Create a new QuestionAnsweringPipeline. * @param {TextPipelineConstructorArgs} options An object used to instantiate the pipeline. */ constructor(options) { super(options); } /** @type {QuestionAnsweringPipelineCallback} */ async _call(question, context, { top_k = 1 } = {}) { // Run tokenization const inputs = this.tokenizer(question, { text_pair: context, padding: true, truncation: true, }); const { start_logits, end_logits } = await this.model(inputs); const input_ids = inputs.input_ids.tolist(); const attention_mask = inputs.attention_mask.tolist(); // TODO: add support for `return_special_tokens_mask` const special_tokens = this.tokenizer.all_special_ids; /** @type {QuestionAnsweringOutput[]} */ const toReturn = []; for (let j = 0; j < start_logits.dims[0]; ++j) { const ids = input_ids[j]; const sepIndex = ids.findIndex(x => // We use == to match bigint with number // @ts-ignore x == this.tokenizer.sep_token_id ); const valid_mask = attention_mask[j].map((y, ix) => ( y == 1 && ( ix === 0 // is cls_token || ( ix > sepIndex && special_tokens.findIndex(x => x == ids[ix]) === -1 // token is not a special token (special_tokens_mask == 0) ) ) )); const start = start_logits[j].tolist(); const end = end_logits[j].tolist(); // Now, we mask out values that can't be in the answer // NOTE: We keep the cls_token unmasked (some models use it to indicate unanswerable questions) for (let i = 1; i < start.length; ++i) { if ( attention_mask[j] == 0 // is part of padding || i <= sepIndex // is before the sep_token || special_tokens.findIndex(x => x == ids[i]) !== -1 // Is a special token ) { // Make sure non-context indexes in the tensor cannot contribute to the softmax start[i] = -Infinity; end[i] = -Infinity; } } // Normalize logits and spans to retrieve the answer const start_scores = softmax(start).map((x, i) => [x, i]); const end_scores = softmax(end).map((x, i) => [x, i]); // Mask CLS start_scores[0][0] = 0; end_scores[0][0] = 0; // Generate all valid spans and select best ones const options = product(start_scores, end_scores) .filter(x => x[0][1] <= x[1][1]) .map(x => [x[0][1], x[1][1], x[0][0] * x[1][0]]) .sort((a, b) => b[2] - a[2]); for (let k = 0; k < Math.min(options.length, top_k); ++k) { const [start, end, score] = options[k]; const answer_tokens = ids.slice(start, end + 1) const answer = this.tokenizer.decode(answer_tokens, { skip_special_tokens: true, }); toReturn.push({ answer, score, start: null, end: null, input_ids: answer_tokens, }); } } this.tokenizer.get_offsets_mapping(toReturn, context, 'closest').forEach((offsets, i) => { toReturn[i].start = offsets.at(0)[0]; toReturn[i].end = offsets.at(-1)[1]; delete toReturn[i].input_ids; }) // Mimic HF's return type based on top_k return (top_k === 1) ? toReturn[0] : toReturn; } } /** * @typedef {Object} FillMaskSingle * @property {string} sequence The corresponding input with the mask token prediction. * @property {number} score The corresponding probability. * @property {number} token The predicted token id (to replace the masked one). * @property {string} token_str The predicted token (to replace the masked one). * @typedef {FillMaskSingle[]} FillMaskOutput * * @typedef {Object} FillMaskPipelineOptions Parameters specific to fill mask pipelines. * @property {number} [top_k=5] When passed, overrides the number of predictions to return. * * @callback FillMaskPipelineCallback Fill the masked token in the text(s) given as inputs. * @param {string|string[]} texts One or several texts (or one list of prompts) with masked tokens. * @param {FillMaskPipelineOptions} [options] The options to use for masked language modelling. * @returns {Promise<FillMaskOutput|FillMaskOutput[]>} An array of objects containing the score, predicted token, predicted token string, * and the sequence with the predicted token filled in, or an array of such arrays (one for each input text). * If only one input text is given, the output will be an array of objects. * @throws {Error} When the mask token is not found in the input text. * * @typedef {TextPipelineConstructorArgs & FillMaskPipelineCallback & Disposable} FillMaskPipelineType */ /** * Masked language modeling prediction pipeline using any `ModelWithLMHead`. * * **Example:** Perform masked language modelling (a.k.a. "fill-mask") with `Xenova/bert-base-uncased`. * ```javascript * const unmasker = await pipeline('fill-mask', 'Xenova/bert-base-cased'); * const output = await unmasker('The goal of life is [MASK].'); * // [ * // { token_str: 'survival', score: 0.06137419492006302, token: 8115, sequence: 'The goal of life is survival.' }, * // { token_str: 'love', score: 0.03902450203895569, token: 1567, sequence: 'The goal of life is love.' }, * // { token_str: 'happiness', score: 0.03253183513879776, token: 9266, sequence: 'The goal of life is happiness.' }, * // { token_str: 'freedom', score: 0.018736306577920914, token: 4438, sequence: 'The goal of life is freedom.' }, * // { token_str: 'life', score: 0.01859794743359089, token: 1297, sequence: 'The goal of life is life.' } * // ] * ``` * * **Example:** Perform masked language modelling (a.k.a. "fill-mask") with `Xenova/bert-base-cased` (and return top result). * ```javascript * const unmasker = await pipeline('fill-mask', 'Xenova/bert-base-cased'); * const output = await unmasker('The Milky Way is a [MASK] galaxy.', { top_k: 1 }); * // [{ token_str: 'spiral', score: 0.6299987435340881, token: 14061, sequence: 'The Milky Way is a spiral galaxy.' }] * ``` */ export class FillMaskPipeline extends (/** @type {new (options: TextPipelineConstructorArgs) => FillMaskPipelineType} */ (Pipeline)) { /** * Create a new FillMaskPipeline. * @param {TextPipelineConstructorArgs} options An object used to instantiate the pipeline. */ constructor(options) { super(options); } /** @type {FillMaskPipelineCallback} */ async _call(texts, { top_k = 5 } = {}) { // Run tokenization const model_inputs = this.tokenizer(texts, { padding: true, truncation: true, }); // Run model const { logits } = await this.model(model_inputs) const toReturn = []; /** @type {bigint[][]} */ const input_ids = model_inputs.input_ids.tolist(); for (let i = 0; i < input_ids.length; ++i) { const ids = input_ids[i]; const mask_token_index = ids.findIndex(x => // We use == to match bigint with number // @ts-ignore x == this.tokenizer.mask_token_id ); if (mask_token_index === -1) { throw Error(`Mask token (${this.tokenizer.mask_token}) not found in text.`) } const itemLogits = logits[i][mask_token_index]; const scores = await topk(new Tensor( 'float32', softmax(itemLogits.data), itemLogits.dims, ), top_k); const values = scores[0].tolist(); const indices = scores[1].tolist(); toReturn.push(indices.map((x, i) => { const sequence = ids.slice(); sequence[mask_token_index] = x; return { score: values[i], token: Number(x), token_str: this.tokenizer.model.vocab[x], sequence: this.tokenizer.decode(sequence, { skip_special_tokens: true }), } })); } return Array.isArray(texts) ? toReturn : toReturn[0]; } } /** * @typedef {Object} Text2TextGenerationSingle * @property {string} generated_text The generated text. * @typedef {Text2TextGenerationSingle[]} Text2TextGenerationOutput * * @callback Text2TextGenerationPipelineCallback Generate the output text(s) using text(s) given as inputs. * @param {string|string[]} texts Input text for the encoder. * @param {Partial<import('./generation/configuration_utils.js').GenerationConfig>} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise<Text2TextGenerationOutput|Text2TextGenerationOutput[]>} * * @typedef {TextPipelineConstructorArgs & Text2TextGenerationPipelineCallback & Disposable} Text2TextGenerationPipelineType */ /** * Text2TextGenerationPipeline class for generating text using a model that performs text-to-text generation tasks. * * **Example:** Text-to-text generation w/ `Xenova/LaMini-Flan-T5-783M`. * ```javascript * const generator = await pipeline('text2text-generation', 'Xenova/LaMini-Flan-T5-783M'); * const output = await generator('how can I become more healthy?', { * max_new_tokens: 100, * }); * // [{ generated_text: "To become more healthy, you can: 1. Eat a balanced diet with plenty of fruits, vegetables, whole grains, lean proteins, and healthy fats. 2. Stay hydrated by drinking plenty of water. 3. Get enough sleep and manage stress levels. 4. Avoid smoking and excessive alcohol consumption. 5. Regularly exercise and maintain a healthy weight. 6. Practice good hygiene and sanitation. 7. Seek medical attention if you experience any health issues." }] * ``` */ export class Text2TextGenerationPipeline extends (/** @type {new (options: TextPipelineConstructorArgs) => Text2TextGenerationPipelineType} */ (Pipeline)) { /** @type {'generated_text'} */ _key = 'generated_text'; /** * Create a new Text2TextGenerationPipeline. * @param {TextPipelineConstructorArgs} options An object used to instantiate the pipeline. */ constructor(options) { super(options); } /** @type {Text2TextGenerationPipelineCallback} */ async _call(texts, generate_kwargs = {}) { if (!Array.isArray(texts)) { texts = [texts]; } // Add global prefix, if present if (this.model.config.prefix) { texts = texts.map(x => this.model.config.prefix + x) } // Handle task specific params: const task_specific_params = this.model.config.task_specific_params if (task_specific_params && task_specific_params[this.task]) { // Add prefixes, if present if (task_specific_params[this.task].prefix) { texts = texts.map(x => task_specific_params[this.task].prefix + x) } // TODO update generation config } const tokenizer = this.tokenizer; const tokenizer_options = { padding: true, truncation: true, } let inputs; if (this instanceof TranslationPipeline && '_build_translation_inputs' in tokenizer) { // TODO: move to Translation pipeline? // Currently put here to avoid code duplication // @ts-ignore inputs = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs); } else { inputs = tokenizer(texts, tokenizer_options); } const outputTokenIds = await this.model.generate({ ...inputs, ...generate_kwargs }); return tokenizer.batch_decode(/** @type {Tensor} */(outputTokenIds), { skip_special_tokens: true, }).map(text => ({ [this._key]: text })); } } /** * @typedef {Object} SummarizationSingle * @property {string} summary_text The summary text. * @typedef {SummarizationSingle[]} SummarizationOutput * * @callback SummarizationPipelineCallback Summarize the text(s) given as inputs. * @param {string|string[]} texts One or several articles (or one list of articles) to summarize. * @param {import('./generation/configuration_utils.js').GenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise<SummarizationOutput|SummarizationOutput[]>} * * @typedef {TextPipelineConstructorArgs & SummarizationPipelineCallback & Disposable} SummarizationPipelineType */ /** * A pipeline for summarization tasks, inheriting from Text2TextGenerationPipeline. * * **Example:** Summarization w/ `Xenova/distilbart-cnn-6-6`. * ```javascript * const generator = await pipeline('summarization', 'Xenova/distilbart-cnn-6-6'); * const text = 'The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, ' + * 'and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. ' + * 'During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest ' + * 'man-made structure in the world, a title it held for 41 years until the Chrysler Building in New ' + * 'York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to ' + * 'the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the ' + * 'Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second ' + * 'tallest free-standing structure in France after the Millau Viaduct.'; * const output = await generator(text, { * max_new_tokens: 100, * }); * // [{ summary_text: ' The Eiffel Tower is about the same height as an 81-storey building and the tallest structure in Paris. It is the second tallest free-standing structure in France after the Millau Viaduct.' }] * ``` */ export class SummarizationPipeline extends (/** @type {new (options: TextPipelineConstructorArgs) => SummarizationPipelineType} */ (/** @type {any} */ (Text2TextGenerationPipeline))) { /** @type {'summary_text'} */ _key = 'summary_text'; /** * Create a new SummarizationPipeline. * @param {TextPipelineConstructorArgs} options An object used to instantiate the pipeline. */ constructor(options) { super(options); } } /** * @typedef {Object} TranslationSingle * @property {string} translation_text The translated text. * @typedef {TranslationSingle[]} TranslationOutput * * @callback TranslationPipelineCallback Translate the text(s) given as inputs. * @param {string|string[]} texts Texts to be translated. * @param {import('./generation/configuration_utils.js').GenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise<TranslationOutput|TranslationOutput[]>} * * @typedef {TextPipelineConstructorArgs & TranslationPipelineCallback & Disposable} TranslationPipelineType */ /** * Translates text from one language to another. * * **Example:** Multilingual translation w/ `Xenova/nllb-200-distilled-600M`. * * See [here](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200) * for the full list of languages and their corresponding codes. * * ```javascript * const translator = await pipeline('translation', 'Xenova/nllb-200-distilled-600M'); * const output = await translator('जीवन एक चॉकलेट बॉक्स की तरह है।', { * src_lang: 'hin_Deva', // Hindi * tgt_lang: 'fra_Latn', // French * }); * // [{ translation_text: 'La vie est comme une boîte à chocolat.' }] * ``` * * **Example:** Multilingual translation w/ `Xenova/m2m100_418M`. * * See [here](https://huggingface.co/facebook/m2m100_418M#languages-covered) * for the full list of languages and their corresponding codes. * * ```javascript * const translator = await pipeline('translation', 'Xenova/m2m100_418M'); * const output = await translator('生活就像一盒巧克力。', { * src_lang: 'zh', // Chinese * tgt_lang: 'en', // English * }); * // [{ translation_text: 'Life is like a box of chocolate.' }] * ``` * * **Example:** Multilingual translation w/ `Xenova/mbart-large-50-many-to-many-mmt`. * * See [here](https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered) * for the full list of languages and their corresponding codes. * * ```javascript * const translator = await pipeline('translation', 'Xenova/mbart-large-50-many-to-many-mmt'); * const output = await translator('संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है', { * src_lang: 'hi_IN', // Hindi * tgt_lang: 'fr_XX', // French * }); * // [{ translation_text: 'Le chef des Nations affirme qu 'il n 'y a military solution in Syria.' }] * ``` */ export class TranslationPipeline extends (/** @type {new (options: TextPipelineConstructorArgs) => TranslationPipelineType} */ (/** @type {any} */ (Text2TextGenerationPipeline))) { /** @type {'translation_text'} */ _key = 'translation_text'; /** * Create a new TranslationPipeline. * @param {TextPipelineConstructorArgs} options An object used to instantiate the pipeline. */ constructor(options) { super(options); } } function isChat(x) { return Array.isArray(x) && x.every(x => 'role' in x && 'content' in x); } /** * @typedef {import('./tokenizers.js').Message[]} Chat * * @typedef {Object} TextGenerationSingle * @property {string|Chat} generated_text The generated text. * @typedef {TextGenerationSingle[]} TextGenerationOutput * * @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines. * @property {boolean} [add_special_tokens] Whether or not to add special tokens when tokenizing the sequences. * @property {boolean} [return_full_text=true] If set to `false` only added text is returned, otherwise the full text is returned. * @typedef {import('./generation/configuration_utils.js').GenerationConfig & TextGenerationSpecificParams} TextGenerationConfig * * @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs. * @param {string|string[]|Chat|Chat[]} texts One or several prompts (or one list of prompts) to complete. * @param {Partial<TextGenerationConfig>} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise<TextGenerationOutput|TextGenerationOutput[]>} An array or object containing the generated texts. * * @typedef {TextPipelineConstructorArgs & TextGenerationPipelineCallback & Disposable} TextGenerationPipelineType */ /** * Language generation pipeline using any `ModelWithLMHead` or `ModelForCausalLM`. * This pipeline predicts the words that will follow a specified text prompt. * NOTE: For the full list of generation parameters, see [`GenerationConfig`](./utils/generation#module_utils/generation.GenerationConfig). * * **Example:** Text generation with `Xenova/distilgpt2` (default settings). * ```javascript * const generator = await pipeline('text-generation', 'Xenova/distilgpt2'); * const text = 'I enjoy walking with my cute dog,'; * const output = await generator(text); * // [{ generated_text: "I enjoy walking with my cute dog, and I love to play with the other dogs." }] * ``` * * **Example:** Text generation with `Xenova/distilgpt2` (custom settings). * ```javascript * const generator = await pipeline('text-generation', 'Xenova/distilgpt2'); * const text = 'Once upon a time, there was'; * const output = await generator(text, { * temperature: 2, * max_new_tokens: 10, * repetition_penalty: 1.5, * no_repeat_ngram_size: 2, * num_beams: 2, * num_return_sequences: 2, * }); * // [{ * // "generated_text": "Once upon a time, there was an abundance of information about the history and activities that" * // }, { * // "generated_text": "Once upon a time, there was an abundance of information about the most important and influential" * // }] * ``` * * **Example:** Run code generation with `Xenova/codegen-350M-mono`. * ```javascript * const generator = await pipeline('text-generation', 'Xenova/codegen-350M-mono'); * const text = 'def fib(n):'; * const output = await generator(text, { * max_new_tokens: 44, * }); * // [{ * // generated_text: 'def fib(n):\n' + * // ' if n == 0:\n' + * // ' return 0\n' + * // ' elif n == 1:\n' + * // ' return 1\n' + * // ' else:\n' + * // ' return fib(n-1) + fib(n-2)\n' * // }] * ``` */ export class TextGenerationPipeline extends (/** @type {new (options: TextPipelineConstructorArgs) => TextGenerationPipelineType} */ (Pipeline)) { /** * Create a new TextGenerationPipeline. * @param {TextPipelineConstructorArgs} options An object used to instantiate the pipeline. */ constructor(options) { super(options); } /** @type {TextGenerationPipelineCallback} */ async _call(texts, generate_kwargs = {}) { let isBatched = false; let isChatInput = false; // Normalize inputs /** @type {string[]} */ let inputs; if (typeof texts === 'string') { inputs = texts = [texts]; } else if (Array.isArray(texts) && texts.every(x => typeof x === 'string')) { isBatched = true; inputs = /** @type {string[]} */(texts); } else { if (isChat(texts)) { texts = [/** @type {Chat} */(texts)]; } else if (Array.isArray(texts) && texts.every(isChat)) { isBatched = true; } else { throw new Error('Input must be a string, an array of strings, a Chat, or an array of Chats'); } isChatInput = true; // If the input is a chat, we need to apply the chat template inputs = /** @type {string[]} */(/** @type {Chat[]} */ (texts).map( x => this.tokenizer.apply_chat_template(x, { tokenize: false, add_generation_prompt: true, }) )); } // By default, do not add special tokens const add_special_tokens = generate_kwargs.add_special_tokens ?? false; // By default, return full text const return_full_text = isChatInput ? false : generate_kwargs.return_full_text ?? true; this.tokenizer.padding_side = 'left'; const text_inputs = this.tokenizer(inputs, { add_special_tokens, padding: true, truncation: true, }); const outputTokenIds = /** @type {Tensor} */(await this.model.generate({ ...text_inputs, ...generate_kwargs })); const decoded = this.tokenizer.batch_decode(outputTokenIds, { skip_special_tokens: true, }); let promptLengths; if (!return_full_text && text_inputs.input_ids.dims.at(-1) > 0) { promptLengths = this.tokenizer.batch_decode(text_inputs.input_ids, { skip_special_tokens: true, }).map(x => x.length); } /** @type {TextGenerationOutput[]} */ const toReturn = Array.from({ length: texts.length }, _ => []); for (let i = 0; i < decoded.length; ++i) { const textIndex = Math.floor(i / outputTokenIds.dims[0] * texts.length); if (promptLengths) { // Trim the decoded text to only include the generated part decoded[i] = decoded[i].slice(promptLengths[textIndex]); } toReturn[textIndex].push({ generated_text: isChatInput ? [ ...((/** @type {Chat[]} */(texts)[textIndex])), { role: 'assistant', content: decoded[i] }, ] : decoded[i] }); } return (!isBatched && toReturn.length === 1) ? toReturn[0] : toReturn; } } /** * @typedef {Object} ZeroShotClassificationOutput * @property {string} sequence The sequence for which this is the output. * @property {string[]} labels The labels sorted by order of likelihood. * @property {number[]} scores The probabilities for each of the labels. * * @typedef {Object} ZeroShotClassificationPipelineOptions Parameters specific to zero-shot classification pipelines. * @property {string} [hypothesis_template="This example is {}."] The template used to turn each * candidate label into an NLI-style hypothesis. The c