UNPKG

node-llama-cpp

Version:

Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level

138 lines 6.5 kB
import { DisposedError } from "lifecycle-utils"; import { pushAll } from "../../../utils/pushAll.js"; import { TokenPredictor } from "../TokenPredictor.js"; const defaultPatternMinLength = 1; const defaultPatternMaxLength = 0; const defaultPredictionMinLength = 1; const defaultPredictionMaxLength = 3; /** * Attempts to find the last few generated tokens in the input (prompt) tokens to predict the next tokens. * * This is useful in input-grounded tasks (when the model frequently repeats some of the input tokens in the output, * such as in text summarization or modifying code). * * This works in all completion classes, including `LlamaChatSession`, `LlamaChat`, and `LlamaCompletion`. * * Based on https://github.com/apoorvumang/prompt-lookup-decoding. * @see [Using Token Predictors: Input Lookup Token Predictor](https://node-llama-cpp.withcat.ai/guide/token-prediction#input-lookup) */ export class InputLookupTokenPredictor extends TokenPredictor { /** @internal */ _patternMinLength; /** @internal */ _patternMaxLength; /** @internal */ _predictionMinLength; /** @internal */ _predictionMaxLength; /** @internal */ _lastPredictionMatchStartIndex = undefined; /** @internal */ _lastPredictionMatchLength = undefined; /** @internal */ _stateTokens = []; /** @internal */ _inputTokens = []; /** @internal */ _disposed = false; constructor(options = {}) { super(); this._patternMinLength = Math.floor(Math.max(1, options?.patternLength?.min ?? defaultPatternMinLength)); this._patternMaxLength = Math.floor(Math.max(0, Math.max(this._patternMinLength, options?.patternLength?.max ?? defaultPatternMaxLength))); this._predictionMinLength = Math.floor(Math.max(1, options.predictionLength?.min ?? defaultPredictionMinLength)); this._predictionMaxLength = Math.floor(Math.max(this._patternMinLength, options.predictionLength?.max ?? defaultPredictionMaxLength)); } get patternMinLength() { return this._patternMinLength; } get patternMaxLength() { return this._patternMaxLength; } get predictionMinLength() { return this._predictionMinLength; } get predictionMaxLength() { return this._predictionMaxLength; } reset({ stateTokens }) { this._stateTokens = stateTokens.slice(); delete this._lastPredictionMatchStartIndex; delete this._lastPredictionMatchLength; } updateInputTokens(tokens) { this._inputTokens = tokens.slice(); delete this._lastPredictionMatchStartIndex; delete this._lastPredictionMatchLength; } pushTokens(tokens) { pushAll(this._stateTokens, tokens); if (this._lastPredictionMatchStartIndex != null && this._lastPredictionMatchLength != null) { this._lastPredictionMatchLength += tokens.length; } } predictTokens() { if (this._disposed) throw new DisposedError(); if (this._inputTokens.length === 0 || this._stateTokens.length === 0) return []; if (this._lastPredictionMatchStartIndex != null && this._lastPredictionMatchLength != null) { for (let p = this._lastPredictionMatchStartIndex + this._lastPredictionMatchLength - 1, s = this._stateTokens.length - 1; p >= this._lastPredictionMatchStartIndex && s >= 0; p--, s--) { if (this._inputTokens[p] !== this._stateTokens[s]) { delete this._lastPredictionMatchStartIndex; delete this._lastPredictionMatchLength; break; } } if (this._lastPredictionMatchStartIndex != null && this._lastPredictionMatchLength != null) { const predictionEndIndex = this._lastPredictionMatchStartIndex + this._lastPredictionMatchLength; if (predictionEndIndex < this._inputTokens.length) { return this._inputTokens.slice(predictionEndIndex, predictionEndIndex + this._predictionMaxLength); } } } const [matchStartIndex, matchLength] = this._findLongestPatternIndex(this._inputTokens, this._stateTokens); if (matchStartIndex == null || matchLength == null) return []; const predictionEndIndex = matchStartIndex + matchLength; const res = this._inputTokens.slice(predictionEndIndex, predictionEndIndex + this._predictionMaxLength); if (res.length >= this._predictionMinLength) { this._lastPredictionMatchStartIndex = matchStartIndex; this._lastPredictionMatchLength = matchLength; return res; } return []; } dispose() { this._disposed = true; this._stateTokens = []; this._inputTokens = []; delete this._lastPredictionMatchStartIndex; delete this._lastPredictionMatchLength; } /** @internal */ _findLongestPatternIndex(findIn, lookupPattern) { const checkIndexes = []; let bestIndex = -1; let bestIndexDiff = -1; for (let i = findIn.length - this._predictionMinLength; i >= 0; i--) { const token = findIn[i]; for (let j = checkIndexes.length - 1; j >= 0; j--) { const startIndex = checkIndexes[j]; const indexDiff = startIndex - i; if (lookupPattern[lookupPattern.length - 1 - indexDiff] !== token || (this._patternMaxLength > 0 && indexDiff >= this._patternMaxLength)) { checkIndexes.splice(j, 1); if (indexDiff >= this._patternMinLength && indexDiff >= bestIndexDiff) { bestIndex = startIndex; bestIndexDiff = indexDiff; } } } if (token === lookupPattern[lookupPattern.length - 1]) checkIndexes.unshift(i); } for (let j = checkIndexes.length - 1; j >= 0; j--) { const startIndex = checkIndexes[j]; const indexDiff = startIndex + 1; checkIndexes.splice(j, 1); if (indexDiff >= this._patternMinLength && indexDiff >= bestIndexDiff) { bestIndex = startIndex; bestIndexDiff = indexDiff; } } if (bestIndex >= 0) return [bestIndex - (bestIndexDiff - 1), bestIndexDiff]; return []; } } //# sourceMappingURL=InputLookupTokenPredictor.js.map