UNPKG

gpt-tokenizer

Version:

A pure JavaScript implementation of a BPE tokenizer (Encoder/Decoder) for GPT-2 / GPT-3 / GPT-4 and other OpenAI models

334 lines 15.5 kB
/* eslint-disable @typescript-eslint/member-ordering */ /* eslint-disable no-param-reassign */ import { BytePairEncodingCore, decoder } from './BytePairEncodingCore.js'; import { ALL_SPECIAL_TOKENS } from './constants.js'; import { chatModelParams, modelToEncodingMap, } from './mapping.js'; import { getEncodingParams, } from './modelParams.js'; import { models } from './models.js'; import { EndOfPrompt, EndOfText, FimMiddle, FimPrefix, FimSuffix, ImEnd, ImSep, ImStart, } from './specialTokens.js'; import { endsWithIncompleteUtfPairSurrogate } from './utfUtil.js'; import { getMaxValueFromMap, getSpecialTokenRegex } from './util.js'; export class GptEncoding { static EndOfPrompt = EndOfPrompt; static EndOfText = EndOfText; static FimMiddle = FimMiddle; static FimPrefix = FimPrefix; static FimSuffix = FimSuffix; modelName; bytePairEncodingCoreProcessor; specialTokensEncoder; specialTokensSet; allSpecialTokenRegex; defaultSpecialTokenConfig; vocabularySize; constructor({ bytePairRankDecoder: mergeableBytePairRanks, specialTokensEncoder, expectedVocabularySize, modelName, ...rest }) { this.specialTokensEncoder = specialTokensEncoder; this.specialTokensSet = new Set(this.specialTokensEncoder.keys()); this.allSpecialTokenRegex = getSpecialTokenRegex(this.specialTokensSet); this.bytePairEncodingCoreProcessor = new BytePairEncodingCore({ bytePairRankDecoder: mergeableBytePairRanks, specialTokensEncoder, ...rest, }); this.defaultSpecialTokenConfig = this.processSpecialTokens(); const maxTokenValue = Math.max(mergeableBytePairRanks.length - 1, getMaxValueFromMap(specialTokensEncoder)); this.vocabularySize = this.bytePairEncodingCoreProcessor.mergeableBytePairRankCount + specialTokensEncoder.size; if (expectedVocabularySize !== undefined) { if (this.vocabularySize !== expectedVocabularySize) { throw new Error('The number of mergeable tokens and special tokens must be equal to expectedVocabularySize.'); } if (maxTokenValue !== expectedVocabularySize - 1) { throw new Error(`The model encodings are invalid. The maximum token value must be equal to expectedVocabularySize - 1. Currently ${maxTokenValue}, expected ${expectedVocabularySize - 1}`); } } this.encode = this.encode.bind(this); this.decode = this.decode.bind(this); this.encodeGenerator = this.encodeGenerator.bind(this); this.decodeGenerator = this.decodeGenerator.bind(this); this.decodeAsyncGenerator = this.decodeAsyncGenerator.bind(this); this.decodeAsync = this.decodeAsync.bind(this); this.isWithinTokenLimit = this.isWithinTokenLimit.bind(this); this.encodeChat = this.encodeChat.bind(this); this.encodeChatGenerator = this.encodeChatGenerator.bind(this); this.countTokens = this.countTokens.bind(this); this.setMergeCacheSize = this.setMergeCacheSize.bind(this); this.clearMergeCache = this.clearMergeCache.bind(this); this.estimateCost = this.estimateCost.bind(this); this.modelName = modelName; } static getEncodingApi(encodingName, getMergeableRanks) { const modelParams = getEncodingParams(encodingName, getMergeableRanks); return new GptEncoding(modelParams); } static getEncodingApiForModel(modelName, getMergeableRanks) { const encodingName = modelToEncodingMap[modelName]; const modelParams = getEncodingParams(encodingName, getMergeableRanks); return new GptEncoding({ ...modelParams, modelName }); } processSpecialTokens({ allowedSpecial, disallowedSpecial, } = {}) { let regexPattern; if (allowedSpecial === ALL_SPECIAL_TOKENS || allowedSpecial?.has(ALL_SPECIAL_TOKENS)) { allowedSpecial = new Set(this.specialTokensSet); const allowedSpecialSet = allowedSpecial; if (disallowedSpecial === ALL_SPECIAL_TOKENS) { throw new Error('allowedSpecial and disallowedSpecial cannot both be set to "all".'); } if (typeof disallowedSpecial === 'object') { // remove any special tokens that are disallowed disallowedSpecial.forEach((val) => allowedSpecialSet.delete(val)); } else { // all special tokens are allowed, and no 'disallowedSpecial' is provided disallowedSpecial = new Set(); } } if (!disallowedSpecial || disallowedSpecial === ALL_SPECIAL_TOKENS || disallowedSpecial.has(ALL_SPECIAL_TOKENS)) { // by default, all special tokens are disallowed disallowedSpecial = new Set(this.specialTokensSet); const disallowedSpecialSet = disallowedSpecial; if (allowedSpecial?.size) { allowedSpecial.forEach((val) => disallowedSpecialSet.delete(val)); // disallowed takes precedence over allowed disallowedSpecial.forEach((val) => allowedSpecial.delete(val)); if (disallowedSpecial.size > 0) { regexPattern = getSpecialTokenRegex(disallowedSpecial); } } else { regexPattern = this.allSpecialTokenRegex; } } return { allowedSpecial, regexPattern }; } encodeGenerator(lineToEncode, encodeOptions) { const specialTokenConfig = encodeOptions ? this.processSpecialTokens(encodeOptions) : this.defaultSpecialTokenConfig; if (specialTokenConfig.regexPattern) { const match = lineToEncode.match(specialTokenConfig.regexPattern); if (match !== null) { throw new Error(`Disallowed special token found: ${match[0]}`); } } return this.bytePairEncodingCoreProcessor.encodeNativeGenerator(lineToEncode, specialTokenConfig.allowedSpecial); } encode(lineToEncode, encodeOptions) { const specialTokenConfig = encodeOptions ? this.processSpecialTokens(encodeOptions) : this.defaultSpecialTokenConfig; if (specialTokenConfig.regexPattern) { const match = lineToEncode.match(specialTokenConfig.regexPattern); if (match !== null) { throw new Error(`Disallowed special token found: ${match[0]}`); } } return this.bytePairEncodingCoreProcessor.encodeNative(lineToEncode, specialTokenConfig.allowedSpecial); } /** * Progressively tokenizes an OpenAI chat. * Warning: gpt-3.5-turbo and gpt-4 chat format may change over time. * Returns tokens assuming the 'gpt-3.5-turbo-0301' / 'gpt-4-0314' format. * Based on OpenAI's guidelines: https://github.com/openai/openai-python/blob/main/chatml.md * Also mentioned in section 6 of this document: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb */ *encodeChatGenerator(chat, model = this.modelName) { if (!model) { throw new Error('Model name must be provided either during initialization or passed in to the method.'); } const params = chatModelParams[model]; const chatStartToken = this.specialTokensEncoder.get(ImStart); const chatEndToken = this.specialTokensEncoder.get(ImEnd); if (!params || chatStartToken === undefined || chatEndToken === undefined) { throw new Error(`Model '${model}' does not support chat.`); } const allowedSpecial = new Set([ImSep]); const { messageSeparator, roleSeparator } = params; const encodedMessageSeparator = messageSeparator.length > 0 ? this.encode(messageSeparator) : []; const encodedRoleSeparator = roleSeparator.length > 0 ? this.encode(roleSeparator, { allowedSpecial }) : []; const nameCache = new Map(); for (const { role = 'system', name = role, content } of chat) { if (content === undefined) { throw new Error('Content must be defined for all messages.'); } yield [chatStartToken]; const encodedName = nameCache.get(name) ?? this.encode(name); nameCache.set(name, encodedName); yield encodedName; if (encodedRoleSeparator.length > 0) { yield encodedRoleSeparator; } yield* this.encodeGenerator(content); yield [chatEndToken]; yield encodedMessageSeparator; } // every reply is primed with <|start|>assistant<|message|> yield [chatStartToken]; yield* this.encodeGenerator('assistant'); if (encodedRoleSeparator.length > 0) { yield encodedRoleSeparator; } } /** * Encodes a chat into a single array of tokens. * Warning: gpt-3.5-turbo and gpt-4 chat format may change over time. * Returns tokens assuming the 'gpt-3.5-turbo-0301' / 'gpt-4-0314' format. * Based on OpenAI's guidelines: https://github.com/openai/openai-python/blob/main/chatml.md * Also mentioned in section 6 of this document: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb */ encodeChat(chat, model = this.modelName) { return [...this.encodeChatGenerator(chat, model)].flat(); } /** * @returns {false | number} false if token limit is exceeded, otherwise the number of tokens */ isWithinTokenLimit(input, tokenLimit) { const tokenGenerator = typeof input === 'string' ? this.encodeGenerator(input) : this.encodeChatGenerator(input); let count = 0; for (const tokens of tokenGenerator) { count += tokens.length; if (count > tokenLimit) { return false; } } return count; } /** * Counts the number of tokens in the input. * @returns {number} The number of tokens. */ countTokens(input, encodeOptions) { if (typeof input === 'string') { const specialTokenConfig = encodeOptions ? this.processSpecialTokens(encodeOptions) : this.defaultSpecialTokenConfig; if (specialTokenConfig.regexPattern) { const match = input.match(specialTokenConfig.regexPattern); if (match !== null) { throw new Error(`Disallowed special token found: ${match[0]}`); } } return this.bytePairEncodingCoreProcessor.countNative(input, specialTokenConfig.allowedSpecial); } const tokenGenerator = this.encodeChatGenerator(input); let count = 0; for (const tokens of tokenGenerator) { count += tokens.length; } return count; } setMergeCacheSize(size) { this.bytePairEncodingCoreProcessor.setMergeCacheSize(size); } clearMergeCache() { this.bytePairEncodingCoreProcessor.clearMergeCache(); } decode(inputTokensToDecode) { return this.bytePairEncodingCoreProcessor.decodeNative(inputTokensToDecode); } *decodeGenerator(inputTokensToDecode) { const decodedByteGenerator = this.bytePairEncodingCoreProcessor.decodeNativeGenerator(inputTokensToDecode); let buffer = ''; for (const decodedPart of decodedByteGenerator) { buffer += typeof decodedPart === 'string' ? decodedPart : decoder.decode(decodedPart, { stream: true }); if (buffer.length === 0 || endsWithIncompleteUtfPairSurrogate(buffer)) { // Keep the high surrogate in the buffer and continue with the next token // eslint-disable-next-line no-continue continue; } else { yield buffer; // reset buffer buffer = ''; } } // Yield any remaining characters in the buffer if (buffer.length > 0) { yield buffer; } } async *decodeAsyncGenerator(inputTokensToDecode) { const decodedByteGenerator = this.bytePairEncodingCoreProcessor.decodeNativeAsyncIterable(inputTokensToDecode); let buffer = ''; for await (const decodedPart of decodedByteGenerator) { buffer += typeof decodedPart === 'string' ? decodedPart : decoder.decode(decodedPart, { stream: true }); if (buffer.length === 0 || endsWithIncompleteUtfPairSurrogate(buffer)) { // Keep the high surrogate in the buffer and continue with the next token // eslint-disable-next-line no-continue continue; } else { yield buffer; // reset buffer buffer = ''; } } // Yield any remaining characters in the buffer if (buffer.length > 0) { yield buffer; } } async decodeAsync(inputTokensToDecode) { const decodedByteGenerator = this.bytePairEncodingCoreProcessor.decodeNativeAsyncIterable(inputTokensToDecode); let buffer = ''; for await (const decodedPart of decodedByteGenerator) { buffer += typeof decodedPart === 'string' ? decodedPart : decoder.decode(decodedPart, { stream: true }); } return buffer; } /** * Estimates the cost of processing a given token count using the model's pricing. * * @param tokenCount - The number of tokens to estimate cost for * @param modelName - Optional model name to use for cost calculation (defaults to this.modelName) * @returns Cost estimate object with applicable price components (input, output, batchInput, batchOutput) */ estimateCost(tokenCount, modelName = this.modelName) { if (!modelName) { throw new Error('Model name must be provided either during initialization or passed in to the method.'); } const model = models[modelName]; if (!model) { throw new Error(`Unknown model: ${modelName}`); } if (!model.cost) { throw new Error(`No cost information available for model: ${modelName}`); } const costPerMillion = model.cost; const result = {}; // Calculate cost per token and multiply by token count // eslint-disable-next-line no-magic-numbers const millionTokens = tokenCount / 1_000_000; if (costPerMillion.input !== undefined) { result.input = costPerMillion.input * millionTokens; } if (costPerMillion.output !== undefined) { result.output = costPerMillion.output * millionTokens; } if (costPerMillion.batchInput !== undefined) { result.batchInput = costPerMillion.batchInput * millionTokens; } if (costPerMillion.batchOutput !== undefined) { result.batchOutput = costPerMillion.batchOutput * millionTokens; } return result; } } //# sourceMappingURL=GptEncoding.js.map