chromadb-default-embed
Version:
Chroma's fork of @xenova/transformers serving as our default embedding function
1,478 lines (1,310 loc) • 158 kB
JavaScript
/**
* @file Tokenizers are used to prepare textual inputs for a model.
*
* **Example:** Create an `AutoTokenizer` and use it to tokenize a sentence.
* This will automatically detect the tokenizer type based on the tokenizer class defined in `tokenizer.json`.
* ```javascript
* import { AutoTokenizer } from '@xenova/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
* let { input_ids } = await tokenizer('I love transformers!');
* // Tensor {
* // data: BigInt64Array(6) [101n, 1045n, 2293n, 19081n, 999n, 102n],
* // dims: [1, 6],
* // type: 'int64',
* // size: 6,
* // }
* ```
*
* @module tokenizers
*/
import {
Callable,
reverseDictionary,
escapeRegExp,
isIntegralNumber,
mergeArrays,
} from './utils/core.js';
import {
getModelJSON,
} from './utils/hub.js';
import { max, min, round } from './utils/maths.js';
import { Tensor } from './utils/tensor.js';
import {
PriorityQueue,
TokenLattice,
CharTrie,
} from './utils/data-structures.js';
import { Template } from '@huggingface/jinja';
/**
* @typedef {Object} TokenizerProperties Additional tokenizer-specific properties.
* @property {boolean} [legacy=false] Whether or not the `legacy` behavior of the tokenizer should be used.
* @typedef {import('./utils/hub.js').PretrainedOptions & TokenizerProperties} PretrainedTokenizerOptions
*/
/**
* Loads a tokenizer from the specified path.
* @param {string} pretrained_model_name_or_path The path to the tokenizer directory.
* @param {PretrainedTokenizerOptions} options Additional options for loading the tokenizer.
* @returns {Promise<any[]>} A promise that resolves with information about the loaded tokenizer.
*/
async function loadTokenizer(pretrained_model_name_or_path, options) {
let info = await Promise.all([
getModelJSON(pretrained_model_name_or_path, 'tokenizer.json', true, options),
getModelJSON(pretrained_model_name_or_path, 'tokenizer_config.json', true, options),
])
// Override legacy option if `options.legacy` is not null
if (options.legacy !== null) {
info[1].legacy = options.legacy;
}
return info;
}
/**
* Helper function to split a string on a regex, but keep the delimiters.
* This is required, because the JavaScript `.split()` method does not keep the delimiters,
* and wrapping in a capturing group causes issues with existing capturing groups (due to nesting).
* @param {string} text The text to split.
* @param {RegExp} regex The regex to split on.
* @returns {string[]} The split string.
*/
function regexSplit(text, regex) {
const result = [];
let prev = 0;
for (const match of text.matchAll(regex)) {
const fullMatch = match[0];
if (prev < match.index) {
result.push(text.slice(prev, match.index));
}
if (fullMatch.length > 0) {
result.push(fullMatch);
}
prev = match.index + fullMatch.length;
}
if (prev < text.length) {
result.push(text.slice(prev));
}
return result;
}
/**
* Helper method to construct a pattern from a config object.
* @param {Object} pattern The pattern object.
* @param {boolean} invert Whether to invert the pattern.
* @returns {RegExp|null} The compiled pattern.
*/
function createPattern(pattern, invert = true) {
if (pattern.Regex !== undefined) {
// In certain cases, the pattern may contain unnecessary escape sequences (e.g., \# or \& or \~).
// i.e., valid in Python (where the patterns are exported from) but invalid in JavaScript (where the patterns are parsed).
// This isn't an issue when creating the regex w/o the 'u' flag, but it is when the 'u' flag is used.
// For this reason, it is necessary to remove these backslashes before creating the regex.
// See https://stackoverflow.com/a/63007777/13989043 for more information
const regex = pattern.Regex.replace(/\\([#&~])/g, '$1'); // TODO: add more characters to this list if necessary
return new RegExp(regex, 'gu');
} else if (pattern.String !== undefined) {
const escaped = escapeRegExp(pattern.String);
// NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split()
return new RegExp(invert ? escaped : `(${escaped})`, 'gu');
} else {
console.warn('Unknown pattern type:', pattern)
return null;
}
}
/**
* Helper function to convert an Object to a Map
* @param {Object} obj The object to convert.
* @returns {Map<string, any>} The map.
*/
function objectToMap(obj) {
return new Map(Object.entries(obj));
}
/**
* Helper function to convert a tensor to a list before decoding.
* @param {Tensor} tensor The tensor to convert.
* @returns {number[]} The tensor as a list.
*/
function prepareTensorForDecode(tensor) {
const dims = tensor.dims;
switch (dims.length) {
case 1:
return tensor.tolist();
case 2:
if (dims[0] !== 1) {
throw new Error('Unable to decode tensor with `batch size !== 1`. Use `tokenizer.batch_decode(...)` for batched inputs.');
}
return tensor.tolist()[0];
default:
throw new Error(`Expected tensor to have 1-2 dimensions, got ${dims.length}.`)
}
}
/**
* Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms
* @param {string} text The text to clean up.
* @returns {string} The cleaned up text.
*/
function clean_up_tokenization(text) {
// Clean up a list of simple English tokenization artifacts
// like spaces before punctuations and abbreviated forms
return text.replace(/ \./g, '.')
.replace(/ \?/g, '?')
.replace(/ \!/g, '!')
.replace(/ ,/g, ',')
.replace(/ \' /g, "'")
.replace(/ n\'t/g, "n't")
.replace(/ \'m/g, "'m")
.replace(/ \'s/g, "'s")
.replace(/ \'ve/g, "'ve")
.replace(/ \'re/g, "'re");
}
/**
* Helper function to remove accents from a string.
* @param {string} text The text to remove accents from.
* @returns {string} The text with accents removed.
*/
function remove_accents(text) {
return text.replace(/[\u0300-\u036f]/g, '');
}
/**
* Helper function to lowercase a string and remove accents.
* @param {string} text The text to lowercase and remove accents from.
* @returns {string} The lowercased text with accents removed.
*/
function lowercase_and_remove_accent(text) {
return remove_accents(text.toLowerCase());
}
/**
* Helper function to fuse consecutive values in an array equal to the specified value.
* @param {Array} arr The input array
* @param {any} value The value to fuse on.
*/
function fuse(arr, value) {
let fused = [];
let i = 0;
while (i < arr.length) {
fused.push(arr[i])
if (arr[i] !== value) {
++i;
continue;
}
while (i < arr.length && arr[i] === value) {
++i;
}
}
return fused;
}
/**
* Split a string on whitespace.
* @param {string} text The text to split.
* @returns {string[]} The split string.
*/
function whitespace_split(text) {
return text.match(/\S+/g) || [];
}
const PUNCTUATION_REGEX = '\\p{P}\\u0021-\\u002F\\u003A-\\u0040\\u005B-\\u0060\\u007B-\\u007E';
/**
* Represent a token added by the user on top of the existing Model vocabulary.
* AddedToken can be configured to specify the behavior they should have in various situations like:
* - Whether they should only match single words
* - Whether to include any whitespace on its left or right
*/
class AddedToken {
/**
* Creates a new instance of AddedToken.
* @param {Object} config Added token configuration object.
* @param {string} config.content The content of the added token.
* @param {number} config.id The id of the added token.
* @param {boolean} [config.single_word=false] Whether this token must be a single word or can break words.
* @param {boolean} [config.lstrip=false] Whether this token should strip whitespaces on its left.
* @param {boolean} [config.rstrip=false] Whether this token should strip whitespaces on its right.
* @param {boolean} [config.normalized=false] Whether this token should be normalized.
* @param {boolean} [config.special=false] Whether this token is special.
*/
constructor(config) {
this.content = config.content;
this.id = config.id;
this.single_word = config.single_word ?? false;
this.lstrip = config.lstrip ?? false;
this.rstrip = config.rstrip ?? false;
this.special = config.special ?? false;
this.normalized = config.normalized ?? null;
}
}
/**
* Abstract base class for tokenizer models.
*
* @extends Callable
*/
export class TokenizerModel extends Callable {
/**
* Creates a new instance of TokenizerModel.
* @param {Object} config The configuration object for the TokenizerModel.
*/
constructor(config) {
super();
this.config = config;
/** @type {string[]} */
this.vocab = [];
/**
* A mapping of tokens to ids.
* @type {Map<string, number>}
*/
this.tokens_to_ids = new Map();
this.unk_token_id = undefined;
this.unk_token = undefined;
this.end_of_word_suffix = undefined;
/** @type {boolean} Whether to fuse unknown tokens when encoding. Defaults to false. */
this.fuse_unk = this.config.fuse_unk ?? false;
}
/**
* Instantiates a new TokenizerModel instance based on the configuration object provided.
* @param {Object} config The configuration object for the TokenizerModel.
* @param {...*} args Optional arguments to pass to the specific TokenizerModel constructor.
* @returns {TokenizerModel} A new instance of a TokenizerModel.
* @throws Will throw an error if the TokenizerModel type in the config is not recognized.
*/
static fromConfig(config, ...args) {
switch (config.type) {
case 'WordPiece':
return new WordPieceTokenizer(config);
case 'Unigram':
// @ts-ignore
return new Unigram(config, ...args);
case 'BPE':
return new BPE(config);
default:
if (config.vocab) {
// @ts-ignore
return new LegacyTokenizerModel(config, ...args);
}
throw new Error(`Unknown TokenizerModel type: ${config.type}`);
}
}
/**
* Internal function to call the TokenizerModel instance.
* @param {string[]} tokens The tokens to encode.
* @returns {string[]} The encoded token IDs.
*/
_call(tokens) {
return this.encode(tokens);
}
/**
* Encodes a list of tokens into a list of token IDs.
* @param {string[]} tokens The tokens to encode.
* @returns {string[]} The encoded tokens.
* @throws Will throw an error if not implemented in a subclass.
*/
encode(tokens) {
throw Error("encode should be implemented in subclass.")
}
/**
* Converts a list of tokens into a list of token IDs.
* @param {string[]} tokens The tokens to convert.
* @returns {number[]} The converted token IDs.
*/
convert_tokens_to_ids(tokens) {
let ids = tokens.map(t => this.tokens_to_ids.get(t) ?? this.unk_token_id);
if (this.fuse_unk) {
// Fuse unknown tokens
ids = fuse(ids, this.unk_token_id);
}
return ids;
}
/**
* Converts a list of token IDs into a list of tokens.
* @param {number[]} ids The token IDs to convert.
* @returns {string[]} The converted tokens.
*/
convert_ids_to_tokens(ids) {
return ids.map(i => this.vocab[i] ?? this.unk_token);
}
}
/**
* A subclass of TokenizerModel that uses WordPiece encoding to encode tokens.
* @extends TokenizerModel
*/
class WordPieceTokenizer extends TokenizerModel {
/**
* @param {Object} config The configuration object.
* @param {Object} config.vocab A mapping of tokens to ids.
* @param {string} config.unk_token The unknown token string.
* @param {string} config.continuing_subword_prefix The prefix to use for continuing subwords.
* @param {number} [config.max_input_chars_per_word=100] The maximum number of characters per word.
*/
constructor(config) {
super(config);
/**
* A mapping of tokens to ids.
* @type {Map<string, number>}
*/
this.tokens_to_ids = objectToMap(config.vocab);
/**
* The id of the unknown token.
* @type {number}
*/
this.unk_token_id = this.tokens_to_ids.get(config.unk_token);
/**
* The unknown token string.
* @type {string}
*/
this.unk_token = config.unk_token;
/**
* The maximum number of characters allowed per word.
* @type {number}
*/
this.max_input_chars_per_word = config.max_input_chars_per_word ?? 100;
/**
* An array of tokens.
* @type {string[]}
*/
this.vocab = new Array(this.tokens_to_ids.size);
for (const [key, value] of this.tokens_to_ids) {
this.vocab[value] = key;
}
}
/**
* Encodes an array of tokens using WordPiece encoding.
* @param {string[]} tokens The tokens to encode.
* @returns {string[]} An array of encoded tokens.
*/
encode(tokens) {
let outputTokens = [];
for (let token of tokens) {
let chars = [...token];
if (chars.length > this.max_input_chars_per_word) {
outputTokens.push(this.unk_token);
continue;
}
let isUnknown = false;
let start = 0;
let subTokens = [];
while (start < chars.length) {
let end = chars.length;
let currentSubstring = null;
while (start < end) {
let substr = chars.slice(start, end).join('');
if (start > 0) {
substr = this.config.continuing_subword_prefix + substr;
}
if (this.tokens_to_ids.has(substr)) {
currentSubstring = substr;
break;
}
--end;
}
if (currentSubstring === null) {
isUnknown = true;
break;
}
subTokens.push(currentSubstring);
start = end;
}
if (isUnknown) {
outputTokens.push(this.unk_token);
} else {
outputTokens.push(...subTokens);
}
}
return outputTokens;
}
}
/**
* Class representing a Unigram tokenizer model.
* @extends TokenizerModel
*/
class Unigram extends TokenizerModel {
/**
* Create a new Unigram tokenizer model.
* @param {Object} config The configuration object for the Unigram model.
* @param {number} config.unk_id The ID of the unknown token
* @param {any[][]} config.vocab A 2D array representing a mapping of tokens to scores.
* @param {Object} moreConfig Additional configuration object for the Unigram model.
*/
constructor(config, moreConfig) {
super(config);
const vocabSize = config.vocab.length;
this.vocab = new Array(vocabSize);
this.scores = new Array(vocabSize);
for (let i = 0; i < vocabSize; ++i) {
const piece = config.vocab[i];
this.vocab[i] = piece[0];
this.scores[i] = piece[1];
}
this.unk_token_id = config.unk_id;
this.unk_token = this.vocab[config.unk_id];
this.tokens_to_ids = new Map(this.vocab.map((x, i) => [x, i]));
this.bosToken = ' '; // beginning of a sentence token
this.bosTokenId = this.tokens_to_ids.get(this.bosToken); // NOTE: may be undefined
this.eosToken = moreConfig.eos_token;
this.eosTokenId = this.tokens_to_ids.get(this.eosToken);
this.unkToken = this.vocab[this.unk_token_id];
this.minScore = min(this.scores)[0];
this.unkScore = this.minScore - 10.0;
this.scores[this.unk_token_id] = this.unkScore;
this.trie = new CharTrie();
this.trie.extend(this.vocab);
// NOTE: `fuse_unk` is hardcoded to true for Unigram models
// See: https://github.com/huggingface/tokenizers/blob/b58227c7f1ccf8b73ee2268354336da56d91e492/tokenizers/src/models/unigram/model.rs#L119
this.fuse_unk = true;
}
/**
* Populates lattice nodes.
* @param {TokenLattice} lattice The token lattice to populate with nodes.
*/
populateNodes(lattice) {
const sentence = lattice.sentence;
const len = sentence.length;
let beginPos = 0;
while (beginPos < len) {
const mblen = 1;
let hasSingleNode = false;
const tokens = [];
for (let token of this.trie.commonPrefixSearch(sentence.slice(beginPos))) {
tokens.push(token);
const tokenId = this.tokens_to_ids.get(token);
const tokenScore = this.scores[tokenId];
const n = token.length;
lattice.insert(beginPos, n, tokenScore, tokenId);
if (!hasSingleNode && n === mblen) {
hasSingleNode = true;
}
}
if (!hasSingleNode) {
lattice.insert(beginPos, mblen, this.unkScore, this.unk_token_id);
}
beginPos += mblen;
}
}
/**
* Encodes an array of tokens into an array of subtokens using the unigram model.
*
* @param {string} normalized The normalized string.
* @returns {string[]} An array of subtokens obtained by encoding the input tokens using the unigram model.
*/
tokenize(normalized) {
const lattice = new TokenLattice(normalized, this.bosTokenId, this.eosTokenId);
this.populateNodes(lattice);
return lattice.tokens();
}
/**
* Encodes an array of tokens using Unigram encoding.
* @param {Array} tokens The tokens to encode.
* @returns {Array} An array of encoded tokens.
*/
encode(tokens) {
let toReturn = [];
for (let token of tokens) {
const tokenized = this.tokenize(token);
toReturn.push(...tokenized);
}
return toReturn;
}
}
/**
* Returns list of utf-8 byte and a mapping to unicode strings.
* Specifically avoids mapping to whitespace/control characters the BPE code barfs on.
* @returns {Object} Object with utf-8 byte keys and unicode string values.
*/
const BYTES_TO_UNICODE = (() => {
// Returns list of utf-8 byte and a mapping to unicode strings.
// We specifically avoids mapping to whitespace/control characters
// the bpe code barfs on.
const bs = [
...Array.from({ length: "~".charCodeAt(0) - "!".charCodeAt(0) + 1 }, (_, i) => i + "!".charCodeAt(0)),
...Array.from({ length: "¬".charCodeAt(0) - "¡".charCodeAt(0) + 1 }, (_, i) => i + "¡".charCodeAt(0)),
...Array.from({ length: "ÿ".charCodeAt(0) - "®".charCodeAt(0) + 1 }, (_, i) => i + "®".charCodeAt(0)),
];
let cs = bs.slice();
let n = 0;
for (let b = 0; b < 256; ++b) {
if (!bs.includes(b)) {
bs.push(b);
cs.push(256 + n);
n += 1;
}
}
let ccs = cs.map(n => String.fromCharCode(n));
return Object.fromEntries(bs.map((b, i) => [b, ccs[i]]));
})();
const UNICODE_TO_BYTES = reverseDictionary(BYTES_TO_UNICODE);
/**
* @typedef {Object} BPENode
* @property {string} token The token associated with the node
* @property {number} bias A positional bias for the node.
* @property {number} [score] The score of the node.
* @property {BPENode} [prev] The previous node in the linked list.
* @property {BPENode} [next] The next node in the linked list.
*/
/**
* BPE class for encoding text into Byte-Pair-Encoding (BPE) tokens.
* @extends TokenizerModel
*/
class BPE extends TokenizerModel {
/**
* Create a BPE instance.
* @param {Object} config The configuration object for BPE.
* @param {Object} config.vocab A mapping of tokens to ids.
* @param {string} config.unk_token The unknown token used for out of vocabulary words.
* @param {string} config.end_of_word_suffix The suffix to place at the end of each word.
* @param {string} [config.continuing_subword_suffix] The suffix to insert between words.
* @param {Array} config.merges An array of BPE merges as strings.
*/
constructor(config) {
super(config);
this.BPE_SPLIT_TOKEN = ' ';
/** @type {Map<string, number>} */
this.tokens_to_ids = objectToMap(config.vocab);
this.unk_token_id = this.tokens_to_ids.get(config.unk_token);
this.unk_token = config.unk_token;
this.vocab = new Array(this.tokens_to_ids.size);
for (const [key, value] of this.tokens_to_ids) {
this.vocab[value] = key;
}
this.bpe_ranks = new Map(config.merges.map((x, i) => [x, i]));
this.merges = config.merges.map(x => x.split(this.BPE_SPLIT_TOKEN));
this.end_of_word_suffix = config.end_of_word_suffix;
// NOTE: `continuing_subword_suffix` is custom (to support `BlenderbotSmallTokenizer`)
this.continuing_subword_suffix = config.continuing_subword_suffix ?? null;
this.byte_fallback = this.config.byte_fallback ?? false;
if (this.byte_fallback) {
this.text_encoder = new TextEncoder();
}
/** @type {Map<string, string[]>} */
this.cache = new Map();
}
/**
* Apply Byte-Pair-Encoding (BPE) to a given token. Efficient heap-based priority
* queue implementation adapted from https://github.com/belladoreai/llama-tokenizer-js.
* @param {string} token The token to encode.
* @returns {string[]} The BPE encoded tokens.
*/
bpe(token) {
if (token.length === 0) {
return [];
}
const cached = this.cache.get(token);
if (cached !== undefined) {
return cached;
}
const word = Array.from(token);
if (this.end_of_word_suffix) {
word[word.length - 1] += this.end_of_word_suffix;
}
let result = [];
if (word.length > 1) {
// Create a priority queue to store the nodes that will be merged.
// The comparator function compares the scores of the nodes.
const queue = new PriorityQueue((a, b) => a.score < b.score);
// Construct a doubly-linked list of nodes that will be inserted into the priority queue,
// starting with the individual characters. We also populate each node with a positional
// bias to break ties in the priority queue.
let startingNode = {
token: word[0],
bias: 0,
prev: null,
next: null,
}
let previousNode = startingNode
for (let i = 1; i < word.length; ++i) {
const currentNode = {
bias: i / word.length, // Add fractional component to break ties
token: word[i],
prev: previousNode,
next: null,
}
previousNode.next = currentNode
this._add_node(queue, previousNode)
previousNode = currentNode
}
while (!queue.isEmpty()) {
// Get the next node with the highest priority
const node = queue.pop();
// Check that this merge is still possible
if (node.deleted || !node.next || node.next.deleted) continue;
// Here, we mark the current node (left side of the merge) and the next node (right side of the merge) as deleted.
// This is because they will both be replaced by a new node representing the merge result.
node.deleted = true;
node.next.deleted = true;
// Next, we fix the node that comes before the current node (i.e., left side of the merge).
if (node.prev) {
// Make a shallow copy of the previous node
const newPreviousNode = { ...node.prev };
// Mark the old previous node as deleted. This avoids erroneous merges later,
// because there may still be references to this node in the priority queue.
node.prev.deleted = true;
node.prev = newPreviousNode;
// Update the reference of the previous node, by pointing its previous node to this new previous node.
if (newPreviousNode.prev) {
newPreviousNode.prev.next = newPreviousNode;
} else {
// If the previous of the previous node does not exist, it means that
// `newPreviousNode` must be the new `startingNode`.
startingNode = newPreviousNode;
}
}
// Create a new node which represents the result of the merge.
const merged = {
token: node.token + node.next.token,
bias: node.bias,
prev: node.prev,
next: node.next.next,
}
// We now consider where we can add the new merged node to the priority queue:
// 1. prev <-> merged
if (merged.prev) {
merged.prev.next = merged;
this._add_node(queue, merged.prev);
} else {
// If `merged.prev` does not exist, then `merged` must be the new `startingNode`.
startingNode = merged;
}
// 2. merged <-> next
if (merged.next) {
merged.next.prev = merged;
this._add_node(queue, merged);
}
}
// Traverse the linked list, starting from the `startingNode`, and collect the tokens.
for (let currentNode = startingNode; currentNode !== null; currentNode = currentNode.next) {
result.push(currentNode.token);
}
} else {
result = word;
}
// Possibly append suffix
if (this.continuing_subword_suffix) {
// Do not append suffix to the last token
for (let i = 0; i < result.length - 1; ++i) {
result[i] += this.continuing_subword_suffix;
}
}
// Save the result to the cache
this.cache.set(token, result);
return result;
}
/**
* Helper function to add a node to the priority queue.
* @param {PriorityQueue} queue
* @param {BPENode} node
* @private
*/
_add_node(queue, node) {
// `score` is a measure of the merge priority: lower means higher priority
// We use the BPE rank as a measure of priority (i.e., the local of the merge in the merges list)
// We also add a fractional component to the score to break ties (with the earlier character having higher priority)
const rank = this.bpe_ranks.get(node.token + this.BPE_SPLIT_TOKEN + node.next.token);
if (rank !== undefined) {
node.score = rank + node.bias;
queue.push(node);
}
}
/**
* Encodes the input sequence of tokens using the BPE algorithm and returns the resulting subword tokens.
* @param {string[]} tokens The input sequence of tokens to encode.
* @returns {string[]} The resulting subword tokens after applying the BPE algorithm to the input sequence of tokens.
*/
encode(tokens) {
let outputTokens = [];
for (let token of tokens) {
let bpe_token_list = this.bpe(token);
for (let t of bpe_token_list) {
if (this.tokens_to_ids.has(t)) {
outputTokens.push(t);
} else {
if (this.byte_fallback) {
outputTokens.push(
...Array.from(this.text_encoder.encode(t))
.map(x => `<0x${x.toString(16).toUpperCase().padStart(2, '0')}>`)
);
} else {
outputTokens.push(this.unk_token);
}
}
}
}
return outputTokens;
}
}
/**
* Legacy tokenizer class for tokenizers with only a vocabulary.
*/
class LegacyTokenizerModel extends TokenizerModel {
/**
* Create a LegacyTokenizerModel instance.
* @param {Object} config The configuration object for LegacyTokenizerModel.
* @param {Object} config.vocab A (possibly nested) mapping of tokens to ids.
* @param {Object} moreConfig Additional configuration object for the LegacyTokenizerModel model.
*/
constructor(config, moreConfig) {
super(config);
/**@type {Map<string, number>} */
this.tokens_to_ids = objectToMap(
moreConfig.target_lang
? config.vocab[moreConfig.target_lang]
: config.vocab
);
this.bos_token = moreConfig.bos_token;
this.bos_token_id = this.tokens_to_ids.get(this.bos_token);
this.eos_token = moreConfig.eos_token;
this.eos_token_id = this.tokens_to_ids.get(this.eos_token);
this.pad_token = moreConfig.pad_token;
this.pad_token_id = this.tokens_to_ids.get(this.pad_token);
this.unk_token = moreConfig.unk_token;
this.unk_token_id = this.tokens_to_ids.get(this.unk_token);
this.vocab = new Array(this.tokens_to_ids.size);
for (const [key, value] of this.tokens_to_ids) {
this.vocab[value] = key;
}
}
encode(tokens) {
return tokens;
}
}
/**
* A base class for text normalization.
* @abstract
*/
class Normalizer extends Callable {
/**
* @param {Object} config The configuration object for the normalizer.
*/
constructor(config) {
super();
this.config = config;
}
/**
* Factory method for creating normalizers from config objects.
* @static
* @param {Object} config The configuration object for the normalizer.
* @returns {Normalizer} A Normalizer object.
* @throws {Error} If an unknown Normalizer type is specified in the config.
*/
static fromConfig(config) {
if (config === null) return null;
switch (config.type) {
case 'BertNormalizer':
return new BertNormalizer(config);
case 'Precompiled':
return new Precompiled(config);
case 'Sequence':
return new NormalizerSequence(config);
case 'Replace':
return new Replace(config);
case 'NFC':
return new NFC(config);
case 'NFKC':
return new NFKC(config);
case 'NFKD':
return new NFKD(config);
case 'Strip':
return new StripNormalizer(config);
case 'StripAccents':
return new StripAccents(config);
case 'Lowercase':
return new Lowercase(config);
case 'Prepend':
return new Prepend(config);
default:
throw new Error(`Unknown Normalizer type: ${config.type}`);
}
}
/**
* Normalize the input text.
* @abstract
* @param {string} text The text to normalize.
* @returns {string} The normalized text.
* @throws {Error} If this method is not implemented in a subclass.
*/
normalize(text) {
throw Error("normalize should be implemented in subclass.")
}
/**
* Alias for {@link Normalizer#normalize}.
* @param {string} text The text to normalize.
* @returns {string} The normalized text.
*/
_call(text) {
return this.normalize(text);
}
}
/**
* Replace normalizer that replaces occurrences of a pattern with a given string or regular expression.
* @extends Normalizer
*/
class Replace extends Normalizer {
/**
* Normalize the input text by replacing the pattern with the content.
* @param {string} text The input text to be normalized.
* @returns {string} The normalized text after replacing the pattern with the content.
*/
normalize(text) {
let pattern = createPattern(this.config.pattern);
if (pattern === null) {
return text;
}
text = text.replaceAll(pattern, this.config.content)
return text;
}
}
/**
* A normalizer that applies Unicode normalization form C (NFC) to the input text.
* @extends Normalizer
*/
class NFC extends Normalizer {
/**
* Normalize the input text by applying Unicode normalization form C (NFC).
* @param {string} text The input text to be normalized.
* @returns {string} The normalized text.
*/
normalize(text) {
text = text.normalize('NFC')
return text;
}
}
/**
* NFKC Normalizer.
* @extends Normalizer
*/
class NFKC extends Normalizer {
/**
* Normalize text using NFKC normalization.
* @param {string} text The text to be normalized.
* @returns {string} The normalized text.
*/
normalize(text) {
text = text.normalize('NFKC')
return text;
}
}
/**
* NFKD Normalizer.
* @extends Normalizer
*/
class NFKD extends Normalizer {
/**
* Normalize text using NFKD normalization.
* @param {string} text The text to be normalized.
* @returns {string} The normalized text.
*/
normalize(text) {
text = text.normalize('NFKD')
return text;
}
}
/**
* A normalizer that strips leading and/or trailing whitespace from the input text.
*/
class StripNormalizer extends Normalizer {
/**
* Strip leading and/or trailing whitespace from the input text.
* @param {string} text The input text.
* @returns {string} The normalized text.
*/
normalize(text) {
if (this.config.strip_left && this.config.strip_right) {
// Fast path to avoid an extra trim call
text = text.trim();
} else {
if (this.config.strip_left) {
text = text.trimStart();
}
if (this.config.strip_right) {
text = text.trimEnd();
}
}
return text;
}
}
/**
* StripAccents normalizer removes all accents from the text.
* @extends Normalizer
*/
class StripAccents extends Normalizer {
/**
* Remove all accents from the text.
* @param {string} text The input text.
* @returns {string} The normalized text without accents.
*/
normalize(text) {
text = remove_accents(text);
return text;
}
}
/**
* A Normalizer that lowercases the input string.
* @extends Normalizer
*/
class Lowercase extends Normalizer {
/**
* Lowercases the input string.
* @param {string} text The text to normalize.
* @returns {string} The normalized text.
*/
normalize(text) {
text = text.toLowerCase();
return text;
}
}
/**
* A Normalizer that prepends a string to the input string.
* @extends Normalizer
*/
class Prepend extends Normalizer {
/**
* Prepends the input string.
* @param {string} text The text to normalize.
* @returns {string} The normalized text.
*/
normalize(text) {
text = this.config.prepend + text;
return text;
}
}
/**
* A Normalizer that applies a sequence of Normalizers.
* @extends Normalizer
*/
class NormalizerSequence extends Normalizer {
/**
* Create a new instance of NormalizerSequence.
* @param {Object} config The configuration object.
* @param {Object[]} config.normalizers An array of Normalizer configuration objects.
*/
constructor(config) {
super(config);
this.normalizers = config.normalizers.map(x => Normalizer.fromConfig(x));
}
/**
* Apply a sequence of Normalizers to the input text.
* @param {string} text The text to normalize.
* @returns {string} The normalized text.
*/
normalize(text) {
return this.normalizers.reduce((t, normalizer) => {
return normalizer.normalize(t);
}, text);
}
}
/**
* A class representing a normalizer used in BERT tokenization.
* @extends Normalizer
*/
class BertNormalizer extends Normalizer {
/**
* Adds whitespace around any CJK (Chinese, Japanese, or Korean) character in the input text.
*
* @param {string} text The input text to tokenize.
* @returns {string} The tokenized text with whitespace added around CJK characters.
*/
_tokenize_chinese_chars(text) {
/* Adds whitespace around any CJK character. */
let output = [];
for (let i = 0; i < text.length; ++i) {
let char = text[i];
let cp = char.charCodeAt(0);
if (this._is_chinese_char(cp)) {
output.push(" ");
output.push(char);
output.push(" ");
} else {
output.push(char);
}
}
return output.join("");
}
/**
* Checks whether the given Unicode codepoint represents a CJK (Chinese, Japanese, or Korean) character.
*
* A "chinese character" is defined as anything in the CJK Unicode block:
* https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
*
* Note that the CJK Unicode block is NOT all Japanese and Korean characters, despite its name.
* The modern Korean Hangul alphabet is a different block, as is Japanese Hiragana and Katakana.
* Those alphabets are used to write space-separated words, so they are not treated specially
* and are handled like all other languages.
*
* @param {number} cp The Unicode codepoint to check.
* @returns {boolean} True if the codepoint represents a CJK character, false otherwise.
*/
_is_chinese_char(cp) {
return (
(cp >= 0x4E00 && cp <= 0x9FFF)
|| (cp >= 0x3400 && cp <= 0x4DBF)
|| (cp >= 0x20000 && cp <= 0x2A6DF)
|| (cp >= 0x2A700 && cp <= 0x2B73F)
|| (cp >= 0x2B740 && cp <= 0x2B81F)
|| (cp >= 0x2B820 && cp <= 0x2CEAF)
|| (cp >= 0xF900 && cp <= 0xFAFF)
|| (cp >= 0x2F800 && cp <= 0x2FA1F)
)
}
/**
* Strips accents from the given text.
* @param {string} text The text to strip accents from.
* @returns {string} The text with accents removed.
*/
stripAccents(text) {
return text.normalize('NFD').replace(/[\u0300-\u036f]/g, '');
}
/**
* Checks whether `char` is a control character.
* @param {string} char The character to check.
* @returns {boolean} Whether `char` is a control character.
* @private
*/
_is_control(char) {
switch (char) {
case '\t':
case '\n':
case '\r':
// These are technically control characters but we count them as whitespace characters.
return false;
default:
// Check if unicode category starts with C:
// Cc - Control
// Cf - Format
// Co - Private Use
// Cs - Surrogate
return /^\p{Cc}|\p{Cf}|\p{Co}|\p{Cs}$/u.test(char);
}
}
/**
* Performs invalid character removal and whitespace cleanup on text.
* @param {string} text The text to clean.
* @returns {string} The cleaned text.
* @private
*/
_clean_text(text) {
const output = [];
for (const char of text) {
const cp = char.charCodeAt(0);
if (cp === 0 || cp === 0xFFFD || this._is_control(char)) {
continue;
}
if (/^\s$/.test(char)) { // is whitespace
output.push(" ");
} else {
output.push(char);
}
}
return output.join("");
}
/**
* Normalizes the given text based on the configuration.
* @param {string} text The text to normalize.
* @returns {string} The normalized text.
*/
normalize(text) {
if (this.config.clean_text) {
text = this._clean_text(text);
}
if (this.config.handle_chinese_chars) {
text = this._tokenize_chinese_chars(text);
}
if (this.config.lowercase) {
text = text.toLowerCase();
if (this.config.strip_accents !== false) {
text = this.stripAccents(text);
}
} else if (this.config.strip_accents) {
text = this.stripAccents(text);
}
return text;
}
}
/**
* A callable class representing a pre-tokenizer used in tokenization. Subclasses
* should implement the `pre_tokenize_text` method to define the specific pre-tokenization logic.
* @extends Callable
*/
class PreTokenizer extends Callable {
/**
* Factory method that returns an instance of a subclass of `PreTokenizer` based on the provided configuration.
*
* @static
* @param {Object} config A configuration object for the pre-tokenizer.
* @returns {PreTokenizer} An instance of a subclass of `PreTokenizer`.
* @throws {Error} If the provided configuration object does not correspond to any known pre-tokenizer.
*/
static fromConfig(config) {
if (config === null) return null;
switch (config.type) {
case 'BertPreTokenizer':
return new BertPreTokenizer(config);
case 'Sequence':
return new PreTokenizerSequence(config);
case 'WhitespaceSplit':
return new WhitespaceSplit(config);
case 'Metaspace':
return new MetaspacePreTokenizer(config);
case 'ByteLevel':
return new ByteLevelPreTokenizer(config);
case 'Split':
return new SplitPreTokenizer(config);
case 'Punctuation':
return new PunctuationPreTokenizer(config);
case 'Digits':
return new DigitsPreTokenizer(config);
case 'Replace':
return new ReplacePreTokenizer(config);
default:
throw new Error(`Unknown PreTokenizer type: ${config.type}`);
}
}
/**
* Method that should be implemented by subclasses to define the specific pre-tokenization logic.
*
* @abstract
* @param {string} text The text to pre-tokenize.
* @param {Object} [options] Additional options for the pre-tokenization logic.
* @returns {string[]} The pre-tokenized text.
* @throws {Error} If the method is not implemented in the subclass.
*/
pre_tokenize_text(text, options) {
throw Error("pre_tokenize_text should be implemented in subclass.")
}
/**
* Tokenizes the given text into pre-tokens.
* @param {string|string[]} text The text or array of texts to pre-tokenize.
* @param {Object} [options] Additional options for the pre-tokenization logic.
* @returns {string[]} An array of pre-tokens.
*/
pre_tokenize(text, options) {
let result = [];
if (Array.isArray(text)) {
result = text.map(x => this.pre_tokenize_text(x, options))
} else {
result = this.pre_tokenize_text(text, options);
}
return result.flat();
}
/**
* Alias for {@link PreTokenizer#pre_tokenize}.
* @param {string|string[]} text The text or array of texts to pre-tokenize.
* @param {Object} [options] Additional options for the pre-tokenization logic.
* @returns {string[]} An array of pre-tokens.
*/
_call(text, options) {
return this.pre_tokenize(text, options);
}
}
/**
* @extends PreTokenizer
*/
class BertPreTokenizer extends PreTokenizer {
/**
* A PreTokenizer that splits text into wordpieces using a basic tokenization scheme
* similar to that used in the original implementation of BERT.
*
* @param {Object} config The configuration object.
*/
constructor(config) {
super();
// Construct a pattern which matches the rust implementation:
// https://github.com/huggingface/tokenizers/blob/b4fcc9ce6e4ad5806e82826f816acfdfdc4fcc67/tokenizers/src/pre_tokenizers/bert.rs#L11
// Equivalent to removing whitespace and splitting on punctuation (both \p{P} and other ascii characters)
this.pattern = new RegExp(`[^\\s${PUNCTUATION_REGEX}]+|[${PUNCTUATION_REGEX}]`, 'gu');
}
/**
* Tokenizes a single text using the BERT pre-tokenization scheme.
*
* @param {string} text The text to tokenize.
* @param {Object} [options] Additional options for the pre-tokenization logic.
* @returns {string[]} An array of tokens.
*/
pre_tokenize_text(text, options) {
return text.trim().match(this.pattern) || [];
}
}
/**
* A pre-tokenizer that splits text into Byte-Pair-Encoding (BPE) subwords.
* @extends PreTokenizer
*/
class ByteLevelPreTokenizer extends PreTokenizer {
/**
* Creates a new instance of the `ByteLevelPreTokenizer` class.
* @param {Object} config The configuration object.
*/
constructor(config) {
super();
this.config = config;
/**
* @type {boolean} Whether to add a leading space to the first word.
* This allows to treat the leading word just as any other word.
*/
this.add_prefix_space = this.config.add_prefix_space;
/**
* @type {boolean} Whether the post processing step should trim offsets
* to avoid including whitespaces.
* @todo Use this in the pretokenization step.
*/
this.trim_offsets = this.config.trim_offsets;
/**
* @type {boolean} Whether to use the standard GPT2 regex for whitespace splitting.
* Set it to False if you want to use your own splitting. Defaults to true.
*/
this.use_regex = this.config.use_regex ?? true;
this.pattern = /'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/gu;
this.byte_encoder = BYTES_TO_UNICODE;
this.text_encoder = new TextEncoder();
}
/**
* Tokenizes a single piece of text using byte-level tokenization.
* @param {string} text The text to tokenize.
* @param {Object} [options] Additional options for the pre-tokenization logic.
* @returns {string[]} An array of tokens.
*/
pre_tokenize_text(text, options) {
// Add a leading space if the option is enabled
if (this.add_prefix_space && !text.startsWith(' ')) {
text = ' ' + text;
}
// Split on whitespace and punctuation
let tokens = this.use_regex ? (text.match(this.pattern) || []) : [text];
// Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
return tokens.map(
token => Array.from(this.text_encoder.encode(token), byte => this.byte_encoder[byte]).join('')
);
}
}
/**
* @typedef {'removed'|'isolated'|'mergedWithPrevious'|'mergedWithNext'|'contiguous'} SplitDelimiterBehavior
*/
/**
* Splits text using a given pattern.
* @extends PreTokenizer
*/
class SplitPreTokenizer extends PreTokenizer {
/**
* @param {Object} config The configuration options for the pre-tokenizer.
* @param {Object} config.pattern The pattern used to split the text. Can be a string or a regex object.
* @param {string|undefined} config.pattern.String The string to use for splitting. Only defined if the pattern is a string.
* @param {string|undefined} config.pattern.Regex The regex to use for splitting. Only defined if the pattern is a regex.
* @param {SplitDelimiterBehavior} config.behavior The behavior to use when splitting.
* @param {boolean} config.invert Whether to split (invert=false) or match (invert=true) the pattern.
*/
constructor(config) {
super();
this.config = config;
// TODO support all behaviours (config.behavior)
this.pattern = createPattern(this.config.pattern, this.config.invert);
}
/**
* Tokenizes text by splitting it using the given pattern.
* @param {string} text The text to tokenize.
* @param {Object} [options] Additional options for the pre-tokenization logic.
* @returns {string[]} An array of tokens.
*/
pre_tokenize_text(text, options) {
if (this.pattern === null) {
return [];
}
if (this.config.invert) {
return text.match(this.pattern) || [];
} else {
return regexSplit(text, this.pattern);
}
}
}
/**
* Splits text based on punctuation.
* @extends PreTokenizer
*/
class PunctuationPreTokenizer extends PreTokenizer {