@huggingface/transformers
Version:
State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
173 lines (149 loc) • 6.54 kB
JavaScript
import { Processor } from "../../base/processing_utils.js";
import { AutoImageProcessor } from "../auto/image_processing_auto.js";
import { AutoTokenizer } from "../../tokenizers.js";
import { max, softmax } from "../../utils/maths.js";
const DECODE_TYPE_MAPPING = {
'char': ['char_decode', 1],
'bpe': ['bpe_decode', 2],
'wp': ['wp_decode', 102],
}
export class MgpstrProcessor extends Processor {
static tokenizer_class = AutoTokenizer
static image_processor_class = AutoImageProcessor
/**
* @returns {import('../../tokenizers.js').MgpstrTokenizer} The character tokenizer.
*/
get char_tokenizer() {
return this.components.char_tokenizer;
}
/**
* @returns {import('../../tokenizers.js').GPT2Tokenizer} The BPE tokenizer.
*/
get bpe_tokenizer() {
return this.components.bpe_tokenizer;
}
/**
* @returns {import('../../tokenizers.js').BertTokenizer} The WordPiece tokenizer.
*/
get wp_tokenizer() {
return this.components.wp_tokenizer;
}
/**
* Helper function to decode the model prediction logits.
* @param {import('../../utils/tensor.js').Tensor} pred_logits Model prediction logits.
* @param {string} format Type of model prediction. Must be one of ['char', 'bpe', 'wp'].
* @returns {[string[], number[]]} The decoded sentences and their confidence scores.
*/
_decode_helper(pred_logits, format) {
if (!DECODE_TYPE_MAPPING.hasOwnProperty(format)) {
throw new Error(`Format ${format} is not supported.`);
}
const [decoder_name, eos_token] = DECODE_TYPE_MAPPING[format];
const decoder = this[decoder_name].bind(this);
const [batch_size, batch_max_length] = pred_logits.dims;
const conf_scores = [];
const all_ids = [];
/** @type {number[][][]} */
const pred_logits_list = pred_logits.tolist();
for (let i = 0; i < batch_size; ++i) {
const logits = pred_logits_list[i];
const ids = [];
const scores = [];
// Start and index=1 to skip the first token
for (let j = 1; j < batch_max_length; ++j) {
// NOTE: == to match bigint and number
const [max_prob, max_prob_index] = max(softmax(logits[j]));
scores.push(max_prob);
if (max_prob_index == eos_token) {
break;
}
ids.push(max_prob_index);
}
const confidence_score = scores.length > 0
? scores.reduce((a, b) => a * b, 1)
: 0;
all_ids.push(ids);
conf_scores.push(confidence_score);
}
const decoded = decoder(all_ids);
return [decoded, conf_scores];
}
/**
* Convert a list of lists of char token ids into a list of strings by calling char tokenizer.
* @param {number[][]} sequences List of tokenized input ids.
* @returns {string[]} The list of char decoded sentences.
*/
char_decode(sequences) {
return this.char_tokenizer.batch_decode(sequences).map(str => str.replaceAll(' ', ''));
}
/**
* Convert a list of lists of BPE token ids into a list of strings by calling BPE tokenizer.
* @param {number[][]} sequences List of tokenized input ids.
* @returns {string[]} The list of BPE decoded sentences.
*/
bpe_decode(sequences) {
return this.bpe_tokenizer.batch_decode(sequences)
}
/**
* Convert a list of lists of word piece token ids into a list of strings by calling word piece tokenizer.
* @param {number[][]} sequences List of tokenized input ids.
* @returns {string[]} The list of wp decoded sentences.
*/
wp_decode(sequences) {
return this.wp_tokenizer.batch_decode(sequences).map(str => str.replaceAll(' ', ''));
}
/**
* Convert a list of lists of token ids into a list of strings by calling decode.
* @param {import('../../utils/tensor.js').Tensor[]} sequences List of tokenized input ids.
* @returns {{generated_text: string[], scores: number[], char_preds: string[], bpe_preds: string[], wp_preds: string[]}}
* Dictionary of all the outputs of the decoded results.
* - generated_text: The final results after fusion of char, bpe, and wp.
* - scores: The final scores after fusion of char, bpe, and wp.
* - char_preds: The list of character decoded sentences.
* - bpe_preds: The list of BPE decoded sentences.
* - wp_preds: The list of wp decoded sentences.
*/
// @ts-expect-error The type of this method is not compatible with the one
// in the base class. It might be a good idea to fix this.
batch_decode([char_logits, bpe_logits, wp_logits]) {
const [char_preds, char_scores] = this._decode_helper(char_logits, 'char');
const [bpe_preds, bpe_scores] = this._decode_helper(bpe_logits, 'bpe');
const [wp_preds, wp_scores] = this._decode_helper(wp_logits, 'wp');
const generated_text = [];
const scores = [];
for (let i = 0; i < char_preds.length; ++i) {
const [max_score, max_score_index] = max([char_scores[i], bpe_scores[i], wp_scores[i]]);
generated_text.push([char_preds[i], bpe_preds[i], wp_preds[i]][max_score_index]);
scores.push(max_score);
}
return {
generated_text,
scores,
char_preds,
bpe_preds,
wp_preds,
}
}
/** @type {typeof Processor.from_pretrained} */
static async from_pretrained(...args) {
const base = await super.from_pretrained(...args);
// Load Transformers.js-compatible versions of the BPE and WordPiece tokenizers
const bpe_tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2") // openai-community/gpt2
const wp_tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased") // google-bert/bert-base-uncased
// Update components
base.components = {
image_processor: base.image_processor,
char_tokenizer: base.tokenizer,
bpe_tokenizer: bpe_tokenizer,
wp_tokenizer: wp_tokenizer,
}
return base;
}
async _call(images, text = null) {
const result = await this.image_processor(images);
if (text) {
result.labels = this.tokenizer(text).input_ids
}
return result;
}
}