UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

1,071 lines 77.8 kB
import { acquireLock, AsyncDisposeAggregator, DisposeAggregator, DisposedError, EventRelay, withLock } from "lifecycle-utils"; import { removeNullFields } from "../../utils/removeNullFields.js"; import { compareTokens } from "../../utils/compareTokens.js"; import { DisposeGuard } from "../../utils/DisposeGuard.js"; import { TokenMeter } from "../TokenMeter.js"; import { UnsupportedError } from "../../utils/UnsupportedError.js"; import { pushAll } from "../../utils/pushAll.js"; import { safeEventCallback } from "../../utils/safeEventCallback.js"; import { GgufArchitectureType } from "../../gguf/types/GgufMetadataTypes.js"; import { resolveBatchItemsPrioritizationStrategy } from "./utils/resolveBatchItemsPrioritizationStrategy.js"; import { LlamaSampler } from "./LlamaSampler.js"; const defaultLoraScale = 1; const shrinkRetriesMinContextSize = 4096; const defaultMaxPunishTokens = 64; const defaultFailedCreationRemedy = { retries: 6, autoContextSizeShrink: 0.16 }; const defaultEvaluationPriority = 5; const decodeSyncWorkaround = { vulkanLock: {} }; export class LlamaContext { /** @internal */ _llama; /** @internal */ _ctx; /** @internal */ _onReclaimUnusedSequenceId = new EventRelay(); /** @internal */ _backendContextDisposeGuard; /** @internal */ _model; /** @internal */ _contextSize; /** @internal */ _batchSize; /** @internal */ _flashAttention; /** @internal */ _idealThreads; /** @internal */ _minThreads; /** @internal */ _performanceTracking; /** @internal */ _totalSequences; /** @internal */ _unusedSequenceIds = []; /** @internal */ _batchingOptions; /** @internal */ _queuedDecodeSequenceIds = new Set(); /** @internal */ _queuedDecodes = []; /** @internal */ _disposeAggregator = new AsyncDisposeAggregator(); /** @internal */ _modelPreventDisposalHandle; /** @internal */ _loraAdapters = new Set(); /** @internal */ _gcRegistry; /** @internal */ _nextGeneratedSequenceId = 0; /** @internal */ _dispatchDecodeScheduled = false; /** @internal */ _batchDispatchPending = false; /** @internal */ _threadSplitterConsumer; /** @internal */ _freeReservedThreadsTimeout; /** @internal */ _currentDispatchBatchHandle = {}; /** @internal */ _allocatedContextSize; /** @internal */ _disposed = false; onDispose = new EventRelay(); constructor({ _model }, { sequences, contextSize, batchSize, flashAttention = _model.defaultContextFlashAttention, threads, batching: { dispatchSchedule: batchingDispatchSchedule = "nextCycle", itemPrioritizationStrategy: batchingItemsPrioritizationStrategy = "maximumParallelism" } = {}, performanceTracking = false, _embeddings, _ranking }) { if (_model.disposed) throw new DisposedError(); this._llama = _model._llama; this._model = _model; this._backendContextDisposeGuard = new DisposeGuard([this._model._backendModelDisposeGuard]); this._modelPreventDisposalHandle = this._model._backendModelDisposeGuard.createPreventDisposalHandle(); this._totalSequences = Math.max(1, Math.floor(sequences)); this._contextSize = Math.max(2, contextSize); this._batchSize = Math.max(batchSize, this._totalSequences); this._flashAttention = flashAttention; this._idealThreads = typeof threads === "number" ? this._llama._threadsSplitter.normalizeThreadsValue(threads) : this._llama._threadsSplitter.normalizeThreadsValue(threads?.ideal ?? (this._llama.maxThreads === 0 ? this._llama.cpuMathCores : this._llama.maxThreads)); this._minThreads = Math.max(1, typeof threads === "number" ? 1 : this._llama._threadsSplitter.normalizeThreadsValue(threads?.min ?? 1)); this._performanceTracking = !!performanceTracking; this._ctx = new this._llama._bindings.AddonContext(this._model._model, removeNullFields({ contextSize: this._contextSize * this._totalSequences, // each sequence needs its own <contextSize> of cells batchSize: this._batchSize, sequences: this._totalSequences, flashAttention: this._flashAttention, threads: this._idealThreads, embeddings: _embeddings, ranking: _ranking, performanceTracking: this._performanceTracking })); this._batchingOptions = { dispatchSchedule: batchingDispatchSchedule, itemPrioritizationStrategy: batchingItemsPrioritizationStrategy }; this._gcRegistry = new FinalizationRegistry(this._model._removeLoraUsage); this._gcRegistry.register(this, this._loraAdapters); this._reclaimUnusedSequenceId = this._reclaimUnusedSequenceId.bind(this); this._freeReservedThreads = this._freeReservedThreads.bind(this); this._disposeAggregator.add(() => { this._disposed = true; }); this._disposeAggregator.add(() => void this._gcRegistry.unregister(this)); this._disposeAggregator.add(this._onReclaimUnusedSequenceId); this._disposeAggregator.add(this.onDispose.dispatchEvent); this._disposeAggregator.add(this.model.onDispose.createListener(disposeContextIfReferenced.bind(null, new WeakRef(this)))); this._disposeAggregator.add(() => { if (this._loraAdapters.size > 0) { const loraAdapters = new Set(this._loraAdapters); this._loraAdapters.clear(); return this._model._removeLoraUsage(loraAdapters); } }); this._disposeAggregator.add(async () => { await this._backendContextDisposeGuard.acquireDisposeLock(); await this._ctx.dispose(); this._modelPreventDisposalHandle.dispose(); }); } async dispose() { if (this._disposed) return; this._disposed = true; await this._disposeAggregator.dispose(); } /** @hidden */ [Symbol.asyncDispose]() { return this.dispose(); } get disposed() { return this._disposed; } get model() { return this._model; } get contextSize() { return this._contextSize; } get batchSize() { return this._batchSize; } get flashAttention() { return this._flashAttention; } /** * The actual size of the state in the memory in bytes. * This value is provided by `llama.cpp` and doesn't include all the memory overhead of the context. */ get stateSize() { this._ensureNotDisposed(); return this._ctx.getStateSize(); } /** The number of threads currently used to evaluate tokens */ get currentThreads() { this._ensureNotDisposed(); return this._ctx.getThreads(); } /** * The number of threads that are preferred to be used to evaluate tokens. * * The actual number of threads used may be lower when other evaluations are running in parallel. */ get idealThreads() { return this._idealThreads; } getAllocatedContextSize() { this._ensureNotDisposed(); if (this._allocatedContextSize == null) this._allocatedContextSize = this._ctx.getContextSize(); return this._allocatedContextSize; } get totalSequences() { return this._totalSequences; } get sequencesLeft() { return this._totalSequences - this._nextGeneratedSequenceId + this._unusedSequenceIds.length; } /** * Before calling this method, make sure to call `sequencesLeft` to check if there are any sequences left. * When there are no sequences left, this method will throw an error. */ getSequence(options = {}) { const { contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(this.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" } = {}, tokenPredictor, _tokenMeter } = options; this._ensureNotDisposed(); const nextSequenceId = this._popSequenceId(); if (nextSequenceId == null) throw new Error("No sequences left"); return LlamaContextSequence._create({ sequenceId: nextSequenceId, context: this, tokenMeter: _tokenMeter, contextShift: { size: contextShiftSize, strategy: contextShiftStrategy }, tokenPredictor }); } dispatchPendingBatch() { this._currentDispatchBatchHandle = {}; this._dispatchDecodeScheduled = false; if (this._batchDispatchPending) return; this._batchDispatchPending = true; void withLock(this, "context", async () => { this._currentDispatchBatchHandle = {}; this._dispatchDecodeScheduled = false; this._batchDispatchPending = false; let shouldHaveAnotherLoop = this._queuedDecodes.length > 0; const queuedDecodeToMappedLogits = new Map(); const resolvePrioritizationStrategy = () => { try { this._ensureNotDisposed(); return resolveBatchItemsPrioritizationStrategy(this._batchingOptions.itemPrioritizationStrategy); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); } return null; }; const getOrderedQueuedDecodes = (prioritizationStrategy) => { const batchItemToQueuedDecodeMap = new Map(); const batchItemsList = []; for (const queuedDecode of this._queuedDecodes) { const batchItem = { tokens: queuedDecode.tokens, logits: queuedDecode.logits, evaluationPriority: queuedDecode.evaluationPriority }; batchItemToQueuedDecodeMap.set(batchItem, queuedDecode); batchItemsList.push(batchItem); } let prioritizedItems; try { prioritizedItems = prioritizationStrategy({ items: batchItemsList, size: this._batchSize }); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); return null; } return prioritizedItems.map((prioritizedItem) => { const queuedDecode = batchItemToQueuedDecodeMap.get(prioritizedItem.item); if (queuedDecode == null) throw new Error("Received invalid batch item. Make sure you keep the original object reference " + "of the batch item on `item` on `PrioritizedBatchItem` in your custom prioritization strategy"); return { queuedDecode, processAmount: prioritizedItem.processAmount }; }); }; const fitQueuedDecodesToABatch = (queuedDecodes, batchSize) => { const currentBatchItems = []; let currentBatchSize = 0; let batchTokenSlotsLeft = batchSize; for (const { queuedDecode, processAmount } of queuedDecodes) { const resolvedProcessAmount = Math.min(processAmount <= 0 ? 1 : processAmount, queuedDecode.tokens.length, batchTokenSlotsLeft); if (resolvedProcessAmount <= 0) { if (batchTokenSlotsLeft === 0) break; continue; } batchTokenSlotsLeft -= resolvedProcessAmount; currentBatchSize += resolvedProcessAmount; currentBatchItems.push({ queuedDecode, processAmount: resolvedProcessAmount }); } return { currentBatchItems, currentBatchSize }; }; const decodeTokenBatchItems = async (batchItems, currentBatchSize) => { const afterDecodeActions = []; const queuedDecodesToDelete = new Set(); const currentQueuedDecodeItems = new Set(); if (currentBatchSize !== 0) this._ctx.initBatch(currentBatchSize); for (const { queuedDecode, processAmount } of batchItems) { let batchLogitIndexes; const tokensToProcess = queuedDecode.tokens.slice(0, processAmount); const tokenIndexesWithLogitsToProcess = queuedDecode.logits.slice(0, processAmount) .map((logit, index) => (logit ? index : undefined)) .filter((index) => index != undefined); const numberOfOutputTokens = tokenIndexesWithLogitsToProcess.length; TokenMeter.useTokens(queuedDecode.tokenMeter, Math.max(0, tokensToProcess.length - numberOfOutputTokens), "input"); TokenMeter.useTokens(queuedDecode.tokenMeter, numberOfOutputTokens, "output"); try { batchLogitIndexes = this._ctx.addToBatch(queuedDecode.sequenceId, queuedDecode.firstTokenSequenceIndex, Uint32Array.from(tokensToProcess), Uint32Array.from(tokenIndexesWithLogitsToProcess)); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set([queuedDecode]), err); continue; } currentQueuedDecodeItems.add(queuedDecode); if (queuedDecode.tokens.length === processAmount) { queuedDecodesToDelete.add(queuedDecode); afterDecodeActions.push({ queuedDecode, batchLogitIndexes, batchLogitTokenIndexes: tokenIndexesWithLogitsToProcess, firstTokenIndex: queuedDecode.firstTokenSequenceIndex, returnResults: true }); } else { if (batchLogitIndexes.length > 0) afterDecodeActions.push({ queuedDecode, batchLogitIndexes, batchLogitTokenIndexes: tokenIndexesWithLogitsToProcess, firstTokenIndex: queuedDecode.firstTokenSequenceIndex }); queuedDecode.tokens = queuedDecode.tokens.slice(processAmount); queuedDecode.logits = queuedDecode.logits.slice(processAmount); queuedDecode.firstTokenSequenceIndex += processAmount; } } for (let i = 0; i < this._queuedDecodes.length; i++) { const queuedDecode = this._queuedDecodes[i]; if (queuedDecodesToDelete.has(queuedDecode)) { this._queuedDecodes.splice(i, 1); this._queuedDecodeSequenceIds.delete(queuedDecode.sequenceId); i--; } } if (currentBatchSize !== 0) { const allocationResult = this._threadSplitterConsumer?.getAllocationToConsume(); const [threadsToUse, consumerHandle] = allocationResult instanceof Promise ? await allocationResult ?? [] : allocationResult ?? []; try { if (threadsToUse != null) this._ctx.setThreads(threadsToUse); await this._ctx.decodeBatch(); consumerHandle?.dispose(); } catch (err) { consumerHandle?.dispose(); this._dispatchErrorForQueuedDecodesAndDequeue(currentQueuedDecodeItems, err); return; } } function finishAfterDecodeAction(action, mappedLogitValues) { if (mappedLogitValues != null && mappedLogitValues.length > 0) { if (queuedDecodeToMappedLogits.has(action.queuedDecode)) pushAll(queuedDecodeToMappedLogits.get(action.queuedDecode), mappedLogitValues); else queuedDecodeToMappedLogits.set(action.queuedDecode, mappedLogitValues); } if (action.returnResults != null) { const [accept] = action.queuedDecode.response; const mappedLogits = queuedDecodeToMappedLogits.get(action.queuedDecode) ?? []; queuedDecodeToMappedLogits.delete(action.queuedDecode); accept(mappedLogits); } } const afterDecodeActionResults = afterDecodeActions.map((action) => { if (action.batchLogitIndexes.length === 0) { finishAfterDecodeAction(action); return undefined; } const mappedLogitValues = []; let promiseChain = undefined; const batchLogitIndexes = action.batchLogitIndexes; const batchLogitTokenIndexes = action.batchLogitTokenIndexes; for (let i = 0; i < batchLogitIndexes.length; i++) { const tokenIndex = batchLogitTokenIndexes[i]; const mappedValue = promiseChain != null ? promiseChain .then(() => action.queuedDecode.logitDataMapper(batchLogitIndexes[i], tokenIndex + action.firstTokenIndex)) : action.queuedDecode.logitDataMapper(batchLogitIndexes[i], tokenIndex + action.firstTokenIndex); if (mappedValue instanceof Promise) { promiseChain = mappedValue; mappedLogitValues.push(mappedValue .then((value) => [tokenIndex + action.firstTokenIndex, value])); } else mappedLogitValues.push([tokenIndex + action.firstTokenIndex, mappedValue]); } if (promiseChain != null) return Promise.all(mappedLogitValues) .then((resolvedMappedLogitValues) => finishAfterDecodeAction(action, resolvedMappedLogitValues)); finishAfterDecodeAction(action, mappedLogitValues); return undefined; }); await Promise.all(afterDecodeActionResults); }; const prioritizationStrategy = resolvePrioritizationStrategy(); if (prioritizationStrategy == null) return; // all queued items are rejected and dequeued when we get here this._reserveThreads(); try { while (shouldHaveAnotherLoop) { const orderedQueuedDecodes = getOrderedQueuedDecodes(prioritizationStrategy); if (orderedQueuedDecodes == null) return; // all queued items are rejected and dequeued when we get here const { currentBatchItems, currentBatchSize } = fitQueuedDecodesToABatch(orderedQueuedDecodes, this._batchSize); let preventDisposalHandle; try { preventDisposalHandle = this._backendContextDisposeGuard.createPreventDisposalHandle(); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); return; } let decodeLock; // this is a workaround to prevent Vulkan from crashing the process when decoding on multiple contexts in parallel if (this._llama.gpu === "vulkan") decodeLock = await acquireLock(decodeSyncWorkaround.vulkanLock, "decode"); try { await decodeTokenBatchItems(currentBatchItems, currentBatchSize); shouldHaveAnotherLoop = this._queuedDecodes.length > 0; } finally { decodeLock?.dispose(); preventDisposalHandle.dispose(); } } } finally { this._scheduleToFreeReservedThreads(); } }); } /** * Print the timings of token evaluation since that last print for this context. * * Requires the `performanceTracking` option to be enabled. * * > **Note:** it prints on the `LlamaLogLevel.info` level, so if you set the level of your `Llama` instance higher than that, * it won't print anything. */ async printTimings() { this._ensureNotDisposed(); if (!this._performanceTracking) throw new UnsupportedError("Performance tracking is not enabled"); this._ctx.printTimings(); await new Promise((accept) => setTimeout(accept, 0)); // wait for the logs to finish printing } /** @internal */ async _decodeTokens({ sequenceId, firstTokenSequenceIndex, tokens, logits, evaluationPriority = defaultEvaluationPriority, tokenMeter }, logitDataMapper) { return await new Promise((accept, reject) => { this._queuedDecodes.push({ sequenceId, tokens, logits, firstTokenSequenceIndex, evaluationPriority, tokenMeter, response: [accept, reject], logitDataMapper }); this._queuedDecodeSequenceIds.add(sequenceId); this._scheduleDecode(); }); } /** @internal */ _reclaimUnusedSequenceId(sequenceId) { if (this._disposed) return; void withLock(this, "context", async () => { if (this._disposed) return; this._ctx.disposeSequence(sequenceId); this._unusedSequenceIds.push(sequenceId); this._onReclaimUnusedSequenceId.dispatchEvent(); }); } /** @internal */ _popSequenceId() { if (this._unusedSequenceIds.length > 0) return this._unusedSequenceIds.shift(); if (this._nextGeneratedSequenceId < this._totalSequences) { const sequenceId = this._nextGeneratedSequenceId; this._nextGeneratedSequenceId++; return sequenceId; } return null; } /** @internal */ _scheduleDecode() { if (this._dispatchDecodeScheduled || this._batchDispatchPending) return; this._dispatchDecodeScheduled = true; const currentPendingBatchHandle = this._currentDispatchBatchHandle; const dispatch = () => { if (this._currentDispatchBatchHandle !== currentPendingBatchHandle) return; this.dispatchPendingBatch(); }; const dispatchSchedule = this._batchingOptions.dispatchSchedule; if (this._queuedDecodeSequenceIds.size === this._totalSequences) dispatch(); if (dispatchSchedule === "nextCycle") { if (typeof setImmediate === "function") setImmediate(dispatch); else setTimeout(dispatch, 0); } else if (typeof dispatchSchedule === "function") dispatchSchedule(dispatch); else { if (typeof setImmediate === "function") setImmediate(dispatch); else setTimeout(dispatch, 0); } } /** @internal */ _dispatchErrorForQueuedDecodesAndDequeue(queuedDecodes, err) { for (const pendingDecode of queuedDecodes) { const [, reject] = pendingDecode.response; reject(err); } for (let i = 0; i < this._queuedDecodes.length; i++) { const item = this._queuedDecodes[i]; if (queuedDecodes.has(item)) { this._queuedDecodes.splice(i, 1); this._queuedDecodeSequenceIds.delete(item.sequenceId); i--; } } } /** @internal */ _ensureNotDisposed() { if (this._disposed) throw new DisposedError(); } /** @internal */ async _setLora({ filePath, scale }) { const lora = await this._model._getOrLoadLora(filePath); this._ctx.setLora(lora, scale ?? defaultLoraScale); if (!this._loraAdapters.has(lora)) { this._loraAdapters.add(lora); lora.usages++; } } /** @internal */ _reserveThreads() { clearTimeout(this._freeReservedThreadsTimeout); delete this._freeReservedThreadsTimeout; if (this._threadSplitterConsumer != null) return; this._threadSplitterConsumer = this._llama._threadsSplitter.createConsumer(this._idealThreads, this._minThreads); } /** @internal */ _freeReservedThreads() { clearTimeout(this._freeReservedThreadsTimeout); delete this._freeReservedThreadsTimeout; if (this._threadSplitterConsumer == null) return; this._threadSplitterConsumer.dispose(); delete this._threadSplitterConsumer; } /** @internal */ _scheduleToFreeReservedThreads() { if (this._threadSplitterConsumer == null) return; clearTimeout(this._freeReservedThreadsTimeout); this._freeReservedThreadsTimeout = setTimeout(this._freeReservedThreads, 0); } /** @internal */ static async _create(options, { _model }) { const sequences = options.sequences ?? getDefaultContextSequences(); const flashAttention = _model.flashAttentionSupported ? Boolean(options.flashAttention ?? _model.defaultContextFlashAttention) : false; const loraOptions = typeof options.lora === "string" ? { adapters: [{ filePath: options.lora }] } : options.lora; let failedCreationRetries = options.failedCreationRemedy === false ? 0 : Math.max(0, options.failedCreationRemedy?.retries ?? defaultFailedCreationRemedy.retries); const failedCreationAutoContextSizeShrink = options.failedCreationRemedy === false ? 0 : options.failedCreationRemedy?.autoContextSizeShrink ?? defaultFailedCreationRemedy.autoContextSizeShrink; let contextSize = await _model.fileInsights.configurationResolver.resolveContextContextSize(options.contextSize, { batchSize: options.batchSize, sequences: sequences, modelGpuLayers: _model.gpuLayers, modelTrainContextSize: _model.trainContextSize, flashAttention, getVramState: () => _model._llama._vramOrchestrator.getMemoryState(), llamaGpu: _model._llama.gpu, ignoreMemorySafetyChecks: options.ignoreMemorySafetyChecks, isEmbeddingContext: options._embeddings }); const minContextSize = options.contextSize === "auto" ? shrinkRetriesMinContextSize : (typeof options.contextSize === "object" && typeof options.contextSize.min === "number") ? options.contextSize.min : typeof options.contextSize === "number" ? options.contextSize : shrinkRetriesMinContextSize; const { createSignal } = options; async function createContext(contextSize) { const batchSize = options.batchSize ?? getDefaultContextBatchSize({ contextSize, sequences }); const resourceRequirementsEstimation = _model.fileInsights.estimateContextResourceRequirements({ contextSize, sequences, isEmbeddingContext: options._embeddings, modelGpuLayers: _model.gpuLayers, batchSize, flashAttention }); const context = new LlamaContext({ _model }, { ...options, contextSize, batchSize, sequences, flashAttention }); const contextCreationVramReservation = options.ignoreMemorySafetyChecks ? null : _model._llama._vramOrchestrator.reserveMemory(resourceRequirementsEstimation.gpuVram); const contextCreationRamReservation = options.ignoreMemorySafetyChecks ? null : _model._llama._vramOrchestrator.reserveMemory(resourceRequirementsEstimation.cpuRam); try { if (createSignal?.aborted) throw createSignal.reason; const contextLoaded = await context._ctx.init(); if (createSignal?.aborted) { if (contextLoaded) await context._ctx.dispose(); throw createSignal.reason; } else if (!contextLoaded) throw new Error("Failed to create context"); contextCreationVramReservation?.dispose?.(); contextCreationRamReservation?.dispose?.(); if (loraOptions != null && loraOptions.adapters.length > 0) { let loadedAdapters = 0; for (const adapter of loraOptions.adapters) { try { await context._setLora({ filePath: adapter.filePath, scale: adapter.scale }); loadedAdapters++; try { loraOptions.onLoadProgress?.(loadedAdapters / loraOptions.adapters.length); } catch (err) { console.error(err); } } catch (err) { await context.dispose(); throw err; } if (createSignal?.aborted) { await context.dispose(); throw createSignal.reason; } } } else if (loraOptions?.onLoadProgress != null) { try { loraOptions.onLoadProgress(1); } catch (err) { console.error(err); } } return context; } finally { contextCreationVramReservation?.dispose?.(); contextCreationRamReservation?.dispose?.(); } } while (failedCreationRetries >= 0) { try { return await createContext(contextSize); } catch (err) { if (failedCreationRetries === 0 || (createSignal?.aborted && err === createSignal.reason)) throw err; failedCreationRetries--; let newContextSize = typeof failedCreationAutoContextSizeShrink === "number" ? Math.floor(contextSize * (1 - failedCreationAutoContextSizeShrink)) : Math.floor(failedCreationAutoContextSizeShrink(contextSize)); if (!Number.isFinite(newContextSize)) throw err; if (newContextSize < minContextSize) newContextSize = minContextSize; if (newContextSize >= contextSize) throw err; contextSize = newContextSize; } } throw new Error("Failed to create context"); } } export class LlamaContextSequence { /** @internal */ _sequenceId; /** @internal */ _gcRegistry; /** @internal */ _context; /** @internal */ _contextShift; /** @internal */ _tokenPredictor; /** @internal */ _tokenMeter; /** @internal */ _disposeAggregator = new DisposeAggregator(); /** @internal */ _lock = {}; /** @internal */ _resetTokenPredictor = false; /** @internal */ _tokenPredictorOwner = {}; /** @internal */ _contextTokens = []; /** @internal */ _nextTokenIndex = 0; /** @internal */ _loadedTokenPredictions = []; /** @internal */ _usedTokenPredictions = 0; /** @internal */ _unusedTokenPredictions = 0; /** @internal */ _validatedTokenPredictions = 0; /** @internal */ _refutedTokenPredictions = 0; /** @internal */ _disposed = false; onDispose = new EventRelay(); constructor({ sequenceId, context, tokenMeter, contextShift, tokenPredictor }) { this._sequenceId = sequenceId; this._context = context; this._tokenMeter = tokenMeter ?? new TokenMeter(); this._contextShift = contextShift; this._tokenPredictor = tokenPredictor; this._gcRegistry = new FinalizationRegistry(this._context._reclaimUnusedSequenceId); this._gcRegistry.register(this, sequenceId); this._disposeAggregator.add(() => this._gcRegistry.unregister(this)); this._disposeAggregator.add(this.onDispose.dispatchEvent); this._disposeAggregator.add(this.model.onDispose.createListener(disposeContextSequenceIfReferenced.bind(null, new WeakRef(this)))); this._disposeAggregator.add(() => { this._context._reclaimUnusedSequenceId(this._sequenceId); }); if (this._tokenPredictor != null) this._disposeAggregator.add(this._tokenPredictor); } dispose() { if (this._disposed) return; this._disposeAggregator.dispose(); this._contextTokens.length = 0; this._disposed = true; } /** @hidden */ [Symbol.dispose]() { return this.dispose(); } get disposed() { return this._disposed; } get context() { return this._context; } get model() { return this._context.model; } /** The maximum number of tokens that the sequence state can hold */ get contextSize() { return this._context.contextSize; } /** The index where the next evaluated token will be placed in the context */ get nextTokenIndex() { return this._nextTokenIndex - this._loadedTokenPredictions.length; } /** The current context state tokens */ get contextTokens() { if (this._loadedTokenPredictions.length === 0) return this._contextTokens.slice(); return this._contextTokens.slice(0, -this._loadedTokenPredictions.length); } get tokenMeter() { return this._tokenMeter; } /** * The token predictor used when creating this sequence. */ get tokenPredictor() { return this._tokenPredictor; } /** * Statistics of token predictions using the sequence's `tokenPredictor`. * * The statistics change only when token prediction is used in this sequence. * * `validated` + `refuted` = total number of evaluated predictions. * * Prefer using `validated` and `refuted` to evaluate the effectiveness of token prediction. */ get tokenPredictions() { return { used: this._usedTokenPredictions, unused: this._unusedTokenPredictions, validated: this._validatedTokenPredictions, refuted: this._refutedTokenPredictions }; } get isLoadedToMemory() { return !this._disposed; } compareContextTokens(tokens) { for (let i = 0; i < this._contextTokens.length - this._loadedTokenPredictions.length; i++) { if (compareTokens(this._contextTokens[i], tokens[i])) continue; return { firstDifferentIndex: i }; } return { firstDifferentIndex: this._contextTokens.length - this._loadedTokenPredictions.length }; } /** * Erase parts of the context state to align it with the given tokens. * * If the given tokens do not align with the current context state, the context state will be erased to align with the given tokens. * * To find the first different token index between the context state and the given tokens, access the `nextTokenIndex` property. * * If `allowShift` is `true` (the default), shifting tokens may happen to align the context state with the given tokens, * which incurs token evaluation of the shifted tokens. */ async adaptStateToTokens(tokens, allowShift = true) { const modelSupportsShifting = !this.model.fileInsights.isRecurrent && this.model.fileInfo.metadata?.general?.architecture !== GgufArchitectureType.deepseek2; if (!modelSupportsShifting || !allowShift) { const { firstDifferentIndex } = this.compareContextTokens(tokens); if (firstDifferentIndex < this.nextTokenIndex) await this._eraseContextTokenRanges([{ start: firstDifferentIndex, end: this._nextTokenIndex }]); return; } const eraseRanges = []; let tokensIndex = 0; let differentTokenIndex = undefined; for (let i = 0; i < this._contextTokens.length - this._loadedTokenPredictions.length && tokensIndex < tokens.length; i++) { if (compareTokens(this._contextTokens[i], tokens[tokensIndex])) { if (differentTokenIndex != null) { eraseRanges.push({ start: differentTokenIndex, end: i }); differentTokenIndex = undefined; } tokensIndex++; continue; } if (differentTokenIndex == null) differentTokenIndex = i; } if (differentTokenIndex != null) eraseRanges.push({ start: differentTokenIndex, end: this._nextTokenIndex }); if (eraseRanges.length > 0) await this._eraseContextTokenRanges(eraseRanges); } /** * Clear the history of the sequence. * If `prependBos` was enabled, the BOS token will be prepended to the sequence again. */ async clearHistory() { this._ensureNotDisposed(); await this._eraseContextTokenRanges([{ start: 0, end: this._nextTokenIndex }]); } /** * Erase context tokens in the provided ranges to free up space for new tokens to be generated. * The start of each range is inclusive, and the end of each range is exclusive. * For example, the range `{start: 0, end: 1}` will remove the token at the `0` index only. */ eraseContextTokenRanges(ranges) { return this._eraseContextTokenRanges(ranges); } /** @internal */ async _eraseContextTokenRanges(ranges, { canResetTokenPredictor = true, canRemovePredictionTokens = true, skipLock = false } = {}) { this._ensureNotDisposed(); await withLock(this._context, "context", async () => { this._ensureNotDisposed(); if (ranges.length === 0) return; // if the deletion fails, we'll have to dispose the sequence and fill it up again let deletionSuccessful = true; const resolvedRanges = ranges .map(({ start, end }) => { if (start === end) return null; if (start > end) [start, end] = [end, start]; if (end > this._nextTokenIndex) end = this._nextTokenIndex; if (start >= this._nextTokenIndex) return null; return { start, end }; }) .filter((range) => range != null) .sort((a, b) => a.start - b.start) .reduce((ranges, range) => { if (ranges.length === 0) return [range]; const lastRange = ranges[ranges.length - 1]; if (lastRange.end >= range.start) { lastRange.end = Math.max(lastRange.end, range.end); return ranges; } ranges.push(range); return ranges; }, []); const tokenPredictionsToRemove = (resolvedRanges.length > 0 && canRemovePredictionTokens) ? this._loadedTokenPredictions.length : 0; if (tokenPredictionsToRemove > 0) { const startDeleteIndex = this._nextTokenIndex - this._loadedTokenPredictions.length; const lastDeleteRange = resolvedRanges[resolvedRanges.length - 1]; if (lastDeleteRange.end >= startDeleteIndex) lastDeleteRange.end = this._nextTokenIndex; else resolvedRanges.push({ start: startDeleteIndex, end: this._nextTokenIndex }); if (canResetTokenPredictor) await this._abortTokenPredictor(true); } let removedTokens = 0; let lastDeleteRangeEndPos = null; for (const range of resolvedRanges) { this._contextTokens.splice(range.start - removedTokens, range.end - range.start); if (deletionSuccessful) deletionSuccessful &&= this._context._ctx.removeTokenCellsFromSequence(this._sequenceId, range.start, range.end); if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 && lastDeleteRangeEndPos !== range.start) { this._context._ctx.shiftSequenceTokenCells(this._sequenceId, lastDeleteRangeEndPos, range.start, -removedTokens); const shiftedTokens = range.start - lastDeleteRangeEndPos; this._tokenMeter.useTokens(shiftedTokens, "input"); } removedTokens += range.end - range.start; lastDeleteRangeEndPos = range.end; } if (tokenPredictionsToRemove > 0) this._loadedTokenPredictions.splice(0, tokenPredictionsToRemove); if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 && lastDeleteRangeEndPos !== this._nextTokenIndex) { this._context._ctx.shiftSequenceTokenCells(this._sequenceId, lastDeleteRangeEndPos, this._nextTokenIndex, -removedTokens); const shiftedTokens = this._nextTokenIndex - lastDeleteRangeEndPos; this._tokenMeter.useTokens(shiftedTokens, "input"); } this._nextTokenIndex -= removedTokens; if (canResetTokenPredictor && removedTokens > 0) await this._abortTokenPredictor(true); if (deletionSuccessful) return; const newSequenceTokens = this._contextTokens.slice(); this._nextTokenIndex = 0; this._context._ctx.disposeSequence(this._sequenceId); await this.evaluateWithoutGeneratingNewTokens(newSequenceTokens, { _skipLock: skipLock }); }); } /** * Evaluate the provided tokens into the context sequence, and continue generating new tokens on iterator iterations. * * This method uses the token predictor (when provided) to generate new tokens faster. */ async *evaluate(tokens, options = {}) { const iterator = this.evaluateWithMetadata(tokens, {}, options); let iterateInput = undefined; try { while (true) { const { value, done } = await iterator.next(iterateInput); if (done) return; iterateInput = yield value.token; } } finally { await iterator.return(); } } /** * Like {@link evaluate `.evaluate(...)`}, but with additional metadata for each generated token. * * Configure the additional metadata options to choose which metadata to include. */ evaluateWithMetadata(tokens, metadata, options = {}) { const { temperature = 0, minP = 0, topK = 40, topP = 0.95, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {}, yieldEogToken = false, _noSampling = false } = options; if (this._tokenPredictor != null && !_noSampling && tokens.length > 0) return this._speculativeEvaluate(tokens, metadata, { temperature, minP, topK, topP, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority, contextShiftOptions: { size: contextShiftSize, strategy: contextShiftStrategy }, yieldEogToken, tokenPredictor: this._tokenPredictor }); return this._evaluate(tokens, metadata, { temperature, minP, topK, topP, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority, contextShiftOptions: { size: contextShiftSize, strategy: contextShiftStrategy }, yieldEogToken, _noSampling }); } /** * Evaluate the provided tokens into the context sequence without generating new tokens. */ async evaluateWithoutGeneratingNewTokens(tokens, options = {}) { const { evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {}, _skipLock = false } = options; const iterator = this._evaluate(tokens, {}, { generateNewTokens: false, evaluationPriority, contextShiftOptions: { size: contextShiftSize, strategy: contextShiftStrategy }, _skipLock }); const predictorAlignmentPromise = this.tokenPredictor == null ? undefined : this._tokenPredictor?.reset({ stateTokens: [...this._contextTokens, ...tokens], evaluateOptions: { evaluationPriority, contextShift: { size: contextShiftSize, strategy: contextShiftStrategy } }, targetSequence: this }); if (predictorAlignmentPromise != null) { this._tokenPredictorOwner = {}; this._resetTokenPredictor = false; } // eslint-disable-next-line @typescript-eslint/no-unused-vars for await (const token of iterator) { // Array.from doesn't work with async generators, so we have to iterate over the generator } await iterator.return(); if (predictorAlignmentPromise != null) await predictorAlignmentPromise; } /** * Evaluate the provided tokens into the context sequence with custom options for each token. * * This method allows for more precise control of the generation process. * * A next token will be generated for a given token only if any of the `generateNext` options for it are used. * * To generate more tokens after this method finishes, * use it again with token(s) you selected to add to the context from the previous evaluation. * * This method doesn't use the token predictor (when provided) since it cannot predict which tokens are actually needed. * Use the `evaluate` method when you need to use token prediction. * @returns An array where for each token in the input array, there can be an output item at the same index in the output array. * For indexes that have no output, there won't be any value at the corresponding index in the output array. * * It's recommended to iterate from `0` up to the length of the input array to check the results in the output array. */ async controlledEvaluate(input, options) { const { evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {} } = options ?? {}; const contextShiftOptions = { size: contextShiftSize, strategy: contextShiftStrategy }; this._ensureNotDisposed(); if (input.length === 0) return []; a