UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

186 lines 7.8 kB
import { DisposeAggregator, DisposedError } from "lifecycle-utils"; import { getConsoleLogPrefix } from "../../../utils/getConsoleLogPrefix.js"; import { LruCache } from "../../../utils/LruCache.js"; import { safeEventCallback } from "../../../utils/safeEventCallback.js"; const defaultMaxPreloadTokens = 256; const defaultMaxCachedCompletions = 100; export class LlamaChatSessionPromptCompletionEngine { /** @internal */ _chatSession; /** @internal */ _maxPreloadTokens; /** @internal */ _maxCachedCompletions; /** @internal */ _onGeneration; /** @internal */ _completionOptions; /** @internal */ _completionCaches = new WeakMap(); /** @internal */ _disposeAggregator = new DisposeAggregator(); /** @internal */ _currentCompletionAbortController = new AbortController(); /** @internal */ _lastPrompt; /** @internal */ _disposed = false; constructor(chatSession, { maxPreloadTokens = defaultMaxPreloadTokens, onGeneration, maxCachedCompletions = defaultMaxCachedCompletions, ...options }) { this._chatSession = chatSession; this._maxPreloadTokens = Math.max(1, maxPreloadTokens); this._maxCachedCompletions = Math.max(1, maxCachedCompletions); this._onGeneration = safeEventCallback(onGeneration); this._completionOptions = options; this.dispose = this.dispose.bind(this); this._disposeAggregator.add(this._chatSession.onDispose.createListener(this.dispose)); this._disposeAggregator.add(() => { this._disposed = true; this._currentCompletionAbortController.abort(); }); } dispose() { if (this._disposed) return; this._disposeAggregator.dispose(); } /** * Get completion for the prompt from the cache, * and begin preloading this prompt into the context sequence and completing it. * * On completion progress, `onGeneration` (configured for this engine instance) will be called. */ complete(prompt) { if (this._disposed) throw new DisposedError(); const completionCache = this._getCurrentCompletionCache(); const completion = completionCache.getCompletion(prompt); if (this._lastPrompt == null || !(this._lastPrompt + (completion ?? "")).startsWith(prompt)) { this._lastPrompt = prompt; this._restartCompletion(completionCache); } this._lastPrompt = prompt; return completion ?? ""; } /** @internal */ _getCurrentCompletionCache() { const completionCache = this._completionCaches.get(this._chatSession._chatHistoryStateRef); if (completionCache != null) return completionCache; const newCompletionCache = new CompletionCache(this._maxCachedCompletions); this._completionCaches.set(this._chatSession._chatHistoryStateRef, newCompletionCache); return newCompletionCache; } /** @internal */ _restartCompletion(completionCache) { if (this._disposed) return; this._currentCompletionAbortController.abort(); this._currentCompletionAbortController = new AbortController(); const prompt = this._lastPrompt; if (prompt == null) return; const existingCompletion = completionCache.getCompletion(prompt); const promptToComplete = prompt + (existingCompletion ?? ""); const currentPromptTokens = this._chatSession.model.tokenize(promptToComplete, false, "trimLeadingSpace").length; const leftTokens = Math.max(0, this._maxPreloadTokens - currentPromptTokens); if (leftTokens === 0) return; const currentAbortController = this._currentCompletionAbortController; const currentAbortSignal = this._currentCompletionAbortController.signal; let currentCompletion = ""; void this._chatSession.completePrompt(promptToComplete, { ...this._completionOptions, stopOnAbortSignal: false, maxTokens: leftTokens, signal: currentAbortSignal, onTextChunk: (chunk) => { currentCompletion += chunk; const completion = (existingCompletion ?? "") + currentCompletion; completionCache.putCompletion(prompt, completion); if (this._getCurrentCompletionCache() !== completionCache) { currentAbortController.abort(); return; } if (this._lastPrompt === prompt) this._onGeneration?.(prompt, completion); } }) .then(() => { if (this._lastPrompt !== prompt && this._getCurrentCompletionCache() === completionCache) return this._restartCompletion(completionCache); }) .catch((err) => { if ((currentAbortSignal.aborted && err === currentAbortSignal.reason) || err instanceof DOMException) return; console.error(getConsoleLogPrefix(false, false), err); }); } /** @internal */ static _create(chatSession, options = {}) { return new LlamaChatSessionPromptCompletionEngine(chatSession, options); } } class CompletionCache { /** @internal */ _cache; /** @internal */ _rootNode = [new Map()]; constructor(maxInputs) { this._cache = new LruCache(maxInputs, { onDelete: (key) => { this._deleteInput(key); } }); } get maxInputs() { return this._cache.maxSize; } getCompletion(input) { let node = this._rootNode; for (let i = 0; i < input.length; i++) { if (node == null) return null; const [next, completion] = node; const char = input[i]; if (!next.has(char)) { if (completion != null && completion.startsWith(input.slice(i))) { this._cache.get(input.slice(0, i)); return completion.slice(input.length - i); } } node = next.get(char); } if (node == null) return null; const [, possibleCompletion] = node; if (possibleCompletion != null) { this._cache.get(input); return possibleCompletion; } return null; } putCompletion(input, completion) { this._cache.set(input, null); let node = this._rootNode; for (let i = 0; i < input.length; i++) { const [next] = node; const char = input[i]; if (!next.has(char)) next.set(char, [new Map()]); node = next.get(char); } const currentCompletion = node[1]; if (currentCompletion != null && currentCompletion.startsWith(completion)) return currentCompletion; node[1] = completion; return completion; } /** @internal */ _deleteInput(input) { let lastNodeWithMultipleChildren = this._rootNode; let lastNodeWithMultipleChildrenDeleteChar = input[0]; let node = this._rootNode; for (let i = 0; i < input.length; i++) { const [next] = node; const char = input[i]; if (next.size > 1) { lastNodeWithMultipleChildren = node; lastNodeWithMultipleChildrenDeleteChar = char; } if (!next.has(char)) return; node = next.get(char); } if (lastNodeWithMultipleChildrenDeleteChar !== "") lastNodeWithMultipleChildren[0].delete(lastNodeWithMultipleChildrenDeleteChar); } } //# sourceMappingURL=LlamaChatSessionPromptCompletionEngine.js.map