llamaindex
Version:
<p align="center"> <img height="100" width="100" alt="LlamaIndex logo" src="https://ts.llamaindex.ai/square.svg" /> </p> <h1 align="center">LlamaIndex.TS</h1> <h3 align="center"> Data framework for your LLM application. </h3>
393 lines (383 loc) • 11.9 kB
JavaScript
import { PromptMixin, PromptTemplate } from '@llamaindex/core/prompts';
import { extractText } from '@llamaindex/core/utils';
import { Settings as Settings$1 } from '@llamaindex/core/global';
import { PromptHelper } from '@llamaindex/core/indices';
import { SentenceSplitter } from '@llamaindex/core/node-parser';
import { AsyncLocalStorage } from '@llamaindex/env';
function wrapChoice(choice) {
if (typeof choice === "string") {
return {
description: choice
};
} else {
return choice;
}
}
class BaseSelector extends PromptMixin {
async select(choices, query) {
const metadata = choices.map((choice)=>wrapChoice(choice));
return await this._select(metadata, query);
}
}
/**
* Error class for output parsing. Due to the nature of LLMs, anytime we use LLM
* to generate structured output, it's possible that it will hallucinate something
* that doesn't match the expected output format. So make sure to catch these
* errors in production.
*/ class OutputParserError extends Error {
constructor(message, options = {}){
super(message, options); // https://github.com/tc39/proposal-error-cause
this.name = "OutputParserError";
if (!this.cause) {
// Need to check for those environments that have implemented the proposal
this.cause = options.cause;
}
this.output = options.output;
// This line is to maintain proper stack trace in V8
// (https://v8.dev/docs/stack-trace-api)
if (Error.captureStackTrace) {
Error.captureStackTrace(this, OutputParserError);
}
}
}
/**
*
* @param text A markdown block with JSON
* @returns parsed JSON object
*/ function parseJsonMarkdown(text) {
text = text.trim();
const left_square = text.indexOf("[");
const left_brace = text.indexOf("{");
let left;
let right;
if (left_square < left_brace && left_square != -1) {
left = left_square;
right = text.lastIndexOf("]");
} else {
left = left_brace;
right = text.lastIndexOf("}");
}
const jsonText = text.substring(left, right + 1);
try {
//Single JSON object case
if (left_square === -1) {
return [
JSON.parse(jsonText)
];
}
//Multiple JSON object case.
return JSON.parse(jsonText);
} catch (e) {
throw new OutputParserError("Not a json markdown", {
output: text
});
}
}
const formatStr = `The output should be ONLY JSON formatted as a JSON instance.
Here is an example:
[
{
"choice": 1,
"reason": "<insert reason for choice>"
},
...
]
`;
/*
* An OutputParser is used to extract structured data from the raw output of the LLM.
*/ class SelectionOutputParser {
/**
*
* @param output
*/ parse(output) {
let parsed;
try {
parsed = parseJsonMarkdown(output);
} catch (e) {
try {
parsed = JSON.parse(output);
} catch (e) {
throw new Error(`Got invalid JSON object. Error: ${e}. Got JSON string: ${output}`);
}
}
return {
rawOutput: output,
parsedOutput: parsed
};
}
format(output) {
return output + "\n\n" + formatStr;
}
}
const defaultSingleSelectPrompt = new PromptTemplate({
templateVars: [
"context",
"query",
"numChoices"
],
template: `Some choices are given below. It is provided in a numbered list (1 to {numChoices}), where each item in the list corresponds to a summary.
---------------------
{context}
---------------------
Using only the choices above and not prior knowledge, return the choice that is most relevant to the question: '{query}'
`
});
const defaultMultiSelectPrompt = new PromptTemplate({
templateVars: [
"contextList",
"query",
"maxOutputs",
"numChoices"
],
template: `Some choices are given below. It is provided in a numbered list (1 to {numChoices}), where each item in the list corresponds to a summary.
---------------------
{contextList}
---------------------
Using only the choices above and not prior knowledge, return the top choices (no more than {maxOutputs}, but only select what is needed) that are most relevant to the question: '{query}'
`
});
function buildChoicesText(choices) {
const texts = [];
for (const [ind, choice] of choices.entries()){
let text = choice.description.split("\n").join(" ");
text = `(${ind + 1}) ${text}`; // to one indexing
texts.push(text);
}
return texts.join("");
}
function structuredOutputToSelectorResult(output) {
const structuredOutput = output;
const answers = structuredOutput.parsedOutput;
// adjust for zero indexing
const selections = answers.map((answer)=>{
return {
index: answer.choice - 1,
reason: answer.reason
};
});
return {
selections
};
}
/**
* A selector that uses the LLM to select a single or multiple choices from a list of choices.
*/ class LLMMultiSelector extends BaseSelector {
constructor(init){
super();
this.llm = init.llm;
this.prompt = init.prompt ?? defaultMultiSelectPrompt;
this.maxOutputs = init.maxOutputs ?? 10;
this.outputParser = init.outputParser ?? new SelectionOutputParser();
}
_getPrompts() {
return {
prompt: this.prompt
};
}
_updatePrompts(prompts) {
if ("prompt" in prompts) {
this.prompt = prompts.prompt;
}
}
_getPromptModules() {
throw new Error("Method not implemented.");
}
/**
* Selects a single choice from a list of choices.
* @param choices
* @param query
*/ async _select(choices, query) {
const choicesText = buildChoicesText(choices);
const prompt = this.prompt.format({
contextList: choicesText,
query: extractText(query.query),
maxOutputs: `${this.maxOutputs}`,
numChoices: `${choicesText.length}`
});
const formattedPrompt = this.outputParser?.format(prompt);
const prediction = await this.llm.complete({
prompt: formattedPrompt
});
const parsed = this.outputParser?.parse(prediction.text);
if (!parsed) {
throw new Error("Parsed output is undefined");
}
return structuredOutputToSelectorResult(parsed);
}
asQueryComponent() {
throw new Error("Method not implemented.");
}
}
/**
* A selector that uses the LLM to select a single choice from a list of choices.
*/ class LLMSingleSelector extends BaseSelector {
constructor(init){
super();
this.llm = init.llm;
this.prompt = init.prompt ?? defaultSingleSelectPrompt;
this.outputParser = init.outputParser ?? new SelectionOutputParser();
}
_getPrompts() {
return {
prompt: this.prompt
};
}
_updatePrompts(prompts) {
if ("prompt" in prompts) {
this.prompt = prompts.prompt;
}
}
/**
* Selects a single choice from a list of choices.
* @param choices
* @param query
*/ async _select(choices, query) {
const choicesText = buildChoicesText(choices);
const prompt = this.prompt.format({
numChoices: `${choicesText.length}`,
context: choicesText,
query: extractText(query)
});
const formattedPrompt = this.outputParser.format(prompt);
const prediction = await this.llm.complete({
prompt: formattedPrompt
});
const parsed = this.outputParser?.parse(prediction.text);
if (!parsed) {
throw new Error("Parsed output is undefined");
}
return structuredOutputToSelectorResult(parsed);
}
asQueryComponent() {
throw new Error("Method not implemented.");
}
_getPromptModules() {
return {};
}
}
/**
* @internal
*/ class GlobalSettings {
#prompt;
#promptHelper;
#nodeParser;
#chunkOverlap;
#promptHelperAsyncLocalStorage;
#nodeParserAsyncLocalStorage;
#chunkOverlapAsyncLocalStorage;
#promptAsyncLocalStorage;
get debug() {
return Settings$1.debug;
}
get llm() {
return Settings$1.llm;
}
set llm(llm) {
Settings$1.llm = llm;
}
withLLM(llm, fn) {
return Settings$1.withLLM(llm, fn);
}
get promptHelper() {
if (this.#promptHelper === null) {
this.#promptHelper = new PromptHelper();
}
return this.#promptHelperAsyncLocalStorage.getStore() ?? this.#promptHelper;
}
set promptHelper(promptHelper) {
this.#promptHelper = promptHelper;
}
withPromptHelper(promptHelper, fn) {
return this.#promptHelperAsyncLocalStorage.run(promptHelper, fn);
}
get embedModel() {
return Settings$1.embedModel;
}
set embedModel(embedModel) {
Settings$1.embedModel = embedModel;
}
withEmbedModel(embedModel, fn) {
return Settings$1.withEmbedModel(embedModel, fn);
}
get nodeParser() {
if (this.#nodeParser === null) {
this.#nodeParser = new SentenceSplitter({
chunkSize: this.chunkSize,
chunkOverlap: this.chunkOverlap
});
}
return this.#nodeParserAsyncLocalStorage.getStore() ?? this.#nodeParser;
}
set nodeParser(nodeParser) {
this.#nodeParser = nodeParser;
}
withNodeParser(nodeParser, fn) {
return this.#nodeParserAsyncLocalStorage.run(nodeParser, fn);
}
get callbackManager() {
return Settings$1.callbackManager;
}
set callbackManager(callbackManager) {
Settings$1.callbackManager = callbackManager;
}
withCallbackManager(callbackManager, fn) {
return Settings$1.withCallbackManager(callbackManager, fn);
}
set chunkSize(chunkSize) {
Settings$1.chunkSize = chunkSize;
}
get chunkSize() {
return Settings$1.chunkSize;
}
withChunkSize(chunkSize, fn) {
return Settings$1.withChunkSize(chunkSize, fn);
}
get chunkOverlap() {
return this.#chunkOverlapAsyncLocalStorage.getStore() ?? this.#chunkOverlap;
}
set chunkOverlap(chunkOverlap) {
if (typeof chunkOverlap === "number") {
this.#chunkOverlap = chunkOverlap;
}
}
withChunkOverlap(chunkOverlap, fn) {
return this.#chunkOverlapAsyncLocalStorage.run(chunkOverlap, fn);
}
get prompt() {
return this.#promptAsyncLocalStorage.getStore() ?? this.#prompt;
}
set prompt(prompt) {
this.#prompt = prompt;
}
withPrompt(prompt, fn) {
return this.#promptAsyncLocalStorage.run(prompt, fn);
}
constructor(){
this.#prompt = {};
this.#promptHelper = null;
this.#nodeParser = null;
this.#promptHelperAsyncLocalStorage = new AsyncLocalStorage();
this.#nodeParserAsyncLocalStorage = new AsyncLocalStorage();
this.#chunkOverlapAsyncLocalStorage = new AsyncLocalStorage();
this.#promptAsyncLocalStorage = new AsyncLocalStorage();
}
}
const Settings = new GlobalSettings();
const getSelectorFromContext = (isMulti = false)=>{
let selector = null;
const llm = Settings.llm;
if (isMulti) {
selector = new LLMMultiSelector({
llm
});
} else {
selector = new LLMSingleSelector({
llm
});
}
if (selector === null) {
throw new Error("Selector is null");
}
return selector;
};
export { BaseSelector, LLMMultiSelector, LLMSingleSelector, getSelectorFromContext };