gpt-tokenizer
Version:
A pure JavaScript implementation of a BPE tokenizer (Encoder/Decoder) for GPT-2 / GPT-3 / GPT-4 and other OpenAI models
417 lines • 17.1 kB
JavaScript
"use strict";
/* eslint-disable no-continue */
Object.defineProperty(exports, "__esModule", { value: true });
exports.BytePairEncodingCore = exports.decoder = void 0;
const constants_js_1 = require("./constants.js");
const utfUtil_js_1 = require("./utfUtil.js");
const util_js_1 = require("./util.js");
const emptyBuffer = new Uint8Array(0);
exports.decoder = new TextDecoder('utf8');
class BytePairEncodingCore {
mergeableBytePairRankCount;
/**
* an array where the index is the BPE rank,
* and the value is the string or the array of bytes that it decodes to
* it may contain holes if token is unused
*/
bytePairRankDecoder;
bytePairNonUtfRankDecoder = new Map();
bytePairNonUtfSortedEncoder;
/**
* a reverse map of the bytePairRankDecoder,
* where the key is the string and the value is the rank
* values that cannot be represented as a string are present in `bytePairNonUtfSortedEncoder`
*/
bytePairStringRankEncoder;
tokenSplitRegex;
specialTokensEncoder;
specialTokensDecoder;
specialTokenPatternRegex;
textEncoder = new TextEncoder();
mergeCache;
mergeCacheSize;
constructor({ bytePairRankDecoder, specialTokensEncoder, tokenSplitRegex, mergeCacheSize = constants_js_1.DEFAULT_MERGE_CACHE_SIZE, }) {
this.bytePairRankDecoder = bytePairRankDecoder;
this.bytePairStringRankEncoder = new Map();
this.mergeCacheSize = mergeCacheSize;
if (mergeCacheSize > 0) {
this.mergeCache = new Map();
}
// size without array holes (which may be present in the encoder)
this.mergeableBytePairRankCount = Object.keys(bytePairRankDecoder).length;
const binaryLookup = [];
// forEach skips array holes:
bytePairRankDecoder.forEach((value, rank) => {
if (typeof value === 'string') {
this.bytePairStringRankEncoder.set(value, rank);
return;
}
const byteArray = new Uint8Array(value);
binaryLookup.push([byteArray, rank]);
this.bytePairNonUtfRankDecoder.set(rank, byteArray);
});
this.bytePairNonUtfSortedEncoder = binaryLookup.sort((a, b) => (0, utfUtil_js_1.compareUint8Arrays)(a[0], b[0]));
this.specialTokensEncoder =
specialTokensEncoder ?? new Map();
this.specialTokensDecoder = specialTokensEncoder
? new Map([...specialTokensEncoder].map(([key, value]) => [value, key]))
: new Map();
this.tokenSplitRegex = tokenSplitRegex;
const escapedSpecialTokens = [...this.specialTokensEncoder.keys()].map(util_js_1.escapeRegExp);
const allSpecialTokensRegex = escapedSpecialTokens.join('|');
try {
this.specialTokenPatternRegex = new RegExp(allSpecialTokensRegex, 'y');
}
catch {
throw new Error('Invalid regular expression pattern.');
}
}
setMergeCacheSize(newSize) {
if (this.mergeCacheSize === 0 && newSize > 0) {
this.mergeCache = new Map();
}
this.mergeCacheSize = newSize;
if (newSize === 0) {
this.mergeCache = undefined;
}
}
clearMergeCache() {
this.mergeCache?.clear();
}
*encodeNativeGenerator(text, allowedSpecial) {
let startIndex = 0;
let lastTokenLength = 0;
while (true) {
const nextSpecialMatch = this.findNextSpecialToken(text, allowedSpecial, startIndex);
const nextSpecialStartIndex = nextSpecialMatch?.[0];
const endIndex = nextSpecialStartIndex ?? text.length;
const textBeforeSpecial = startIndex === 0 && endIndex === text.length
? text
: text.slice(startIndex, endIndex);
for (const [match] of textBeforeSpecial.matchAll(this.tokenSplitRegex)) {
const token = this.getBpeRankFromString(match);
if (token !== undefined) {
lastTokenLength = 1;
yield [token];
continue;
}
const tokens = this.bytePairEncode(match);
lastTokenLength = tokens.length;
yield tokens;
}
if (nextSpecialStartIndex !== undefined) {
const specialToken = nextSpecialMatch[1];
const specialTokenValue = this.specialTokensEncoder.get(specialToken);
if (specialTokenValue === undefined) {
throw new Error(`Special token "${specialToken}" is not in the special token encoder.`);
}
yield [specialTokenValue];
startIndex = nextSpecialStartIndex + specialToken.length;
lastTokenLength = 1;
}
else {
break;
}
}
return lastTokenLength;
}
encodeNative(text, allowedSpecial) {
let startIndex = 0;
const tokensArray = []; // Flat list to collect the tokens
// eslint-disable-next-line no-constant-condition
while (true) {
const nextSpecialMatch = this.findNextSpecialToken(text, allowedSpecial, startIndex);
const nextSpecialStartIndex = nextSpecialMatch?.[0];
const endIndex = nextSpecialStartIndex ?? text.length;
const textBeforeSpecial = startIndex === 0 && endIndex === text.length
? text
: text.slice(startIndex, endIndex);
for (const [match] of textBeforeSpecial.matchAll(this.tokenSplitRegex)) {
const token = this.getBpeRankFromString(match);
if (token !== undefined) {
tokensArray.push(token);
continue;
}
const tokens = this.bytePairEncode(match);
tokensArray.push(...tokens);
}
if (nextSpecialStartIndex !== undefined) {
const specialToken = nextSpecialMatch[1];
const specialTokenValue = this.specialTokensEncoder.get(specialToken);
if (specialTokenValue === undefined) {
throw new Error(`Special token "${specialToken}" is not in the special token encoder.`);
}
tokensArray.push(specialTokenValue);
startIndex = nextSpecialStartIndex + specialToken.length;
}
else {
break;
}
}
return tokensArray;
}
countNative(text, allowedSpecial) {
let startIndex = 0;
let tokensCount = 0;
// eslint-disable-next-line no-constant-condition
while (true) {
const nextSpecialMatch = this.findNextSpecialToken(text, allowedSpecial, startIndex);
const nextSpecialStartIndex = nextSpecialMatch?.[0];
const endIndex = nextSpecialStartIndex ?? text.length;
const textBeforeSpecial = startIndex === 0 && endIndex === text.length
? text
: text.slice(startIndex, endIndex);
for (const [match] of textBeforeSpecial.matchAll(this.tokenSplitRegex)) {
const token = this.getBpeRankFromString(match);
if (token !== undefined) {
tokensCount++;
continue;
}
const tokens = this.bytePairEncode(match);
tokensCount += tokens.length;
}
if (nextSpecialStartIndex !== undefined) {
const specialToken = nextSpecialMatch[1];
const specialTokenValue = this.specialTokensEncoder.get(specialToken);
if (specialTokenValue === undefined) {
throw new Error(`Special token "${specialToken}" is not in the special token encoder.`);
}
tokensCount++;
startIndex = nextSpecialStartIndex + specialToken.length;
}
else {
break;
}
}
return tokensCount;
}
*decodeNativeGenerator(tokens) {
for (const token of tokens) {
const tokenBytes = this.tryDecodeToken(token);
if (tokenBytes) {
yield tokenBytes;
}
}
}
decodeNative(tokens) {
let decoded = '';
let intBuffer = emptyBuffer;
for (const token of tokens) {
const tokenBytes = this.tryDecodeToken(token);
if (tokenBytes === undefined) {
throw new Error(`Token ${token} is not in the byte pair encoder.`);
}
if (typeof tokenBytes === 'string') {
if (intBuffer !== emptyBuffer) {
decoded += exports.decoder.decode(intBuffer, { stream: true });
intBuffer = emptyBuffer;
}
decoded += tokenBytes;
}
else {
const newBuffer = new Uint8Array(intBuffer.length + tokenBytes.length);
newBuffer.set(intBuffer);
newBuffer.set(tokenBytes, intBuffer.length);
intBuffer = newBuffer;
}
}
if (intBuffer !== emptyBuffer) {
decoded += exports.decoder.decode(intBuffer, { stream: true });
}
return decoded;
}
async *decodeNativeAsyncIterable(tokens) {
for await (const token of tokens) {
const tokenBytesOrString = this.tryDecodeToken(token);
if (tokenBytesOrString) {
yield tokenBytesOrString;
}
}
}
getBpeRankFromString(key) {
return this.bytePairStringRankEncoder.get(key);
}
getBpeRankFromStringOrThrow(key) {
const value = this.getBpeRankFromString(key);
if (value === undefined) {
throw new Error(`The byte-pair encoding does not contain a value for: ${key}`);
}
return value;
}
getBpeRankFromBytes(key) {
const keyAsString = (0, utfUtil_js_1.tryConvertToString)(key);
if (keyAsString !== undefined) {
return this.getBpeRankFromString(keyAsString);
}
// Perform binary search on the binary keys
const index = this.binarySearch(key);
if (index !== -1) {
return this.bytePairNonUtfSortedEncoder[index][1];
}
return undefined;
}
getBpeRankFromBytesOrThrow(key) {
const value = this.getBpeRankFromBytes(key);
if (value === undefined) {
throw new Error(`The byte-pair encoding does not contain a value for: ${key.toString()}`);
}
return value;
}
// Binary search on the binary keys
binarySearch(key) {
let low = 0;
let high = this.bytePairNonUtfSortedEncoder.length - 1;
while (low <= high) {
// eslint-disable-next-line no-bitwise
const mid = (low + high) >>> 1;
const midKey = this.bytePairNonUtfSortedEncoder[mid][0];
let cmp = 0;
const maxLength = Math.min(midKey.length, key.length);
for (let i = 0; i < maxLength; i++) {
cmp = midKey[i] - key[i];
if (cmp !== 0)
break;
}
if (cmp === 0) {
cmp = midKey.length - key.length;
}
if (cmp === 0) {
return mid;
}
if (cmp < 0) {
low = mid + 1;
}
else {
high = mid - 1;
}
}
return -1;
}
findNextSpecialToken(text, allowedSpecial, startIndex) {
let searchIndex = startIndex;
// eslint-disable-next-line no-constant-condition
while (true) {
this.specialTokenPatternRegex.lastIndex = searchIndex;
const nextSpecialMatch = this.specialTokenPatternRegex.exec(text);
if (!nextSpecialMatch) {
return undefined;
}
const specialToken = nextSpecialMatch[0];
if (allowedSpecial?.has(specialToken)) {
const specialTokenStartIndex = nextSpecialMatch.index + searchIndex;
return [specialTokenStartIndex, specialToken];
}
searchIndex = nextSpecialMatch.index + searchIndex + 1;
}
}
tryDecodeToken(tokenRank) {
const value = this.bytePairRankDecoder[tokenRank];
if (typeof value === 'string') {
return value;
}
if (typeof value === 'object') {
const fromBinary = this.bytePairNonUtfRankDecoder.get(tokenRank);
if (fromBinary) {
return fromBinary;
}
}
return this.specialTokensDecoder.get(tokenRank);
}
addToMergeCache(key, value) {
if (!this.mergeCache)
return;
if (this.mergeCache.size >= this.mergeCacheSize) {
// Remove least recently used item (first item)
const firstKey = this.mergeCache.keys().next().value;
this.mergeCache.delete(firstKey);
}
this.mergeCache.set(key, value);
}
bytePairEncode(input) {
if (input.length === 1 && (0, utfUtil_js_1.isAscii)(input.codePointAt(0))) {
return [this.getBpeRankFromStringOrThrow(input)];
}
if (this.mergeCache?.has(input)) {
const result = this.mergeCache.get(input);
// Move to end to mark as recently used
this.mergeCache.delete(input);
this.mergeCache.set(input, result);
return result;
}
const inputBytes = this.textEncoder.encode(input);
const result = this.bytePairMerge(inputBytes);
this.addToMergeCache(input, result);
return result;
}
bytePairMerge(
// Input array of bytes to process
piece) {
// 'starts' holds the start indices of each partition
const starts = [];
// 'ranks' holds the BPE ranks of each partition pair
const ranks = [];
// Helper function to get the rank of a byte pair starting at 'startIndex'
const getRank = (startIndex, pairStart = starts[startIndex], pairEnd = starts[startIndex + 2]) => {
if (pairEnd === undefined) {
// No valid pair exists
return Number.POSITIVE_INFINITY;
}
// Extract the byte pair
const key = piece.subarray(pairStart, pairEnd);
// Retrieve the BPE rank of this byte pair (if it exists)
const rank = this.getBpeRankFromBytes(key);
return rank ?? Number.POSITIVE_INFINITY;
};
// Initialize the 'starts' array with all possible start indices
for (let i = 0; i <= piece.length; i++) {
starts.push(i);
if (i < piece.length - 1) {
// Initialize the BPE values for all adjacent pairs
ranks.push(getRank(i, i, i + 2));
}
else {
// Initialize BPE values to infinity for the last pair
ranks.push(Number.POSITIVE_INFINITY);
}
}
// Iteratively merge byte pairs until no more useful merges can be done
while (starts.length > 1) {
let lowestRank = Number.POSITIVE_INFINITY;
let lowestPartitionIndex = -1;
// Find the partition with the minimum rank
for (let i = 0; i < ranks.length - 1; i++) {
const rank = ranks[i];
if (rank < lowestRank) {
lowestRank = rank;
lowestPartitionIndex = i;
}
}
// If no valid pair is left to merge, exit the loop
if (lowestRank === Number.POSITIVE_INFINITY ||
lowestPartitionIndex === -1) {
break;
}
// Merge the pair at 'lowestPartitionIndex' by removing the next start index
starts.splice(lowestPartitionIndex + 1, 1);
// Remove the BPE value of the merged pair
ranks.splice(lowestPartitionIndex, 1);
// Update the current merged pair's rank
ranks[lowestPartitionIndex] = getRank(lowestPartitionIndex);
// Update the rank of the previous pair, if it exists
if (lowestPartitionIndex > 0) {
ranks[lowestPartitionIndex - 1] = getRank(lowestPartitionIndex - 1);
}
}
// Create the final output by applying the transform function to each partitioned range
const output = [];
for (let i = 0; i < starts.length - 1; i++) {
const pairStart = starts[i];
const pairEnd = starts[i + 1];
const bpeValue = this.getBpeRankFromBytesOrThrow(piece.subarray(pairStart, pairEnd));
output.push(bpeValue);
}
return output;
}
}
exports.BytePairEncodingCore = BytePairEncodingCore;
//# sourceMappingURL=BytePairEncodingCore.js.map