node-llama-cpp
Version:
Run AI models locally on your machine with node.js bindings for llama.cpp. Enforce a JSON schema on the model output on the generation level
291 lines • 11.8 kB
JavaScript
import { SpecialToken, isLlamaText, SpecialTokensText } from "./LlamaText.js";
export class StopGenerationDetector {
/** @internal */ _stopTriggers = new Map();
/** @internal */ _activeChecks = new Set();
/** @internal */ _triggeredStops = new Map();
recordGeneration({ text, tokens, queuedTokenRelease, startNewChecks = true, triggerMustStartWithGeneration = false }) {
const currentActiveChecks = this._activeChecks;
this._activeChecks = new Set();
for (const check of currentActiveChecks) {
let checkKept = false;
if (text.length > 0)
this._checkTriggerPart(check, text);
else {
this._activeChecks.add(check);
checkKept = true;
}
if (tokens.length > 0)
this._checkTriggerPart(check, tokens);
else {
this._activeChecks.add(check);
checkKept = true;
}
if (!checkKept)
check.queuedTokenReleaseLock?.dispose();
}
if (!startNewChecks)
return;
for (let i = 0; i < text.length && (!triggerMustStartWithGeneration || i === 0); i++) {
const char = text[i];
const currentPart = this._stopTriggers.get(char);
if (currentPart == null)
continue;
const textCheck = {
queuedTokenReleaseLock: queuedTokenRelease?.createTextIndexLock(i),
currentPart
};
this._checkTriggerPart(textCheck, text.slice(i + 1));
textCheck.queuedTokenReleaseLock?.dispose();
}
for (let i = 0; i < tokens.length && (!triggerMustStartWithGeneration || i === 0); i++) {
const token = tokens[i];
const currentPart = this._stopTriggers.get(token);
if (currentPart == null)
continue;
const tokenCheck = {
queuedTokenReleaseLock: queuedTokenRelease?.createTokenIndexLock(i),
currentPart
};
this._checkTriggerPart(tokenCheck, tokens.slice(i + 1));
tokenCheck.queuedTokenReleaseLock?.dispose();
}
}
addStopTrigger(stopTrigger, completeEvent) {
const simplifiedTrigger = simplifyStopTrigger(stopTrigger);
const triggerValues = simplifiedTrigger
.map((item) => {
if (typeof item === "string")
return item.split("");
else
return [item];
})
.flat(1);
let currentMap = this._stopTriggers;
for (let i = 0; i < triggerValues.length; i++) {
const value = triggerValues[i];
const isLast = i === triggerValues.length - 1;
if (!currentMap.has(value)) {
currentMap.set(value, {
next: new Map()
});
}
const part = currentMap.get(value);
if (isLast) {
part.next = undefined;
part.completesTrigger = simplifiedTrigger;
part.completeEvents = part.completeEvents ?? new Set();
if (completeEvent != null)
part.completeEvents.add(completeEvent);
}
else if (part.next == null)
break;
else
currentMap = part.next;
}
return this;
}
/** Whether there are some stops that have been found and triggered. */
get hasTriggeredStops() {
return this._triggeredStops.size > 0;
}
/** Whether there are some stops that have been found, but not triggered yet. */
get hasInProgressStops() {
return this._activeChecks.size > 0;
}
/** Gets the stops that have been found and triggered. */
getTriggeredStops() {
const res = [];
for (const [triggerPart, triggeredStop] of this._triggeredStops.entries()) {
res.push({
stopTrigger: triggerPart.completesTrigger,
events: Array.from(triggerPart.completeEvents ?? new Set()),
remainingGeneration: Array.from(triggeredStop.remainingGenerations),
queuedTokenReleaseLocks: Array.from(triggeredStop.queuedTokenReleaseLocks)
});
}
return res;
}
clearTriggeredStops() {
for (const triggeredStop of this._triggeredStops.values()) {
for (const queuedTokenReleaseLock of triggeredStop.queuedTokenReleaseLocks)
queuedTokenReleaseLock.dispose();
}
this._triggeredStops.clear();
}
clearInProgressStops() {
for (const check of this._activeChecks)
check.queuedTokenReleaseLock?.dispose();
this._activeChecks.clear();
}
get hasTriggers() {
return this._stopTriggers.size > 0;
}
/**
* For a given generation, get the number of possibilities that would be disregarded if the generation is recorded.
*
* Calling this function does not change the state of the detector.
*/
getDisregardedPossibilitiesCountForAGeneration({ text, tokens, startNewChecks }) {
let res = 0;
for (const check of this._activeChecks) {
const disregardedTextPossibilities = this._getCountOfPossibleTriggersToBeDisregarded(check.currentPart, text);
const disregardedTokenPossibilities = this._getCountOfPossibleTriggersToBeDisregarded(check.currentPart, tokens);
res += Math.min(disregardedTextPossibilities, disregardedTokenPossibilities);
}
if (startNewChecks) {
const disregardedTextPossibilities = text.length > 0
? this._getCountOfPossibleTriggersToBeDisregarded(this._stopTriggers.get(text[0]), text.slice(1))
: null;
const disregardedTokenPossibilities = tokens.length > 0
? this._getCountOfPossibleTriggersToBeDisregarded(this._stopTriggers.get(tokens[0]), tokens.slice(1))
: null;
if (disregardedTextPossibilities != null && disregardedTokenPossibilities != null)
res += Math.min(disregardedTextPossibilities, disregardedTokenPossibilities);
else if (disregardedTextPossibilities != null)
res += disregardedTextPossibilities;
else if (disregardedTokenPossibilities != null)
res += disregardedTokenPossibilities;
}
return res;
}
/** @internal */
_addFoundStop(part, remainingGeneration, queuedTokenReleaseLock) {
if (!this._triggeredStops.has(part))
this._triggeredStops.set(part, {
remainingGenerations: new Set(),
queuedTokenReleaseLocks: new Set()
});
const triggeredStop = this._triggeredStops.get(part);
if (remainingGeneration != null)
triggeredStop.remainingGenerations.add(remainingGeneration);
if (queuedTokenReleaseLock != null)
triggeredStop.queuedTokenReleaseLocks.add(queuedTokenReleaseLock);
}
/** @internal */
_getCountOfPossibleTriggersToBeDisregarded(initialPart, value) {
if (initialPart == null)
return 0;
let part = initialPart;
let res = 0;
for (let i = 0; i < value.length && part != null; i++) {
const item = value[i];
if (part.next == null)
return res + 1;
if (part.next.has(item)) {
res += part.next.size - 1;
part = part.next.get(item);
continue;
}
return res + part.next.size;
}
if (part == null || part.next == null)
return res + 1;
return res;
}
/** @internal */
_checkTriggerPart(check, value) {
if (check == null)
return false;
let part = check.currentPart;
for (let i = 0; i < value.length && part != null; i++) {
const item = value[i];
if (part.next == null) {
this._addFoundStop(part, value.slice(i), check.queuedTokenReleaseLock?.duplicate?.());
return true;
}
if (part.next.has(item)) {
part = part.next.get(item);
continue;
}
return false;
}
if (part == null)
return false;
if (part.next == null) {
this._addFoundStop(part, undefined, check.queuedTokenReleaseLock?.duplicate?.());
return true;
}
else {
this._activeChecks.add({
...check,
currentPart: part,
queuedTokenReleaseLock: check.queuedTokenReleaseLock?.duplicate?.()
});
return true;
}
}
static resolveStopTriggers(stopTriggers, tokenizer) {
return stopTriggers
.map((stopTrigger) => {
if (isLlamaText(stopTrigger))
return StopGenerationDetector.resolveLlamaTextTrigger(stopTrigger, tokenizer);
else if (typeof stopTrigger === "string")
return simplifyStopTrigger([stopTrigger]);
else
return simplifyStopTrigger(stopTrigger);
})
.filter((stopTrigger) => stopTrigger.length > 0);
}
static resolveLlamaTextTrigger(llamaText, tokenizer) {
return simplifyStopTrigger(llamaText.values
.filter((value) => value !== "")
.map((value) => {
if (typeof value === "string")
return [value];
else if (value instanceof SpecialToken)
return value.tokenize(tokenizer);
else if (value instanceof SpecialTokensText)
return value.tokenizeSpecialTokensOnly(tokenizer);
return value;
})
.flat(1));
}
static getFirstRemainingGenerationAfterStop(triggeredStops) {
const [stopTrigger] = triggeredStops
.filter((stopTrigger) => (stopTrigger.remainingGeneration.some((remainingGeneration) => remainingGeneration.length > 0)));
return {
stopTrigger: stopTrigger?.stopTrigger ?? triggeredStops?.[0]?.stopTrigger,
firstRemainingGenerationAfterStop: stopTrigger?.remainingGeneration?.filter((remainingGeneration) => remainingGeneration.length > 0)?.[0]
};
}
static detokenizeRemainingGeneration(remainingGeneration, stopTrigger, tokenizer, specialTokens = false) {
if (remainingGeneration == null || remainingGeneration.length === 0)
return "";
if (typeof remainingGeneration === "string")
return remainingGeneration;
return tokenizer.detokenize(remainingGeneration, specialTokens, tokenizeStopTrigger(stopTrigger, tokenizer, specialTokens));
}
}
function simplifyStopTrigger(stopTrigger) {
let text = "";
const res = [];
for (const item of stopTrigger) {
if (typeof item === "string") {
text += item;
continue;
}
if (text !== "") {
res.push(text);
text = "";
}
res.push(item);
}
if (text !== "")
res.push(text);
return res;
}
function tokenizeStopTrigger(stopTrigger, tokenizer, specialTokens = false) {
if (stopTrigger == null)
return [];
const res = [];
for (const item of stopTrigger) {
if (typeof item === "string") {
const tokens = tokenizer(item, specialTokens, "trimLeadingSpace");
res.push(...tokens);
}
else
res.push(item);
}
return res;
}
//# sourceMappingURL=StopGenerationDetector.js.map