UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

489 lines 25.8 kB
import { DisposeAggregator, DisposedError, EventRelay, withLock } from "lifecycle-utils"; import { tokenizeInput } from "../utils/tokenizeInput.js"; import { UnsupportedError } from "../utils/UnsupportedError.js"; import { removeNullFields } from "../utils/removeNullFields.js"; import { TokenStreamRegulator } from "../utils/TokenStreamRegulator.js"; import { StopGenerationDetector } from "../utils/StopGenerationDetector.js"; import { UNKNOWN_UNICODE_CHAR } from "../consts.js"; import { getQueuedTokensBeforeStopTrigger } from "../utils/getQueuedTokensBeforeStopTrigger.js"; import { safeEventCallback } from "../utils/safeEventCallback.js"; import { pushAll } from "../utils/pushAll.js"; import { GgufArchitectureType } from "../gguf/types/GgufMetadataTypes.js"; import { resolveBeginningTokenToPrepend } from "../utils/tokenizerUtils.js"; import { LlamaGrammarEvaluationState } from "./LlamaGrammarEvaluationState.js"; const defaultContextShiftSize = ((sequence) => Math.max(1, Math.floor(sequence.context.contextSize / 10))); const defaultMinPrefixKeepTokens = ((sequence) => Math.max(1, Math.floor(sequence.context.contextSize / 10))); /** * @see [Text Completion](https://node-llama-cpp.withcat.ai/guide/text-completion) tutorial */ export class LlamaCompletion { /** @internal */ _disposeAggregator = new DisposeAggregator(); /** @internal */ _autoDisposeSequence; /** @internal */ _sequence; onDispose = new EventRelay(); constructor({ contextSequence, autoDisposeSequence = false }) { this._sequence = contextSequence; this._autoDisposeSequence = autoDisposeSequence; this._disposeAggregator.add(this._sequence.onDispose.createListener(() => { this.dispose(); })); this._disposeAggregator.add(this.onDispose.dispatchEvent); } dispose({ disposeSequence = this._autoDisposeSequence } = {}) { if (this._sequence == null || this.disposed) return; if (disposeSequence) this._sequence.dispose(); this._sequence = null; this._disposeAggregator.dispose(); } /** @hidden */ [Symbol.dispose]() { return this.dispose(); } get disposed() { return this._sequence == null || this._sequence.disposed; } get infillSupported() { if (this._sequence == null) throw new DisposedError(); return this._sequence.model.tokens.infill.prefix != null && this._sequence.model.tokens.infill.suffix != null; } /** * Generate a completion for an input. */ async generateCompletion(input, options = {}) { const { response } = await this.generateCompletionWithMeta(input, options); return response; } /** * Same as `generateCompletion`, but returns additional metadata about the generation. * See `generateCompletion` for more information. */ async generateCompletionWithMeta(input, { onTextChunk, onToken, signal, maxTokens, temperature, minP, topK, topP, seed, trimWhitespaceSuffix = false, repeatPenalty = {}, tokenBias, evaluationPriority = 5, grammar, customStopTriggers, contextShiftSize = defaultContextShiftSize, disableContextShift } = {}) { if (this._sequence == null || this.disposed) throw new DisposedError(); const beginningTokenToPrepend = resolveBeginningTokenToPrepend(this._sequence.model.vocabularyType, this._sequence.model.tokens); const extraEosTokens = getExtraCompletionEosTokens(this._sequence.model); async function fitInputIntoContext({ maxTokens, tokens }) { const res = []; if (beginningTokenToPrepend != null) res.push(beginningTokenToPrepend); const inputTokensSize = Math.max(0, Math.min(maxTokens - res.length, tokens.length)); if (inputTokensSize === 0 && tokens.length > 0) throw new Error("The context size is too small to generate a response for the given input"); const slicedTokens = tokens.slice(-inputTokensSize); pushAll(res, slicedTokens); return res; } const ensureNotAborted = () => { if (signal?.aborted) throw signal.reason; if (this.disposed) throw new DisposedError(); }; return await withLock(this, "generateCompletion", signal, async () => { ensureNotAborted(); if (this._sequence == null || this.disposed) throw new DisposedError(); const resolvedInput = tokenizeInput(input, this._sequence.model.tokenizer, beginningTokenToPrepend != null ? "trimLeadingSpace" : undefined); const resolvedContextShiftSize = await resolveContextShiftSize(contextShiftSize, this._sequence); ensureNotAborted(); const inputTokens = await fitInputIntoContext({ maxTokens: this._sequence.context.contextSize - resolvedContextShiftSize, tokens: resolvedInput }); ensureNotAborted(); const resolvedMaxTokens = !disableContextShift ? maxTokens : (maxTokens != null && maxTokens > 0) ? Math.min(maxTokens, this._sequence.context.contextSize - inputTokens.length) : this._sequence.context.contextSize - inputTokens.length; this._sequence.tokenPredictor?.updateInputTokens?.(inputTokens.slice()); return await this._generateResponse(inputTokens, { onTextChunk: safeEventCallback(onTextChunk), onToken: safeEventCallback(onToken), signal, maxTokens: resolvedMaxTokens, temperature, minP, topK, topP, seed, trimWhitespaceSuffix, repeatPenalty, tokenBias, evaluationPriority, grammar, contextShiftSize, customStopTriggers }, { async contextShift({ shiftSize, res, pendingTokens, sequence }) { return { newContextState: await fitInputIntoContext({ maxTokens: sequence.context.contextSize - shiftSize, tokens: [...resolvedInput, ...res, ...pendingTokens] }) }; }, extraEosTokens }); }); } /** * Infill (also known as Fill-In-Middle), generates a completion for an input (`prefixInput`) that * should connect to a given continuation (`suffixInput`). * For example, for `prefixInput: "123"` and `suffixInput: "789"`, the model is expected to generate `456` * to make the final text be `123456789`. */ async generateInfillCompletion(prefixInput, suffixInput, options = {}) { const { response } = await this.generateInfillCompletionWithMeta(prefixInput, suffixInput, options); return response; } /** * Same as `generateInfillCompletion`, but returns additional metadata about the generation. * See `generateInfillCompletion` for more information. */ async generateInfillCompletionWithMeta(prefixInput, suffixInput, { onTextChunk, onToken, signal, maxTokens, temperature, minP, topK, topP, seed, trimWhitespaceSuffix = false, repeatPenalty = {}, tokenBias, evaluationPriority = 5, grammar, contextShiftSize = defaultContextShiftSize, customStopTriggers, minPrefixKeepTokens = defaultMinPrefixKeepTokens, disableContextShift = false } = {}) { if (this._sequence == null || this.disposed) throw new DisposedError(); const prefixToken = this._sequence.model.tokens.infill.prefix; const suffixToken = this._sequence.model.tokens.infill.suffix; const middleToken = this._sequence.model.tokens.infill.middle; const beginningTokenToPrepend = resolveBeginningTokenToPrepend(this._sequence.model.vocabularyType, this._sequence.model.tokens); if (prefixToken == null || suffixToken == null) throw new UnsupportedError("Infill completions are not supported by this model"); const extraEosTokens = getExtraInfillEosTokens(this._sequence.model); async function fitInputIntoContext({ maxTokens, prefixTokens, suffixTokens, sequence }) { if (prefixToken == null || suffixToken == null) throw new UnsupportedError("Infill completions are not supported by this model"); // 2 - InfillPrefix token, InfillSuffix token const specialTokensInContext = 2 + (middleToken != null ? 1 : 0) + (beginningTokenToPrepend != null ? 1 : 0); const resolvedMaxTokens = maxTokens - specialTokensInContext; let sizeLeftToFill = resolvedMaxTokens; let suffixTokensSize = Math.min(sizeLeftToFill, suffixTokens.length); sizeLeftToFill -= suffixTokensSize; let prefixTokensSize = Math.min(sizeLeftToFill, prefixTokens.length); sizeLeftToFill -= prefixTokensSize; if (sizeLeftToFill <= 0 && disableContextShift) throw new Error("The context size is too small to generate a response for the given input, and context shift is disabled. " + "Consider removing `disableContextShift` or reducing the input size."); const resolvedMinPrefixKeepTokens = Math.min(Math.min(resolvedMaxTokens, prefixTokens.length), Math.max(1, Math.floor(minPrefixKeepTokens instanceof Function ? await minPrefixKeepTokens(sequence) : minPrefixKeepTokens))); if (prefixTokensSize < resolvedMinPrefixKeepTokens) { const diffToFill = Math.min(suffixTokensSize, resolvedMinPrefixKeepTokens - prefixTokensSize); prefixTokensSize += diffToFill; suffixTokensSize -= diffToFill; } const resolvedPrefixTokens = prefixTokens.slice(-prefixTokensSize); const resolvedSuffixTokens = suffixTokens.slice(0, suffixTokensSize); const newContextState = []; if (beginningTokenToPrepend != null) newContextState.push(beginningTokenToPrepend); if (middleToken != null) { newContextState.push(prefixToken); pushAll(newContextState, resolvedPrefixTokens); newContextState.push(suffixToken); pushAll(newContextState, resolvedSuffixTokens); newContextState.push(middleToken); } else { newContextState.push(suffixToken); pushAll(newContextState, resolvedSuffixTokens); newContextState.push(prefixToken); pushAll(newContextState, resolvedPrefixTokens); } return newContextState; } const ensureNotAborted = () => { if (signal?.aborted) throw signal.reason; if (this.disposed) throw new DisposedError(); }; return await withLock(this, "generateCompletion", signal, async () => { ensureNotAborted(); if (this._sequence == null || this.disposed) throw new DisposedError(); const resolvedPrefixInputTokens = tokenizeInput(prefixInput, this._sequence.model.tokenizer, "trimLeadingSpace"); const resolvedSuffixInputTokens = tokenizeInput(suffixInput, this._sequence.model.tokenizer, "trimLeadingSpace"); const resolvedContextShiftSize = await resolveContextShiftSize(contextShiftSize, this._sequence); ensureNotAborted(); const inputTokens = await fitInputIntoContext({ maxTokens: this._sequence.context.contextSize - resolvedContextShiftSize, prefixTokens: resolvedPrefixInputTokens, suffixTokens: resolvedSuffixInputTokens, sequence: this._sequence }); ensureNotAborted(); const resolvedMaxTokens = !disableContextShift ? maxTokens : (maxTokens != null && maxTokens > 0) ? Math.min(maxTokens, this._sequence.context.contextSize - inputTokens.length) : this._sequence.context.contextSize - inputTokens.length; this._sequence.tokenPredictor?.updateInputTokens?.(inputTokens.slice()); return await this._generateResponse(inputTokens, { onTextChunk: safeEventCallback(onTextChunk), onToken: safeEventCallback(onToken), signal, maxTokens: resolvedMaxTokens, temperature, minP, topK, topP, seed, trimWhitespaceSuffix, repeatPenalty, tokenBias, evaluationPriority, grammar, contextShiftSize, customStopTriggers }, { async contextShift({ shiftSize, res, pendingTokens, sequence }) { return { newContextState: await fitInputIntoContext({ maxTokens: sequence.context.contextSize - shiftSize, prefixTokens: [...resolvedPrefixInputTokens, ...res, ...pendingTokens], suffixTokens: resolvedSuffixInputTokens, sequence }) }; }, extraEosTokens }); }); } /** @internal */ async _generateResponse(tokens, { onTextChunk, onToken, signal, maxTokens, temperature, minP, topK, topP, seed, trimWhitespaceSuffix = false, repeatPenalty = {}, tokenBias, evaluationPriority = 5, grammar, contextShiftSize = defaultContextShiftSize, customStopTriggers }, { contextShift, extraEosTokens = new Set() }) { if (this._sequence == null) throw new DisposedError(); const sequence = this._sequence; const model = sequence.model; const context = sequence.context; const res = []; const pendingTokens = []; const grammarEvaluationState = grammar != null ? new LlamaGrammarEvaluationState({ model, grammar }) : undefined; const { lastTokens: repeatPenaltyLastTokens = 64, punishTokensFilter, penalizeNewLine, penalty, frequencyPenalty, presencePenalty } = repeatPenalty === false ? { lastTokens: 0 } : repeatPenalty; const streamRegulator = new TokenStreamRegulator(); const stopGenerationDetector = new StopGenerationDetector(); const customStopGenerationTriggersDetector = new StopGenerationDetector(); const locksToReleaseOnValidGeneration = []; const repeatPenaltyEnabled = repeatPenaltyLastTokens > 0; let inputTokens = tokens; let generatedTokens = 0; if (grammar != null) StopGenerationDetector.resolveStopTriggers(grammar.stopGenerationTriggers, model.tokenizer) .map((stopTrigger) => stopGenerationDetector.addStopTrigger(stopTrigger)); if (customStopTriggers != null) StopGenerationDetector.resolveStopTriggers(customStopTriggers, model.tokenizer) .map((stopTrigger) => customStopGenerationTriggersDetector.addStopTrigger(stopTrigger)); const ensureNotAborted = () => { if (signal?.aborted) throw signal.reason; if (this.disposed) throw new DisposedError(); }; const getPenaltyTokens = () => { if (this._sequence == null) throw new DisposedError(); let punishTokens = res.slice(-repeatPenaltyLastTokens); if (punishTokensFilter != null) punishTokens = punishTokensFilter(punishTokens); if (penalizeNewLine == null || !penalizeNewLine) { const nlToken = model.tokens.nl; if (nlToken != null) punishTokens = punishTokens.filter((token) => token !== nlToken); } return punishTokens; }; while (true) { ensureNotAborted(); let shouldContextShift = false; if (inputTokens.length === 1 && sequence.nextTokenIndex !== 0) await sequence.eraseContextTokenRanges([{ start: 0, end: sequence.nextTokenIndex }]); else { const lastToken = inputTokens[inputTokens.length - 1]; // we need to decode at least one token to generate a response inputTokens.pop(); await sequence.adaptStateToTokens(inputTokens, false); inputTokens.push(lastToken); ensureNotAborted(); const firstDifferentIndex = sequence.nextTokenIndex; inputTokens.splice(0, firstDifferentIndex); } const evaluationIterator = sequence.evaluate(inputTokens, removeNullFields({ temperature, minP, topK, topP, seed, grammarEvaluationState, repeatPenalty: !repeatPenaltyEnabled ? undefined : { punishTokens: getPenaltyTokens, maxPunishTokens: repeatPenaltyLastTokens, penalty, frequencyPenalty, presencePenalty }, tokenBias, evaluationPriority, yieldEogToken: true })); const pendingPartialTokens = []; for await (const token of evaluationIterator) { ensureNotAborted(); generatedTokens++; const tokens = pendingPartialTokens.length === 0 ? [token] : [...pendingPartialTokens, token]; const text = model.detokenize([token]); if (pendingPartialTokens.length === 0 && text.endsWith(UNKNOWN_UNICODE_CHAR) && !model.isSpecialToken(token) && !model.isEogToken(token)) { pendingPartialTokens.push(token); continue; } else { pendingPartialTokens.length = 0; const queuedTokenRelease = streamRegulator.addChunk({ tokens, text }); if (text.endsWith(UNKNOWN_UNICODE_CHAR) || ((grammar?.trimWhitespaceSuffix || trimWhitespaceSuffix) && text.trim() === "") || (text === "" && locksToReleaseOnValidGeneration.length > 0 && !model.isSpecialToken(token))) { locksToReleaseOnValidGeneration.push(queuedTokenRelease.createTextIndexLock(0)); } else { while (locksToReleaseOnValidGeneration.length > 0) locksToReleaseOnValidGeneration.shift().dispose(); } stopGenerationDetector.recordGeneration({ text, tokens, queuedTokenRelease }); customStopGenerationTriggersDetector.recordGeneration({ text, tokens, queuedTokenRelease }); if (model.isEogToken(token) || extraEosTokens.has(token)) queuedTokenRelease.createTokenIndexLock(0); pushAll(pendingTokens, streamRegulator.popFreeChunkTokens()); if (stopGenerationDetector.hasTriggeredStops || customStopGenerationTriggersDetector.hasTriggeredStops || model.isEogToken(token) || extraEosTokens.has(token)) { const triggeredStops = stopGenerationDetector.hasTriggeredStops ? stopGenerationDetector.getTriggeredStops() : customStopGenerationTriggersDetector.getTriggeredStops(); const partiallyFreeTokens = streamRegulator.getPartiallyFreeChunk(model.tokenizer); const queuedTokensBeforeStopTrigger = getQueuedTokensBeforeStopTrigger(triggeredStops, partiallyFreeTokens, model.tokenizer); pushAll(pendingTokens, queuedTokensBeforeStopTrigger); const { firstRemainingGenerationAfterStop } = StopGenerationDetector.getFirstRemainingGenerationAfterStop(triggeredStops); if (pendingTokens.length > 0) { onToken?.(pendingTokens.slice()); onTextChunk?.(model.detokenize(pendingTokens, false, res)); } pushAll(res, pendingTokens); pendingTokens.length = 0; let modelResponse = model.detokenize(res); if (grammar?.trimWhitespaceSuffix || trimWhitespaceSuffix) modelResponse = modelResponse.trimEnd(); const isEogToken = model.isEogToken(token) || extraEosTokens.has(token); if (isEogToken || stopGenerationDetector.hasTriggeredStops) return { response: modelResponse, metadata: { remainingGenerationAfterStop: firstRemainingGenerationAfterStop, stopReason: isEogToken ? "eogToken" : "stopGenerationTrigger" } }; return { response: modelResponse, metadata: { remainingGenerationAfterStop: firstRemainingGenerationAfterStop, stopReason: "customStopTrigger", customStopTrigger: triggeredStops[0].stopTrigger } }; } if (pendingTokens.length > 0) { onToken?.(pendingTokens.slice()); onTextChunk?.(model.detokenize(pendingTokens, false, res)); pushAll(res, pendingTokens); pendingTokens.length = 0; } } if (maxTokens != null && maxTokens > 0 && generatedTokens >= maxTokens) { let modelResponse = model.detokenize(res); if (grammar?.trimWhitespaceSuffix || trimWhitespaceSuffix) modelResponse = modelResponse.trimEnd(); return { response: modelResponse, metadata: { stopReason: "maxTokens" } }; } if (sequence.nextTokenIndex >= context.contextSize - 1) { shouldContextShift = true; break; } } if (shouldContextShift) { const resolvedContextShiftSize = await resolveContextShiftSize(contextShiftSize, sequence); ensureNotAborted(); const { newContextState } = await contextShift({ shiftSize: resolvedContextShiftSize, res, pendingTokens, sequence }); ensureNotAborted(); inputTokens = newContextState; continue; } break; } throw new Error("The context size is too small to generate a response"); } } async function resolveContextShiftSize(contextShiftSize, sequence) { if (typeof contextShiftSize === "number") return contextShiftSize; else if (contextShiftSize instanceof Function) return Math.min(sequence.context.contextSize, Math.max(1, Math.floor(contextShiftSize instanceof Function ? await contextShiftSize(sequence) : contextShiftSize))); return defaultContextShiftSize(sequence); } function getExtraCompletionEosTokens(model) { const extraEosTokens = new Set(); if (model.fileInfo.metadata?.general?.architecture === GgufArchitectureType.gemma || model.fileInfo.metadata?.general?.architecture === GgufArchitectureType.gemma2) { for (const token of model.iterateAllTokens()) { const tokenText = model.detokenize([token], true); if (tokenText === "<|file_separator|>" || tokenText === "<|fim_prefix|>") { extraEosTokens.add(token); if (extraEosTokens.size === 2) break; } } } return extraEosTokens; } function getExtraInfillEosTokens(model) { const extraEosTokens = new Set(); if (model.fileInfo.metadata?.general?.architecture === GgufArchitectureType.gemma || model.fileInfo.metadata?.general?.architecture === GgufArchitectureType.gemma2) { for (const token of model.iterateAllTokens()) { const tokenText = model.detokenize([token], true); if (tokenText === "<|file_separator|>") { extraEosTokens.add(token); break; } } } return extraEosTokens; } //# sourceMappingURL=LlamaCompletion.js.map