UNPKG

gpt-tokenizer

Version:

A pure JavaScript implementation of a BPE tokenizer (Encoder/Decoder) for GPT-2 / GPT-3 / GPT-4 and other OpenAI models

417 lines 17.1 kB
"use strict"; /* eslint-disable no-continue */ Object.defineProperty(exports, "__esModule", { value: true }); exports.BytePairEncodingCore = exports.decoder = void 0; const constants_js_1 = require("./constants.js"); const utfUtil_js_1 = require("./utfUtil.js"); const util_js_1 = require("./util.js"); const emptyBuffer = new Uint8Array(0); exports.decoder = new TextDecoder('utf8'); class BytePairEncodingCore { mergeableBytePairRankCount; /** * an array where the index is the BPE rank, * and the value is the string or the array of bytes that it decodes to * it may contain holes if token is unused */ bytePairRankDecoder; bytePairNonUtfRankDecoder = new Map(); bytePairNonUtfSortedEncoder; /** * a reverse map of the bytePairRankDecoder, * where the key is the string and the value is the rank * values that cannot be represented as a string are present in `bytePairNonUtfSortedEncoder` */ bytePairStringRankEncoder; tokenSplitRegex; specialTokensEncoder; specialTokensDecoder; specialTokenPatternRegex; textEncoder = new TextEncoder(); mergeCache; mergeCacheSize; constructor({ bytePairRankDecoder, specialTokensEncoder, tokenSplitRegex, mergeCacheSize = constants_js_1.DEFAULT_MERGE_CACHE_SIZE, }) { this.bytePairRankDecoder = bytePairRankDecoder; this.bytePairStringRankEncoder = new Map(); this.mergeCacheSize = mergeCacheSize; if (mergeCacheSize > 0) { this.mergeCache = new Map(); } // size without array holes (which may be present in the encoder) this.mergeableBytePairRankCount = Object.keys(bytePairRankDecoder).length; const binaryLookup = []; // forEach skips array holes: bytePairRankDecoder.forEach((value, rank) => { if (typeof value === 'string') { this.bytePairStringRankEncoder.set(value, rank); return; } const byteArray = new Uint8Array(value); binaryLookup.push([byteArray, rank]); this.bytePairNonUtfRankDecoder.set(rank, byteArray); }); this.bytePairNonUtfSortedEncoder = binaryLookup.sort((a, b) => (0, utfUtil_js_1.compareUint8Arrays)(a[0], b[0])); this.specialTokensEncoder = specialTokensEncoder ?? new Map(); this.specialTokensDecoder = specialTokensEncoder ? new Map([...specialTokensEncoder].map(([key, value]) => [value, key])) : new Map(); this.tokenSplitRegex = tokenSplitRegex; const escapedSpecialTokens = [...this.specialTokensEncoder.keys()].map(util_js_1.escapeRegExp); const allSpecialTokensRegex = escapedSpecialTokens.join('|'); try { this.specialTokenPatternRegex = new RegExp(allSpecialTokensRegex, 'y'); } catch { throw new Error('Invalid regular expression pattern.'); } } setMergeCacheSize(newSize) { if (this.mergeCacheSize === 0 && newSize > 0) { this.mergeCache = new Map(); } this.mergeCacheSize = newSize; if (newSize === 0) { this.mergeCache = undefined; } } clearMergeCache() { this.mergeCache?.clear(); } *encodeNativeGenerator(text, allowedSpecial) { let startIndex = 0; let lastTokenLength = 0; while (true) { const nextSpecialMatch = this.findNextSpecialToken(text, allowedSpecial, startIndex); const nextSpecialStartIndex = nextSpecialMatch?.[0]; const endIndex = nextSpecialStartIndex ?? text.length; const textBeforeSpecial = startIndex === 0 && endIndex === text.length ? text : text.slice(startIndex, endIndex); for (const [match] of textBeforeSpecial.matchAll(this.tokenSplitRegex)) { const token = this.getBpeRankFromString(match); if (token !== undefined) { lastTokenLength = 1; yield [token]; continue; } const tokens = this.bytePairEncode(match); lastTokenLength = tokens.length; yield tokens; } if (nextSpecialStartIndex !== undefined) { const specialToken = nextSpecialMatch[1]; const specialTokenValue = this.specialTokensEncoder.get(specialToken); if (specialTokenValue === undefined) { throw new Error(`Special token "${specialToken}" is not in the special token encoder.`); } yield [specialTokenValue]; startIndex = nextSpecialStartIndex + specialToken.length; lastTokenLength = 1; } else { break; } } return lastTokenLength; } encodeNative(text, allowedSpecial) { let startIndex = 0; const tokensArray = []; // Flat list to collect the tokens // eslint-disable-next-line no-constant-condition while (true) { const nextSpecialMatch = this.findNextSpecialToken(text, allowedSpecial, startIndex); const nextSpecialStartIndex = nextSpecialMatch?.[0]; const endIndex = nextSpecialStartIndex ?? text.length; const textBeforeSpecial = startIndex === 0 && endIndex === text.length ? text : text.slice(startIndex, endIndex); for (const [match] of textBeforeSpecial.matchAll(this.tokenSplitRegex)) { const token = this.getBpeRankFromString(match); if (token !== undefined) { tokensArray.push(token); continue; } const tokens = this.bytePairEncode(match); tokensArray.push(...tokens); } if (nextSpecialStartIndex !== undefined) { const specialToken = nextSpecialMatch[1]; const specialTokenValue = this.specialTokensEncoder.get(specialToken); if (specialTokenValue === undefined) { throw new Error(`Special token "${specialToken}" is not in the special token encoder.`); } tokensArray.push(specialTokenValue); startIndex = nextSpecialStartIndex + specialToken.length; } else { break; } } return tokensArray; } countNative(text, allowedSpecial) { let startIndex = 0; let tokensCount = 0; // eslint-disable-next-line no-constant-condition while (true) { const nextSpecialMatch = this.findNextSpecialToken(text, allowedSpecial, startIndex); const nextSpecialStartIndex = nextSpecialMatch?.[0]; const endIndex = nextSpecialStartIndex ?? text.length; const textBeforeSpecial = startIndex === 0 && endIndex === text.length ? text : text.slice(startIndex, endIndex); for (const [match] of textBeforeSpecial.matchAll(this.tokenSplitRegex)) { const token = this.getBpeRankFromString(match); if (token !== undefined) { tokensCount++; continue; } const tokens = this.bytePairEncode(match); tokensCount += tokens.length; } if (nextSpecialStartIndex !== undefined) { const specialToken = nextSpecialMatch[1]; const specialTokenValue = this.specialTokensEncoder.get(specialToken); if (specialTokenValue === undefined) { throw new Error(`Special token "${specialToken}" is not in the special token encoder.`); } tokensCount++; startIndex = nextSpecialStartIndex + specialToken.length; } else { break; } } return tokensCount; } *decodeNativeGenerator(tokens) { for (const token of tokens) { const tokenBytes = this.tryDecodeToken(token); if (tokenBytes) { yield tokenBytes; } } } decodeNative(tokens) { let decoded = ''; let intBuffer = emptyBuffer; for (const token of tokens) { const tokenBytes = this.tryDecodeToken(token); if (tokenBytes === undefined) { throw new Error(`Token ${token} is not in the byte pair encoder.`); } if (typeof tokenBytes === 'string') { if (intBuffer !== emptyBuffer) { decoded += exports.decoder.decode(intBuffer, { stream: true }); intBuffer = emptyBuffer; } decoded += tokenBytes; } else { const newBuffer = new Uint8Array(intBuffer.length + tokenBytes.length); newBuffer.set(intBuffer); newBuffer.set(tokenBytes, intBuffer.length); intBuffer = newBuffer; } } if (intBuffer !== emptyBuffer) { decoded += exports.decoder.decode(intBuffer, { stream: true }); } return decoded; } async *decodeNativeAsyncIterable(tokens) { for await (const token of tokens) { const tokenBytesOrString = this.tryDecodeToken(token); if (tokenBytesOrString) { yield tokenBytesOrString; } } } getBpeRankFromString(key) { return this.bytePairStringRankEncoder.get(key); } getBpeRankFromStringOrThrow(key) { const value = this.getBpeRankFromString(key); if (value === undefined) { throw new Error(`The byte-pair encoding does not contain a value for: ${key}`); } return value; } getBpeRankFromBytes(key) { const keyAsString = (0, utfUtil_js_1.tryConvertToString)(key); if (keyAsString !== undefined) { return this.getBpeRankFromString(keyAsString); } // Perform binary search on the binary keys const index = this.binarySearch(key); if (index !== -1) { return this.bytePairNonUtfSortedEncoder[index][1]; } return undefined; } getBpeRankFromBytesOrThrow(key) { const value = this.getBpeRankFromBytes(key); if (value === undefined) { throw new Error(`The byte-pair encoding does not contain a value for: ${key.toString()}`); } return value; } // Binary search on the binary keys binarySearch(key) { let low = 0; let high = this.bytePairNonUtfSortedEncoder.length - 1; while (low <= high) { // eslint-disable-next-line no-bitwise const mid = (low + high) >>> 1; const midKey = this.bytePairNonUtfSortedEncoder[mid][0]; let cmp = 0; const maxLength = Math.min(midKey.length, key.length); for (let i = 0; i < maxLength; i++) { cmp = midKey[i] - key[i]; if (cmp !== 0) break; } if (cmp === 0) { cmp = midKey.length - key.length; } if (cmp === 0) { return mid; } if (cmp < 0) { low = mid + 1; } else { high = mid - 1; } } return -1; } findNextSpecialToken(text, allowedSpecial, startIndex) { let searchIndex = startIndex; // eslint-disable-next-line no-constant-condition while (true) { this.specialTokenPatternRegex.lastIndex = searchIndex; const nextSpecialMatch = this.specialTokenPatternRegex.exec(text); if (!nextSpecialMatch) { return undefined; } const specialToken = nextSpecialMatch[0]; if (allowedSpecial?.has(specialToken)) { const specialTokenStartIndex = nextSpecialMatch.index + searchIndex; return [specialTokenStartIndex, specialToken]; } searchIndex = nextSpecialMatch.index + searchIndex + 1; } } tryDecodeToken(tokenRank) { const value = this.bytePairRankDecoder[tokenRank]; if (typeof value === 'string') { return value; } if (typeof value === 'object') { const fromBinary = this.bytePairNonUtfRankDecoder.get(tokenRank); if (fromBinary) { return fromBinary; } } return this.specialTokensDecoder.get(tokenRank); } addToMergeCache(key, value) { if (!this.mergeCache) return; if (this.mergeCache.size >= this.mergeCacheSize) { // Remove least recently used item (first item) const firstKey = this.mergeCache.keys().next().value; this.mergeCache.delete(firstKey); } this.mergeCache.set(key, value); } bytePairEncode(input) { if (input.length === 1 && (0, utfUtil_js_1.isAscii)(input.codePointAt(0))) { return [this.getBpeRankFromStringOrThrow(input)]; } if (this.mergeCache?.has(input)) { const result = this.mergeCache.get(input); // Move to end to mark as recently used this.mergeCache.delete(input); this.mergeCache.set(input, result); return result; } const inputBytes = this.textEncoder.encode(input); const result = this.bytePairMerge(inputBytes); this.addToMergeCache(input, result); return result; } bytePairMerge( // Input array of bytes to process piece) { // 'starts' holds the start indices of each partition const starts = []; // 'ranks' holds the BPE ranks of each partition pair const ranks = []; // Helper function to get the rank of a byte pair starting at 'startIndex' const getRank = (startIndex, pairStart = starts[startIndex], pairEnd = starts[startIndex + 2]) => { if (pairEnd === undefined) { // No valid pair exists return Number.POSITIVE_INFINITY; } // Extract the byte pair const key = piece.subarray(pairStart, pairEnd); // Retrieve the BPE rank of this byte pair (if it exists) const rank = this.getBpeRankFromBytes(key); return rank ?? Number.POSITIVE_INFINITY; }; // Initialize the 'starts' array with all possible start indices for (let i = 0; i <= piece.length; i++) { starts.push(i); if (i < piece.length - 1) { // Initialize the BPE values for all adjacent pairs ranks.push(getRank(i, i, i + 2)); } else { // Initialize BPE values to infinity for the last pair ranks.push(Number.POSITIVE_INFINITY); } } // Iteratively merge byte pairs until no more useful merges can be done while (starts.length > 1) { let lowestRank = Number.POSITIVE_INFINITY; let lowestPartitionIndex = -1; // Find the partition with the minimum rank for (let i = 0; i < ranks.length - 1; i++) { const rank = ranks[i]; if (rank < lowestRank) { lowestRank = rank; lowestPartitionIndex = i; } } // If no valid pair is left to merge, exit the loop if (lowestRank === Number.POSITIVE_INFINITY || lowestPartitionIndex === -1) { break; } // Merge the pair at 'lowestPartitionIndex' by removing the next start index starts.splice(lowestPartitionIndex + 1, 1); // Remove the BPE value of the merged pair ranks.splice(lowestPartitionIndex, 1); // Update the current merged pair's rank ranks[lowestPartitionIndex] = getRank(lowestPartitionIndex); // Update the rank of the previous pair, if it exists if (lowestPartitionIndex > 0) { ranks[lowestPartitionIndex - 1] = getRank(lowestPartitionIndex - 1); } } // Create the final output by applying the transform function to each partitioned range const output = []; for (let i = 0; i < starts.length - 1; i++) { const pairStart = starts[i]; const pairEnd = starts[i + 1]; const bpeValue = this.getBpeRankFromBytesOrThrow(piece.subarray(pairStart, pairEnd)); output.push(bpeValue); } return output; } } exports.BytePairEncodingCore = BytePairEncodingCore; //# sourceMappingURL=BytePairEncodingCore.js.map