node-llama-cpp
Version:
Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level
1,071 lines • 77.8 kB
JavaScript
import { acquireLock, AsyncDisposeAggregator, DisposeAggregator, DisposedError, EventRelay, withLock } from "lifecycle-utils";
import { removeNullFields } from "../../utils/removeNullFields.js";
import { compareTokens } from "../../utils/compareTokens.js";
import { DisposeGuard } from "../../utils/DisposeGuard.js";
import { TokenMeter } from "../TokenMeter.js";
import { UnsupportedError } from "../../utils/UnsupportedError.js";
import { pushAll } from "../../utils/pushAll.js";
import { safeEventCallback } from "../../utils/safeEventCallback.js";
import { GgufArchitectureType } from "../../gguf/types/GgufMetadataTypes.js";
import { resolveBatchItemsPrioritizationStrategy } from "./utils/resolveBatchItemsPrioritizationStrategy.js";
import { LlamaSampler } from "./LlamaSampler.js";
const defaultLoraScale = 1;
const shrinkRetriesMinContextSize = 4096;
const defaultMaxPunishTokens = 64;
const defaultFailedCreationRemedy = {
retries: 6,
autoContextSizeShrink: 0.16
};
const defaultEvaluationPriority = 5;
const decodeSyncWorkaround = {
vulkanLock: {}
};
export class LlamaContext {
/** @internal */ _llama;
/** @internal */ _ctx;
/** @internal */ _onReclaimUnusedSequenceId = new EventRelay();
/** @internal */ _backendContextDisposeGuard;
/** @internal */ _model;
/** @internal */ _contextSize;
/** @internal */ _batchSize;
/** @internal */ _flashAttention;
/** @internal */ _idealThreads;
/** @internal */ _minThreads;
/** @internal */ _performanceTracking;
/** @internal */ _totalSequences;
/** @internal */ _unusedSequenceIds = [];
/** @internal */ _batchingOptions;
/** @internal */ _queuedDecodeSequenceIds = new Set();
/** @internal */ _queuedDecodes = [];
/** @internal */ _disposeAggregator = new AsyncDisposeAggregator();
/** @internal */ _modelPreventDisposalHandle;
/** @internal */ _loraAdapters = new Set();
/** @internal */ _gcRegistry;
/** @internal */ _nextGeneratedSequenceId = 0;
/** @internal */ _dispatchDecodeScheduled = false;
/** @internal */ _batchDispatchPending = false;
/** @internal */ _threadSplitterConsumer;
/** @internal */ _freeReservedThreadsTimeout;
/** @internal */ _currentDispatchBatchHandle = {};
/** @internal */ _allocatedContextSize;
/** @internal */ _disposed = false;
onDispose = new EventRelay();
constructor({ _model }, { sequences, contextSize, batchSize, flashAttention = _model.defaultContextFlashAttention, threads, batching: { dispatchSchedule: batchingDispatchSchedule = "nextCycle", itemPrioritizationStrategy: batchingItemsPrioritizationStrategy = "maximumParallelism" } = {}, performanceTracking = false, _embeddings, _ranking }) {
if (_model.disposed)
throw new DisposedError();
this._llama = _model._llama;
this._model = _model;
this._backendContextDisposeGuard = new DisposeGuard([this._model._backendModelDisposeGuard]);
this._modelPreventDisposalHandle = this._model._backendModelDisposeGuard.createPreventDisposalHandle();
this._totalSequences = Math.max(1, Math.floor(sequences));
this._contextSize = Math.max(2, contextSize);
this._batchSize = Math.max(batchSize, this._totalSequences);
this._flashAttention = flashAttention;
this._idealThreads = typeof threads === "number"
? this._llama._threadsSplitter.normalizeThreadsValue(threads)
: this._llama._threadsSplitter.normalizeThreadsValue(threads?.ideal ?? (this._llama.maxThreads === 0
? this._llama.cpuMathCores
: this._llama.maxThreads));
this._minThreads = Math.max(1, typeof threads === "number"
? 1
: this._llama._threadsSplitter.normalizeThreadsValue(threads?.min ?? 1));
this._performanceTracking = !!performanceTracking;
this._ctx = new this._llama._bindings.AddonContext(this._model._model, removeNullFields({
contextSize: this._contextSize * this._totalSequences, // each sequence needs its own <contextSize> of cells
batchSize: this._batchSize,
sequences: this._totalSequences,
flashAttention: this._flashAttention,
threads: this._idealThreads,
embeddings: _embeddings,
ranking: _ranking,
performanceTracking: this._performanceTracking
}));
this._batchingOptions = {
dispatchSchedule: batchingDispatchSchedule,
itemPrioritizationStrategy: batchingItemsPrioritizationStrategy
};
this._gcRegistry = new FinalizationRegistry(this._model._removeLoraUsage);
this._gcRegistry.register(this, this._loraAdapters);
this._reclaimUnusedSequenceId = this._reclaimUnusedSequenceId.bind(this);
this._freeReservedThreads = this._freeReservedThreads.bind(this);
this._disposeAggregator.add(() => {
this._disposed = true;
});
this._disposeAggregator.add(() => void this._gcRegistry.unregister(this));
this._disposeAggregator.add(this._onReclaimUnusedSequenceId);
this._disposeAggregator.add(this.onDispose.dispatchEvent);
this._disposeAggregator.add(this.model.onDispose.createListener(disposeContextIfReferenced.bind(null, new WeakRef(this))));
this._disposeAggregator.add(() => {
if (this._loraAdapters.size > 0) {
const loraAdapters = new Set(this._loraAdapters);
this._loraAdapters.clear();
return this._model._removeLoraUsage(loraAdapters);
}
});
this._disposeAggregator.add(async () => {
await this._backendContextDisposeGuard.acquireDisposeLock();
await this._ctx.dispose();
this._modelPreventDisposalHandle.dispose();
});
}
async dispose() {
if (this._disposed)
return;
this._disposed = true;
await this._disposeAggregator.dispose();
}
/** @hidden */
[Symbol.asyncDispose]() {
return this.dispose();
}
get disposed() {
return this._disposed;
}
get model() {
return this._model;
}
get contextSize() {
return this._contextSize;
}
get batchSize() {
return this._batchSize;
}
get flashAttention() {
return this._flashAttention;
}
/**
* The actual size of the state in the memory in bytes.
* This value is provided by `llama.cpp` and doesn't include all the memory overhead of the context.
*/
get stateSize() {
this._ensureNotDisposed();
return this._ctx.getStateSize();
}
/** The number of threads currently used to evaluate tokens */
get currentThreads() {
this._ensureNotDisposed();
return this._ctx.getThreads();
}
/**
* The number of threads that are preferred to be used to evaluate tokens.
*
* The actual number of threads used may be lower when other evaluations are running in parallel.
*/
get idealThreads() {
return this._idealThreads;
}
getAllocatedContextSize() {
this._ensureNotDisposed();
if (this._allocatedContextSize == null)
this._allocatedContextSize = this._ctx.getContextSize();
return this._allocatedContextSize;
}
get totalSequences() {
return this._totalSequences;
}
get sequencesLeft() {
return this._totalSequences - this._nextGeneratedSequenceId + this._unusedSequenceIds.length;
}
/**
* Before calling this method, make sure to call `sequencesLeft` to check if there are any sequences left.
* When there are no sequences left, this method will throw an error.
*/
getSequence(options = {}) {
const { contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(this.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" } = {}, tokenPredictor, _tokenMeter } = options;
this._ensureNotDisposed();
const nextSequenceId = this._popSequenceId();
if (nextSequenceId == null)
throw new Error("No sequences left");
return LlamaContextSequence._create({
sequenceId: nextSequenceId,
context: this,
tokenMeter: _tokenMeter,
contextShift: {
size: contextShiftSize,
strategy: contextShiftStrategy
},
tokenPredictor
});
}
dispatchPendingBatch() {
this._currentDispatchBatchHandle = {};
this._dispatchDecodeScheduled = false;
if (this._batchDispatchPending)
return;
this._batchDispatchPending = true;
void withLock(this, "context", async () => {
this._currentDispatchBatchHandle = {};
this._dispatchDecodeScheduled = false;
this._batchDispatchPending = false;
let shouldHaveAnotherLoop = this._queuedDecodes.length > 0;
const queuedDecodeToMappedLogits = new Map();
const resolvePrioritizationStrategy = () => {
try {
this._ensureNotDisposed();
return resolveBatchItemsPrioritizationStrategy(this._batchingOptions.itemPrioritizationStrategy);
}
catch (err) {
this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err);
}
return null;
};
const getOrderedQueuedDecodes = (prioritizationStrategy) => {
const batchItemToQueuedDecodeMap = new Map();
const batchItemsList = [];
for (const queuedDecode of this._queuedDecodes) {
const batchItem = {
tokens: queuedDecode.tokens,
logits: queuedDecode.logits,
evaluationPriority: queuedDecode.evaluationPriority
};
batchItemToQueuedDecodeMap.set(batchItem, queuedDecode);
batchItemsList.push(batchItem);
}
let prioritizedItems;
try {
prioritizedItems = prioritizationStrategy({
items: batchItemsList,
size: this._batchSize
});
}
catch (err) {
this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err);
return null;
}
return prioritizedItems.map((prioritizedItem) => {
const queuedDecode = batchItemToQueuedDecodeMap.get(prioritizedItem.item);
if (queuedDecode == null)
throw new Error("Received invalid batch item. Make sure you keep the original object reference " +
"of the batch item on `item` on `PrioritizedBatchItem` in your custom prioritization strategy");
return {
queuedDecode,
processAmount: prioritizedItem.processAmount
};
});
};
const fitQueuedDecodesToABatch = (queuedDecodes, batchSize) => {
const currentBatchItems = [];
let currentBatchSize = 0;
let batchTokenSlotsLeft = batchSize;
for (const { queuedDecode, processAmount } of queuedDecodes) {
const resolvedProcessAmount = Math.min(processAmount <= 0 ? 1 : processAmount, queuedDecode.tokens.length, batchTokenSlotsLeft);
if (resolvedProcessAmount <= 0) {
if (batchTokenSlotsLeft === 0)
break;
continue;
}
batchTokenSlotsLeft -= resolvedProcessAmount;
currentBatchSize += resolvedProcessAmount;
currentBatchItems.push({
queuedDecode,
processAmount: resolvedProcessAmount
});
}
return {
currentBatchItems,
currentBatchSize
};
};
const decodeTokenBatchItems = async (batchItems, currentBatchSize) => {
const afterDecodeActions = [];
const queuedDecodesToDelete = new Set();
const currentQueuedDecodeItems = new Set();
if (currentBatchSize !== 0)
this._ctx.initBatch(currentBatchSize);
for (const { queuedDecode, processAmount } of batchItems) {
let batchLogitIndexes;
const tokensToProcess = queuedDecode.tokens.slice(0, processAmount);
const tokenIndexesWithLogitsToProcess = queuedDecode.logits.slice(0, processAmount)
.map((logit, index) => (logit ? index : undefined))
.filter((index) => index != undefined);
const numberOfOutputTokens = tokenIndexesWithLogitsToProcess.length;
TokenMeter.useTokens(queuedDecode.tokenMeter, Math.max(0, tokensToProcess.length - numberOfOutputTokens), "input");
TokenMeter.useTokens(queuedDecode.tokenMeter, numberOfOutputTokens, "output");
try {
batchLogitIndexes = this._ctx.addToBatch(queuedDecode.sequenceId, queuedDecode.firstTokenSequenceIndex, Uint32Array.from(tokensToProcess), Uint32Array.from(tokenIndexesWithLogitsToProcess));
}
catch (err) {
this._dispatchErrorForQueuedDecodesAndDequeue(new Set([queuedDecode]), err);
continue;
}
currentQueuedDecodeItems.add(queuedDecode);
if (queuedDecode.tokens.length === processAmount) {
queuedDecodesToDelete.add(queuedDecode);
afterDecodeActions.push({
queuedDecode,
batchLogitIndexes,
batchLogitTokenIndexes: tokenIndexesWithLogitsToProcess,
firstTokenIndex: queuedDecode.firstTokenSequenceIndex,
returnResults: true
});
}
else {
if (batchLogitIndexes.length > 0)
afterDecodeActions.push({
queuedDecode,
batchLogitIndexes,
batchLogitTokenIndexes: tokenIndexesWithLogitsToProcess,
firstTokenIndex: queuedDecode.firstTokenSequenceIndex
});
queuedDecode.tokens = queuedDecode.tokens.slice(processAmount);
queuedDecode.logits = queuedDecode.logits.slice(processAmount);
queuedDecode.firstTokenSequenceIndex += processAmount;
}
}
for (let i = 0; i < this._queuedDecodes.length; i++) {
const queuedDecode = this._queuedDecodes[i];
if (queuedDecodesToDelete.has(queuedDecode)) {
this._queuedDecodes.splice(i, 1);
this._queuedDecodeSequenceIds.delete(queuedDecode.sequenceId);
i--;
}
}
if (currentBatchSize !== 0) {
const allocationResult = this._threadSplitterConsumer?.getAllocationToConsume();
const [threadsToUse, consumerHandle] = allocationResult instanceof Promise
? await allocationResult ?? []
: allocationResult ?? [];
try {
if (threadsToUse != null)
this._ctx.setThreads(threadsToUse);
await this._ctx.decodeBatch();
consumerHandle?.dispose();
}
catch (err) {
consumerHandle?.dispose();
this._dispatchErrorForQueuedDecodesAndDequeue(currentQueuedDecodeItems, err);
return;
}
}
function finishAfterDecodeAction(action, mappedLogitValues) {
if (mappedLogitValues != null && mappedLogitValues.length > 0) {
if (queuedDecodeToMappedLogits.has(action.queuedDecode))
pushAll(queuedDecodeToMappedLogits.get(action.queuedDecode), mappedLogitValues);
else
queuedDecodeToMappedLogits.set(action.queuedDecode, mappedLogitValues);
}
if (action.returnResults != null) {
const [accept] = action.queuedDecode.response;
const mappedLogits = queuedDecodeToMappedLogits.get(action.queuedDecode) ?? [];
queuedDecodeToMappedLogits.delete(action.queuedDecode);
accept(mappedLogits);
}
}
const afterDecodeActionResults = afterDecodeActions.map((action) => {
if (action.batchLogitIndexes.length === 0) {
finishAfterDecodeAction(action);
return undefined;
}
const mappedLogitValues = [];
let promiseChain = undefined;
const batchLogitIndexes = action.batchLogitIndexes;
const batchLogitTokenIndexes = action.batchLogitTokenIndexes;
for (let i = 0; i < batchLogitIndexes.length; i++) {
const tokenIndex = batchLogitTokenIndexes[i];
const mappedValue = promiseChain != null
? promiseChain
.then(() => action.queuedDecode.logitDataMapper(batchLogitIndexes[i], tokenIndex + action.firstTokenIndex))
: action.queuedDecode.logitDataMapper(batchLogitIndexes[i], tokenIndex + action.firstTokenIndex);
if (mappedValue instanceof Promise) {
promiseChain = mappedValue;
mappedLogitValues.push(mappedValue
.then((value) => [tokenIndex + action.firstTokenIndex, value]));
}
else
mappedLogitValues.push([tokenIndex + action.firstTokenIndex, mappedValue]);
}
if (promiseChain != null)
return Promise.all(mappedLogitValues)
.then((resolvedMappedLogitValues) => finishAfterDecodeAction(action, resolvedMappedLogitValues));
finishAfterDecodeAction(action, mappedLogitValues);
return undefined;
});
await Promise.all(afterDecodeActionResults);
};
const prioritizationStrategy = resolvePrioritizationStrategy();
if (prioritizationStrategy == null)
return; // all queued items are rejected and dequeued when we get here
this._reserveThreads();
try {
while (shouldHaveAnotherLoop) {
const orderedQueuedDecodes = getOrderedQueuedDecodes(prioritizationStrategy);
if (orderedQueuedDecodes == null)
return; // all queued items are rejected and dequeued when we get here
const { currentBatchItems, currentBatchSize } = fitQueuedDecodesToABatch(orderedQueuedDecodes, this._batchSize);
let preventDisposalHandle;
try {
preventDisposalHandle = this._backendContextDisposeGuard.createPreventDisposalHandle();
}
catch (err) {
this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err);
return;
}
let decodeLock;
// this is a workaround to prevent Vulkan from crashing the process when decoding on multiple contexts in parallel
if (this._llama.gpu === "vulkan")
decodeLock = await acquireLock(decodeSyncWorkaround.vulkanLock, "decode");
try {
await decodeTokenBatchItems(currentBatchItems, currentBatchSize);
shouldHaveAnotherLoop = this._queuedDecodes.length > 0;
}
finally {
decodeLock?.dispose();
preventDisposalHandle.dispose();
}
}
}
finally {
this._scheduleToFreeReservedThreads();
}
});
}
/**
* Print the timings of token evaluation since that last print for this context.
*
* Requires the `performanceTracking` option to be enabled.
*
* > **Note:** it prints on the `LlamaLogLevel.info` level, so if you set the level of your `Llama` instance higher than that,
* it won't print anything.
*/
async printTimings() {
this._ensureNotDisposed();
if (!this._performanceTracking)
throw new UnsupportedError("Performance tracking is not enabled");
this._ctx.printTimings();
await new Promise((accept) => setTimeout(accept, 0)); // wait for the logs to finish printing
}
/** @internal */
async _decodeTokens({ sequenceId, firstTokenSequenceIndex, tokens, logits, evaluationPriority = defaultEvaluationPriority, tokenMeter }, logitDataMapper) {
return await new Promise((accept, reject) => {
this._queuedDecodes.push({
sequenceId,
tokens,
logits,
firstTokenSequenceIndex,
evaluationPriority,
tokenMeter,
response: [accept, reject],
logitDataMapper
});
this._queuedDecodeSequenceIds.add(sequenceId);
this._scheduleDecode();
});
}
/** @internal */
_reclaimUnusedSequenceId(sequenceId) {
if (this._disposed)
return;
void withLock(this, "context", async () => {
if (this._disposed)
return;
this._ctx.disposeSequence(sequenceId);
this._unusedSequenceIds.push(sequenceId);
this._onReclaimUnusedSequenceId.dispatchEvent();
});
}
/** @internal */
_popSequenceId() {
if (this._unusedSequenceIds.length > 0)
return this._unusedSequenceIds.shift();
if (this._nextGeneratedSequenceId < this._totalSequences) {
const sequenceId = this._nextGeneratedSequenceId;
this._nextGeneratedSequenceId++;
return sequenceId;
}
return null;
}
/** @internal */
_scheduleDecode() {
if (this._dispatchDecodeScheduled || this._batchDispatchPending)
return;
this._dispatchDecodeScheduled = true;
const currentPendingBatchHandle = this._currentDispatchBatchHandle;
const dispatch = () => {
if (this._currentDispatchBatchHandle !== currentPendingBatchHandle)
return;
this.dispatchPendingBatch();
};
const dispatchSchedule = this._batchingOptions.dispatchSchedule;
if (this._queuedDecodeSequenceIds.size === this._totalSequences)
dispatch();
if (dispatchSchedule === "nextCycle") {
if (typeof setImmediate === "function")
setImmediate(dispatch);
else
setTimeout(dispatch, 0);
}
else if (typeof dispatchSchedule === "function")
dispatchSchedule(dispatch);
else {
if (typeof setImmediate === "function")
setImmediate(dispatch);
else
setTimeout(dispatch, 0);
}
}
/** @internal */
_dispatchErrorForQueuedDecodesAndDequeue(queuedDecodes, err) {
for (const pendingDecode of queuedDecodes) {
const [, reject] = pendingDecode.response;
reject(err);
}
for (let i = 0; i < this._queuedDecodes.length; i++) {
const item = this._queuedDecodes[i];
if (queuedDecodes.has(item)) {
this._queuedDecodes.splice(i, 1);
this._queuedDecodeSequenceIds.delete(item.sequenceId);
i--;
}
}
}
/** @internal */
_ensureNotDisposed() {
if (this._disposed)
throw new DisposedError();
}
/** @internal */
async _setLora({ filePath, scale }) {
const lora = await this._model._getOrLoadLora(filePath);
this._ctx.setLora(lora, scale ?? defaultLoraScale);
if (!this._loraAdapters.has(lora)) {
this._loraAdapters.add(lora);
lora.usages++;
}
}
/** @internal */
_reserveThreads() {
clearTimeout(this._freeReservedThreadsTimeout);
delete this._freeReservedThreadsTimeout;
if (this._threadSplitterConsumer != null)
return;
this._threadSplitterConsumer = this._llama._threadsSplitter.createConsumer(this._idealThreads, this._minThreads);
}
/** @internal */
_freeReservedThreads() {
clearTimeout(this._freeReservedThreadsTimeout);
delete this._freeReservedThreadsTimeout;
if (this._threadSplitterConsumer == null)
return;
this._threadSplitterConsumer.dispose();
delete this._threadSplitterConsumer;
}
/** @internal */
_scheduleToFreeReservedThreads() {
if (this._threadSplitterConsumer == null)
return;
clearTimeout(this._freeReservedThreadsTimeout);
this._freeReservedThreadsTimeout = setTimeout(this._freeReservedThreads, 0);
}
/** @internal */
static async _create(options, { _model }) {
const sequences = options.sequences ?? getDefaultContextSequences();
const flashAttention = _model.flashAttentionSupported
? Boolean(options.flashAttention ?? _model.defaultContextFlashAttention)
: false;
const loraOptions = typeof options.lora === "string"
? { adapters: [{ filePath: options.lora }] }
: options.lora;
let failedCreationRetries = options.failedCreationRemedy === false
? 0
: Math.max(0, options.failedCreationRemedy?.retries ?? defaultFailedCreationRemedy.retries);
const failedCreationAutoContextSizeShrink = options.failedCreationRemedy === false
? 0
: options.failedCreationRemedy?.autoContextSizeShrink ?? defaultFailedCreationRemedy.autoContextSizeShrink;
let contextSize = await _model.fileInsights.configurationResolver.resolveContextContextSize(options.contextSize, {
batchSize: options.batchSize,
sequences: sequences,
modelGpuLayers: _model.gpuLayers,
modelTrainContextSize: _model.trainContextSize,
flashAttention,
getVramState: () => _model._llama._vramOrchestrator.getMemoryState(),
llamaGpu: _model._llama.gpu,
ignoreMemorySafetyChecks: options.ignoreMemorySafetyChecks,
isEmbeddingContext: options._embeddings
});
const minContextSize = options.contextSize === "auto"
? shrinkRetriesMinContextSize
: (typeof options.contextSize === "object" && typeof options.contextSize.min === "number")
? options.contextSize.min
: typeof options.contextSize === "number"
? options.contextSize
: shrinkRetriesMinContextSize;
const { createSignal } = options;
async function createContext(contextSize) {
const batchSize = options.batchSize ?? getDefaultContextBatchSize({ contextSize, sequences });
const resourceRequirementsEstimation = _model.fileInsights.estimateContextResourceRequirements({
contextSize,
sequences,
isEmbeddingContext: options._embeddings,
modelGpuLayers: _model.gpuLayers,
batchSize,
flashAttention
});
const context = new LlamaContext({ _model }, { ...options, contextSize, batchSize, sequences, flashAttention });
const contextCreationVramReservation = options.ignoreMemorySafetyChecks
? null
: _model._llama._vramOrchestrator.reserveMemory(resourceRequirementsEstimation.gpuVram);
const contextCreationRamReservation = options.ignoreMemorySafetyChecks
? null
: _model._llama._vramOrchestrator.reserveMemory(resourceRequirementsEstimation.cpuRam);
try {
if (createSignal?.aborted)
throw createSignal.reason;
const contextLoaded = await context._ctx.init();
if (createSignal?.aborted) {
if (contextLoaded)
await context._ctx.dispose();
throw createSignal.reason;
}
else if (!contextLoaded)
throw new Error("Failed to create context");
contextCreationVramReservation?.dispose?.();
contextCreationRamReservation?.dispose?.();
if (loraOptions != null && loraOptions.adapters.length > 0) {
let loadedAdapters = 0;
for (const adapter of loraOptions.adapters) {
try {
await context._setLora({
filePath: adapter.filePath,
scale: adapter.scale
});
loadedAdapters++;
try {
loraOptions.onLoadProgress?.(loadedAdapters / loraOptions.adapters.length);
}
catch (err) {
console.error(err);
}
}
catch (err) {
await context.dispose();
throw err;
}
if (createSignal?.aborted) {
await context.dispose();
throw createSignal.reason;
}
}
}
else if (loraOptions?.onLoadProgress != null) {
try {
loraOptions.onLoadProgress(1);
}
catch (err) {
console.error(err);
}
}
return context;
}
finally {
contextCreationVramReservation?.dispose?.();
contextCreationRamReservation?.dispose?.();
}
}
while (failedCreationRetries >= 0) {
try {
return await createContext(contextSize);
}
catch (err) {
if (failedCreationRetries === 0 || (createSignal?.aborted && err === createSignal.reason))
throw err;
failedCreationRetries--;
let newContextSize = typeof failedCreationAutoContextSizeShrink === "number"
? Math.floor(contextSize * (1 - failedCreationAutoContextSizeShrink))
: Math.floor(failedCreationAutoContextSizeShrink(contextSize));
if (!Number.isFinite(newContextSize))
throw err;
if (newContextSize < minContextSize)
newContextSize = minContextSize;
if (newContextSize >= contextSize)
throw err;
contextSize = newContextSize;
}
}
throw new Error("Failed to create context");
}
}
export class LlamaContextSequence {
/** @internal */ _sequenceId;
/** @internal */ _gcRegistry;
/** @internal */ _context;
/** @internal */ _contextShift;
/** @internal */ _tokenPredictor;
/** @internal */ _tokenMeter;
/** @internal */ _disposeAggregator = new DisposeAggregator();
/** @internal */ _lock = {};
/** @internal */ _resetTokenPredictor = false;
/** @internal */ _tokenPredictorOwner = {};
/** @internal */ _contextTokens = [];
/** @internal */ _nextTokenIndex = 0;
/** @internal */ _loadedTokenPredictions = [];
/** @internal */ _usedTokenPredictions = 0;
/** @internal */ _unusedTokenPredictions = 0;
/** @internal */ _validatedTokenPredictions = 0;
/** @internal */ _refutedTokenPredictions = 0;
/** @internal */ _disposed = false;
onDispose = new EventRelay();
constructor({ sequenceId, context, tokenMeter, contextShift, tokenPredictor }) {
this._sequenceId = sequenceId;
this._context = context;
this._tokenMeter = tokenMeter ?? new TokenMeter();
this._contextShift = contextShift;
this._tokenPredictor = tokenPredictor;
this._gcRegistry = new FinalizationRegistry(this._context._reclaimUnusedSequenceId);
this._gcRegistry.register(this, sequenceId);
this._disposeAggregator.add(() => this._gcRegistry.unregister(this));
this._disposeAggregator.add(this.onDispose.dispatchEvent);
this._disposeAggregator.add(this.model.onDispose.createListener(disposeContextSequenceIfReferenced.bind(null, new WeakRef(this))));
this._disposeAggregator.add(() => {
this._context._reclaimUnusedSequenceId(this._sequenceId);
});
if (this._tokenPredictor != null)
this._disposeAggregator.add(this._tokenPredictor);
}
dispose() {
if (this._disposed)
return;
this._disposeAggregator.dispose();
this._contextTokens.length = 0;
this._disposed = true;
}
/** @hidden */
[Symbol.dispose]() {
return this.dispose();
}
get disposed() {
return this._disposed;
}
get context() {
return this._context;
}
get model() {
return this._context.model;
}
/** The maximum number of tokens that the sequence state can hold */
get contextSize() {
return this._context.contextSize;
}
/** The index where the next evaluated token will be placed in the context */
get nextTokenIndex() {
return this._nextTokenIndex - this._loadedTokenPredictions.length;
}
/** The current context state tokens */
get contextTokens() {
if (this._loadedTokenPredictions.length === 0)
return this._contextTokens.slice();
return this._contextTokens.slice(0, -this._loadedTokenPredictions.length);
}
get tokenMeter() {
return this._tokenMeter;
}
/**
* The token predictor used when creating this sequence.
*/
get tokenPredictor() {
return this._tokenPredictor;
}
/**
* Statistics of token predictions using the sequence's `tokenPredictor`.
*
* The statistics change only when token prediction is used in this sequence.
*
* `validated` + `refuted` = total number of evaluated predictions.
*
* Prefer using `validated` and `refuted` to evaluate the effectiveness of token prediction.
*/
get tokenPredictions() {
return {
used: this._usedTokenPredictions,
unused: this._unusedTokenPredictions,
validated: this._validatedTokenPredictions,
refuted: this._refutedTokenPredictions
};
}
get isLoadedToMemory() {
return !this._disposed;
}
compareContextTokens(tokens) {
for (let i = 0; i < this._contextTokens.length - this._loadedTokenPredictions.length; i++) {
if (compareTokens(this._contextTokens[i], tokens[i]))
continue;
return {
firstDifferentIndex: i
};
}
return {
firstDifferentIndex: this._contextTokens.length - this._loadedTokenPredictions.length
};
}
/**
* Erase parts of the context state to align it with the given tokens.
*
* If the given tokens do not align with the current context state, the context state will be erased to align with the given tokens.
*
* To find the first different token index between the context state and the given tokens, access the `nextTokenIndex` property.
*
* If `allowShift` is `true` (the default), shifting tokens may happen to align the context state with the given tokens,
* which incurs token evaluation of the shifted tokens.
*/
async adaptStateToTokens(tokens, allowShift = true) {
const modelSupportsShifting = !this.model.fileInsights.isRecurrent &&
this.model.fileInfo.metadata?.general?.architecture !== GgufArchitectureType.deepseek2;
if (!modelSupportsShifting || !allowShift) {
const { firstDifferentIndex } = this.compareContextTokens(tokens);
if (firstDifferentIndex < this.nextTokenIndex)
await this._eraseContextTokenRanges([{
start: firstDifferentIndex,
end: this._nextTokenIndex
}]);
return;
}
const eraseRanges = [];
let tokensIndex = 0;
let differentTokenIndex = undefined;
for (let i = 0; i < this._contextTokens.length - this._loadedTokenPredictions.length && tokensIndex < tokens.length; i++) {
if (compareTokens(this._contextTokens[i], tokens[tokensIndex])) {
if (differentTokenIndex != null) {
eraseRanges.push({
start: differentTokenIndex,
end: i
});
differentTokenIndex = undefined;
}
tokensIndex++;
continue;
}
if (differentTokenIndex == null)
differentTokenIndex = i;
}
if (differentTokenIndex != null)
eraseRanges.push({
start: differentTokenIndex,
end: this._nextTokenIndex
});
if (eraseRanges.length > 0)
await this._eraseContextTokenRanges(eraseRanges);
}
/**
* Clear the history of the sequence.
* If `prependBos` was enabled, the BOS token will be prepended to the sequence again.
*/
async clearHistory() {
this._ensureNotDisposed();
await this._eraseContextTokenRanges([{ start: 0, end: this._nextTokenIndex }]);
}
/**
* Erase context tokens in the provided ranges to free up space for new tokens to be generated.
* The start of each range is inclusive, and the end of each range is exclusive.
* For example, the range `{start: 0, end: 1}` will remove the token at the `0` index only.
*/
eraseContextTokenRanges(ranges) {
return this._eraseContextTokenRanges(ranges);
}
/** @internal */
async _eraseContextTokenRanges(ranges, { canResetTokenPredictor = true, canRemovePredictionTokens = true, skipLock = false } = {}) {
this._ensureNotDisposed();
await withLock(this._context, "context", async () => {
this._ensureNotDisposed();
if (ranges.length === 0)
return;
// if the deletion fails, we'll have to dispose the sequence and fill it up again
let deletionSuccessful = true;
const resolvedRanges = ranges
.map(({ start, end }) => {
if (start === end)
return null;
if (start > end)
[start, end] = [end, start];
if (end > this._nextTokenIndex)
end = this._nextTokenIndex;
if (start >= this._nextTokenIndex)
return null;
return { start, end };
})
.filter((range) => range != null)
.sort((a, b) => a.start - b.start)
.reduce((ranges, range) => {
if (ranges.length === 0)
return [range];
const lastRange = ranges[ranges.length - 1];
if (lastRange.end >= range.start) {
lastRange.end = Math.max(lastRange.end, range.end);
return ranges;
}
ranges.push(range);
return ranges;
}, []);
const tokenPredictionsToRemove = (resolvedRanges.length > 0 && canRemovePredictionTokens)
? this._loadedTokenPredictions.length
: 0;
if (tokenPredictionsToRemove > 0) {
const startDeleteIndex = this._nextTokenIndex - this._loadedTokenPredictions.length;
const lastDeleteRange = resolvedRanges[resolvedRanges.length - 1];
if (lastDeleteRange.end >= startDeleteIndex)
lastDeleteRange.end = this._nextTokenIndex;
else
resolvedRanges.push({ start: startDeleteIndex, end: this._nextTokenIndex });
if (canResetTokenPredictor)
await this._abortTokenPredictor(true);
}
let removedTokens = 0;
let lastDeleteRangeEndPos = null;
for (const range of resolvedRanges) {
this._contextTokens.splice(range.start - removedTokens, range.end - range.start);
if (deletionSuccessful)
deletionSuccessful &&= this._context._ctx.removeTokenCellsFromSequence(this._sequenceId, range.start, range.end);
if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 && lastDeleteRangeEndPos !== range.start) {
this._context._ctx.shiftSequenceTokenCells(this._sequenceId, lastDeleteRangeEndPos, range.start, -removedTokens);
const shiftedTokens = range.start - lastDeleteRangeEndPos;
this._tokenMeter.useTokens(shiftedTokens, "input");
}
removedTokens += range.end - range.start;
lastDeleteRangeEndPos = range.end;
}
if (tokenPredictionsToRemove > 0)
this._loadedTokenPredictions.splice(0, tokenPredictionsToRemove);
if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 &&
lastDeleteRangeEndPos !== this._nextTokenIndex) {
this._context._ctx.shiftSequenceTokenCells(this._sequenceId, lastDeleteRangeEndPos, this._nextTokenIndex, -removedTokens);
const shiftedTokens = this._nextTokenIndex - lastDeleteRangeEndPos;
this._tokenMeter.useTokens(shiftedTokens, "input");
}
this._nextTokenIndex -= removedTokens;
if (canResetTokenPredictor && removedTokens > 0)
await this._abortTokenPredictor(true);
if (deletionSuccessful)
return;
const newSequenceTokens = this._contextTokens.slice();
this._nextTokenIndex = 0;
this._context._ctx.disposeSequence(this._sequenceId);
await this.evaluateWithoutGeneratingNewTokens(newSequenceTokens, { _skipLock: skipLock });
});
}
/**
* Evaluate the provided tokens into the context sequence, and continue generating new tokens on iterator iterations.
*
* This method uses the token predictor (when provided) to generate new tokens faster.
*/
async *evaluate(tokens, options = {}) {
const iterator = this.evaluateWithMetadata(tokens, {}, options);
let iterateInput = undefined;
try {
while (true) {
const { value, done } = await iterator.next(iterateInput);
if (done)
return;
iterateInput = yield value.token;
}
}
finally {
await iterator.return();
}
}
/**
* Like {@link evaluate `.evaluate(...)`}, but with additional metadata for each generated token.
*
* Configure the additional metadata options to choose which metadata to include.
*/
evaluateWithMetadata(tokens, metadata, options = {}) {
const { temperature = 0, minP = 0, topK = 40, topP = 0.95, seed, grammarEvaluationState, repeatPenalty, tokenBias, evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {}, yieldEogToken = false, _noSampling = false } = options;
if (this._tokenPredictor != null && !_noSampling && tokens.length > 0)
return this._speculativeEvaluate(tokens, metadata, {
temperature,
minP,
topK,
topP,
seed,
grammarEvaluationState,
repeatPenalty,
tokenBias,
evaluationPriority,
contextShiftOptions: {
size: contextShiftSize,
strategy: contextShiftStrategy
},
yieldEogToken,
tokenPredictor: this._tokenPredictor
});
return this._evaluate(tokens, metadata, {
temperature,
minP,
topK,
topP,
seed,
grammarEvaluationState,
repeatPenalty,
tokenBias,
evaluationPriority,
contextShiftOptions: {
size: contextShiftSize,
strategy: contextShiftStrategy
},
yieldEogToken,
_noSampling
});
}
/**
* Evaluate the provided tokens into the context sequence without generating new tokens.
*/
async evaluateWithoutGeneratingNewTokens(tokens, options = {}) {
const { evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {}, _skipLock = false } = options;
const iterator = this._evaluate(tokens, {}, {
generateNewTokens: false,
evaluationPriority,
contextShiftOptions: {
size: contextShiftSize,
strategy: contextShiftStrategy
},
_skipLock
});
const predictorAlignmentPromise = this.tokenPredictor == null
? undefined
: this._tokenPredictor?.reset({
stateTokens: [...this._contextTokens, ...tokens],
evaluateOptions: {
evaluationPriority,
contextShift: {
size: contextShiftSize,
strategy: contextShiftStrategy
}
},
targetSequence: this
});
if (predictorAlignmentPromise != null) {
this._tokenPredictorOwner = {};
this._resetTokenPredictor = false;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
for await (const token of iterator) {
// Array.from doesn't work with async generators, so we have to iterate over the generator
}
await iterator.return();
if (predictorAlignmentPromise != null)
await predictorAlignmentPromise;
}
/**
* Evaluate the provided tokens into the context sequence with custom options for each token.
*
* This method allows for more precise control of the generation process.
*
* A next token will be generated for a given token only if any of the `generateNext` options for it are used.
*
* To generate more tokens after this method finishes,
* use it again with token(s) you selected to add to the context from the previous evaluation.
*
* This method doesn't use the token predictor (when provided) since it cannot predict which tokens are actually needed.
* Use the `evaluate` method when you need to use token prediction.
* @returns An array where for each token in the input array, there can be an output item at the same index in the output array.
* For indexes that have no output, there won't be any value at the corresponding index in the output array.
*
* It's recommended to iterate from `0` up to the length of the input array to check the results in the output array.
*/
async controlledEvaluate(input, options) {
const { evaluationPriority = defaultEvaluationPriority, contextShift: { size: contextShiftSize = this._contextShift.size, strategy: contextShiftStrategy = this._contextShift.strategy } = {} } = options ?? {};
const contextShiftOptions = {
size: contextShiftSize,
strategy: contextShiftStrategy
};
this._ensureNotDisposed();
if (input.length === 0)
return [];
a