UNPKG

gpt-tokenizer

Version:

A pure JavaScript implementation of a BPE tokenizer (Encoder/Decoder) for GPT-2 / GPT-3 / GPT-4 and other OpenAI models

515 lines (449 loc) 16.3 kB
/* eslint-disable @typescript-eslint/member-ordering */ /* eslint-disable no-param-reassign */ import { BytePairEncodingCore, decoder } from './BytePairEncodingCore.js' import { ALL_SPECIAL_TOKENS } from './constants.js' import { type ChatModelName, type ChatParameters, type EncodingName, type ModelName, chatModelParams, modelToEncodingMap, } from './mapping.js' import { type EncodingParams, type GetMergeableRanksFn, getEncodingParams, } from './modelParams.js' import { type CostEstimate, models } from './models.js' import { EndOfPrompt, EndOfText, FimMiddle, FimPrefix, FimSuffix, ImEnd, ImSep, ImStart, } from './specialTokens.js' import { endsWithIncompleteUtfPairSurrogate } from './utfUtil.js' import { getMaxValueFromMap, getSpecialTokenRegex } from './util.js' export interface EncodeOptions { /** * A list of special tokens that are allowed in the input. * If set to 'all', all special tokens are allowed except those in disallowedSpecial. * @default undefined */ allowedSpecial?: Set<string> | typeof ALL_SPECIAL_TOKENS /** * A list of special tokens that are disallowed in the input. * If set to 'all', all special tokens are disallowed except those in allowedSpecial. * @default 'all' */ disallowedSpecial?: Set<string> | typeof ALL_SPECIAL_TOKENS } export interface ChatMessage { role?: 'system' | 'user' | 'assistant' name?: string content: string } export interface EncodeChatOptions { primeWithAssistantResponse?: string } interface SpecialTokenConfig { allowedSpecial: Set<string> | undefined regexPattern: RegExp | undefined } export class GptEncoding { static EndOfPrompt = EndOfPrompt static EndOfText = EndOfText static FimMiddle = FimMiddle static FimPrefix = FimPrefix static FimSuffix = FimSuffix modelName?: ModelName private bytePairEncodingCoreProcessor: BytePairEncodingCore private specialTokensEncoder: Map<string, number> private specialTokensSet: Set<string> private allSpecialTokenRegex: RegExp private defaultSpecialTokenConfig: SpecialTokenConfig readonly vocabularySize: number private constructor({ bytePairRankDecoder: mergeableBytePairRanks, specialTokensEncoder, expectedVocabularySize, modelName, ...rest }: EncodingParams) { this.specialTokensEncoder = specialTokensEncoder this.specialTokensSet = new Set<string>(this.specialTokensEncoder.keys()) this.allSpecialTokenRegex = getSpecialTokenRegex(this.specialTokensSet) this.bytePairEncodingCoreProcessor = new BytePairEncodingCore({ bytePairRankDecoder: mergeableBytePairRanks, specialTokensEncoder, ...rest, }) this.defaultSpecialTokenConfig = this.processSpecialTokens() const maxTokenValue = Math.max( mergeableBytePairRanks.length - 1, getMaxValueFromMap(specialTokensEncoder), ) this.vocabularySize = this.bytePairEncodingCoreProcessor.mergeableBytePairRankCount + specialTokensEncoder.size if (expectedVocabularySize !== undefined) { if (this.vocabularySize !== expectedVocabularySize) { throw new Error( 'The number of mergeable tokens and special tokens must be equal to expectedVocabularySize.', ) } if (maxTokenValue !== expectedVocabularySize - 1) { throw new Error( `The model encodings are invalid. The maximum token value must be equal to expectedVocabularySize - 1. Currently ${maxTokenValue}, expected ${ expectedVocabularySize - 1 }`, ) } } this.encode = this.encode.bind(this) this.decode = this.decode.bind(this) this.encodeGenerator = this.encodeGenerator.bind(this) this.decodeGenerator = this.decodeGenerator.bind(this) this.decodeAsyncGenerator = this.decodeAsyncGenerator.bind(this) this.decodeAsync = this.decodeAsync.bind(this) this.isWithinTokenLimit = this.isWithinTokenLimit.bind(this) this.encodeChat = this.encodeChat.bind(this) this.encodeChatGenerator = this.encodeChatGenerator.bind(this) this.countTokens = this.countTokens.bind(this) this.setMergeCacheSize = this.setMergeCacheSize.bind(this) this.clearMergeCache = this.clearMergeCache.bind(this) this.estimateCost = this.estimateCost.bind(this) this.modelName = modelName } static getEncodingApi( encodingName: EncodingName, getMergeableRanks: GetMergeableRanksFn, ): GptEncoding { const modelParams = getEncodingParams(encodingName, getMergeableRanks) return new GptEncoding(modelParams) } static getEncodingApiForModel( modelName: ModelName, getMergeableRanks: GetMergeableRanksFn, ): GptEncoding { const encodingName = modelToEncodingMap[modelName] const modelParams = getEncodingParams(encodingName, getMergeableRanks) return new GptEncoding({ ...modelParams, modelName }) } private processSpecialTokens({ allowedSpecial, disallowedSpecial, }: EncodeOptions = {}): SpecialTokenConfig { let regexPattern: RegExp | undefined if ( allowedSpecial === ALL_SPECIAL_TOKENS || allowedSpecial?.has(ALL_SPECIAL_TOKENS) ) { allowedSpecial = new Set(this.specialTokensSet) const allowedSpecialSet = allowedSpecial if (disallowedSpecial === ALL_SPECIAL_TOKENS) { throw new Error( 'allowedSpecial and disallowedSpecial cannot both be set to "all".', ) } if (typeof disallowedSpecial === 'object') { // remove any special tokens that are disallowed disallowedSpecial.forEach((val) => allowedSpecialSet.delete(val)) } else { // all special tokens are allowed, and no 'disallowedSpecial' is provided disallowedSpecial = new Set() } } if ( !disallowedSpecial || disallowedSpecial === ALL_SPECIAL_TOKENS || disallowedSpecial.has(ALL_SPECIAL_TOKENS) ) { // by default, all special tokens are disallowed disallowedSpecial = new Set(this.specialTokensSet) const disallowedSpecialSet = disallowedSpecial if (allowedSpecial?.size) { allowedSpecial.forEach((val) => disallowedSpecialSet.delete(val)) // disallowed takes precedence over allowed disallowedSpecial.forEach((val) => allowedSpecial.delete(val)) if (disallowedSpecial.size > 0) { regexPattern = getSpecialTokenRegex(disallowedSpecial) } } else { regexPattern = this.allSpecialTokenRegex } } return { allowedSpecial, regexPattern } } encodeGenerator( lineToEncode: string, encodeOptions?: EncodeOptions, ): Generator<number[], number, undefined> { const specialTokenConfig = encodeOptions ? this.processSpecialTokens(encodeOptions) : this.defaultSpecialTokenConfig if (specialTokenConfig.regexPattern) { const match = lineToEncode.match(specialTokenConfig.regexPattern) if (match !== null) { throw new Error(`Disallowed special token found: ${match[0]}`) } } return this.bytePairEncodingCoreProcessor.encodeNativeGenerator( lineToEncode, specialTokenConfig.allowedSpecial, ) } encode(lineToEncode: string, encodeOptions?: EncodeOptions): number[] { const specialTokenConfig = encodeOptions ? this.processSpecialTokens(encodeOptions) : this.defaultSpecialTokenConfig if (specialTokenConfig.regexPattern) { const match = lineToEncode.match(specialTokenConfig.regexPattern) if (match !== null) { throw new Error(`Disallowed special token found: ${match[0]}`) } } return this.bytePairEncodingCoreProcessor.encodeNative( lineToEncode, specialTokenConfig.allowedSpecial, ) } /** * Progressively tokenizes an OpenAI chat. * Warning: gpt-3.5-turbo and gpt-4 chat format may change over time. * Returns tokens assuming the 'gpt-3.5-turbo-0301' / 'gpt-4-0314' format. * Based on OpenAI's guidelines: https://github.com/openai/openai-python/blob/main/chatml.md * Also mentioned in section 6 of this document: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb */ *encodeChatGenerator( chat: Iterable<ChatMessage>, model = this.modelName, ): Generator<number[], void, undefined> { if (!model) { throw new Error( 'Model name must be provided either during initialization or passed in to the method.', ) } const params: ChatParameters | undefined = chatModelParams[model as ChatModelName] const chatStartToken = this.specialTokensEncoder.get(ImStart) const chatEndToken = this.specialTokensEncoder.get(ImEnd) if (!params || chatStartToken === undefined || chatEndToken === undefined) { throw new Error(`Model '${model}' does not support chat.`) } const allowedSpecial = new Set([ImSep]) const { messageSeparator, roleSeparator } = params const encodedMessageSeparator = messageSeparator.length > 0 ? this.encode(messageSeparator) : [] const encodedRoleSeparator = roleSeparator.length > 0 ? this.encode(roleSeparator, { allowedSpecial }) : [] const nameCache = new Map<string, number[]>() for (const { role = 'system', name = role, content } of chat) { if (content === undefined) { throw new Error('Content must be defined for all messages.') } yield [chatStartToken] const encodedName = nameCache.get(name) ?? this.encode(name) nameCache.set(name, encodedName) yield encodedName if (encodedRoleSeparator.length > 0) { yield encodedRoleSeparator } yield* this.encodeGenerator(content) yield [chatEndToken] yield encodedMessageSeparator } // every reply is primed with <|start|>assistant<|message|> yield [chatStartToken] yield* this.encodeGenerator('assistant') if (encodedRoleSeparator.length > 0) { yield encodedRoleSeparator } } /** * Encodes a chat into a single array of tokens. * Warning: gpt-3.5-turbo and gpt-4 chat format may change over time. * Returns tokens assuming the 'gpt-3.5-turbo-0301' / 'gpt-4-0314' format. * Based on OpenAI's guidelines: https://github.com/openai/openai-python/blob/main/chatml.md * Also mentioned in section 6 of this document: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb */ encodeChat(chat: readonly ChatMessage[], model = this.modelName): number[] { return [...this.encodeChatGenerator(chat, model)].flat() } /** * @returns {false | number} false if token limit is exceeded, otherwise the number of tokens */ isWithinTokenLimit( input: string | Iterable<ChatMessage>, tokenLimit: number, ): false | number { const tokenGenerator = typeof input === 'string' ? this.encodeGenerator(input) : this.encodeChatGenerator(input) let count = 0 for (const tokens of tokenGenerator) { count += tokens.length if (count > tokenLimit) { return false } } return count } /** * Counts the number of tokens in the input. * @returns {number} The number of tokens. */ countTokens( input: string | Iterable<ChatMessage>, encodeOptions?: EncodeOptions, ): number { if (typeof input === 'string') { const specialTokenConfig = encodeOptions ? this.processSpecialTokens(encodeOptions) : this.defaultSpecialTokenConfig if (specialTokenConfig.regexPattern) { const match = input.match(specialTokenConfig.regexPattern) if (match !== null) { throw new Error(`Disallowed special token found: ${match[0]}`) } } return this.bytePairEncodingCoreProcessor.countNative( input, specialTokenConfig.allowedSpecial, ) } const tokenGenerator = this.encodeChatGenerator(input) let count = 0 for (const tokens of tokenGenerator) { count += tokens.length } return count } setMergeCacheSize(size: number): void { this.bytePairEncodingCoreProcessor.setMergeCacheSize(size) } clearMergeCache(): void { this.bytePairEncodingCoreProcessor.clearMergeCache() } decode(inputTokensToDecode: Iterable<number>): string { return this.bytePairEncodingCoreProcessor.decodeNative(inputTokensToDecode) } *decodeGenerator( inputTokensToDecode: Iterable<number>, ): Generator<string, void, void> { const decodedByteGenerator = this.bytePairEncodingCoreProcessor.decodeNativeGenerator( inputTokensToDecode, ) let buffer = '' for (const decodedPart of decodedByteGenerator) { buffer += typeof decodedPart === 'string' ? decodedPart : decoder.decode(decodedPart, { stream: true }) if (buffer.length === 0 || endsWithIncompleteUtfPairSurrogate(buffer)) { // Keep the high surrogate in the buffer and continue with the next token // eslint-disable-next-line no-continue continue } else { yield buffer // reset buffer buffer = '' } } // Yield any remaining characters in the buffer if (buffer.length > 0) { yield buffer } } async *decodeAsyncGenerator( inputTokensToDecode: AsyncIterable<number>, ): AsyncGenerator<string, void> { const decodedByteGenerator = this.bytePairEncodingCoreProcessor.decodeNativeAsyncIterable( inputTokensToDecode, ) let buffer = '' for await (const decodedPart of decodedByteGenerator) { buffer += typeof decodedPart === 'string' ? decodedPart : decoder.decode(decodedPart, { stream: true }) if (buffer.length === 0 || endsWithIncompleteUtfPairSurrogate(buffer)) { // Keep the high surrogate in the buffer and continue with the next token // eslint-disable-next-line no-continue continue } else { yield buffer // reset buffer buffer = '' } } // Yield any remaining characters in the buffer if (buffer.length > 0) { yield buffer } } async decodeAsync( inputTokensToDecode: AsyncIterable<number>, ): Promise<string> { const decodedByteGenerator = this.bytePairEncodingCoreProcessor.decodeNativeAsyncIterable( inputTokensToDecode, ) let buffer = '' for await (const decodedPart of decodedByteGenerator) { buffer += typeof decodedPart === 'string' ? decodedPart : decoder.decode(decodedPart, { stream: true }) } return buffer } /** * Estimates the cost of processing a given token count using the model's pricing. * * @param tokenCount - The number of tokens to estimate cost for * @param modelName - Optional model name to use for cost calculation (defaults to this.modelName) * @returns Cost estimate object with applicable price components (input, output, batchInput, batchOutput) */ estimateCost(tokenCount: number, modelName = this.modelName): CostEstimate { if (!modelName) { throw new Error( 'Model name must be provided either during initialization or passed in to the method.', ) } const model = models[modelName] if (!model) { throw new Error(`Unknown model: ${modelName}`) } if (!model.cost) { throw new Error(`No cost information available for model: ${modelName}`) } const costPerMillion = model.cost const result: CostEstimate = {} // Calculate cost per token and multiply by token count // eslint-disable-next-line no-magic-numbers const millionTokens = tokenCount / 1_000_000 if (costPerMillion.input !== undefined) { result.input = costPerMillion.input * millionTokens } if (costPerMillion.output !== undefined) { result.output = costPerMillion.output * millionTokens } if (costPerMillion.batchInput !== undefined) { result.batchInput = costPerMillion.batchInput * millionTokens } if (costPerMillion.batchOutput !== undefined) { result.batchOutput = costPerMillion.batchOutput * millionTokens } return result } }