UNPKG

llamaindex

Version:

<p align="center"> <img height="100" width="100" alt="LlamaIndex logo" src="https://ts.llamaindex.ai/square.svg" /> </p> <h1 align="center">LlamaIndex.TS</h1> <h3 align="center"> Data framework for your LLM application. </h3>

393 lines (383 loc) 11.9 kB
import { PromptMixin, PromptTemplate } from '@llamaindex/core/prompts'; import { extractText } from '@llamaindex/core/utils'; import { Settings as Settings$1 } from '@llamaindex/core/global'; import { PromptHelper } from '@llamaindex/core/indices'; import { SentenceSplitter } from '@llamaindex/core/node-parser'; import { AsyncLocalStorage } from '@llamaindex/env'; function wrapChoice(choice) { if (typeof choice === "string") { return { description: choice }; } else { return choice; } } class BaseSelector extends PromptMixin { async select(choices, query) { const metadata = choices.map((choice)=>wrapChoice(choice)); return await this._select(metadata, query); } } /** * Error class for output parsing. Due to the nature of LLMs, anytime we use LLM * to generate structured output, it's possible that it will hallucinate something * that doesn't match the expected output format. So make sure to catch these * errors in production. */ class OutputParserError extends Error { constructor(message, options = {}){ super(message, options); // https://github.com/tc39/proposal-error-cause this.name = "OutputParserError"; if (!this.cause) { // Need to check for those environments that have implemented the proposal this.cause = options.cause; } this.output = options.output; // This line is to maintain proper stack trace in V8 // (https://v8.dev/docs/stack-trace-api) if (Error.captureStackTrace) { Error.captureStackTrace(this, OutputParserError); } } } /** * * @param text A markdown block with JSON * @returns parsed JSON object */ function parseJsonMarkdown(text) { text = text.trim(); const left_square = text.indexOf("["); const left_brace = text.indexOf("{"); let left; let right; if (left_square < left_brace && left_square != -1) { left = left_square; right = text.lastIndexOf("]"); } else { left = left_brace; right = text.lastIndexOf("}"); } const jsonText = text.substring(left, right + 1); try { //Single JSON object case if (left_square === -1) { return [ JSON.parse(jsonText) ]; } //Multiple JSON object case. return JSON.parse(jsonText); } catch (e) { throw new OutputParserError("Not a json markdown", { output: text }); } } const formatStr = `The output should be ONLY JSON formatted as a JSON instance. Here is an example: [ { "choice": 1, "reason": "<insert reason for choice>" }, ... ] `; /* * An OutputParser is used to extract structured data from the raw output of the LLM. */ class SelectionOutputParser { /** * * @param output */ parse(output) { let parsed; try { parsed = parseJsonMarkdown(output); } catch (e) { try { parsed = JSON.parse(output); } catch (e) { throw new Error(`Got invalid JSON object. Error: ${e}. Got JSON string: ${output}`); } } return { rawOutput: output, parsedOutput: parsed }; } format(output) { return output + "\n\n" + formatStr; } } const defaultSingleSelectPrompt = new PromptTemplate({ templateVars: [ "context", "query", "numChoices" ], template: `Some choices are given below. It is provided in a numbered list (1 to {numChoices}), where each item in the list corresponds to a summary. --------------------- {context} --------------------- Using only the choices above and not prior knowledge, return the choice that is most relevant to the question: '{query}' ` }); const defaultMultiSelectPrompt = new PromptTemplate({ templateVars: [ "contextList", "query", "maxOutputs", "numChoices" ], template: `Some choices are given below. It is provided in a numbered list (1 to {numChoices}), where each item in the list corresponds to a summary. --------------------- {contextList} --------------------- Using only the choices above and not prior knowledge, return the top choices (no more than {maxOutputs}, but only select what is needed) that are most relevant to the question: '{query}' ` }); function buildChoicesText(choices) { const texts = []; for (const [ind, choice] of choices.entries()){ let text = choice.description.split("\n").join(" "); text = `(${ind + 1}) ${text}`; // to one indexing texts.push(text); } return texts.join(""); } function structuredOutputToSelectorResult(output) { const structuredOutput = output; const answers = structuredOutput.parsedOutput; // adjust for zero indexing const selections = answers.map((answer)=>{ return { index: answer.choice - 1, reason: answer.reason }; }); return { selections }; } /** * A selector that uses the LLM to select a single or multiple choices from a list of choices. */ class LLMMultiSelector extends BaseSelector { constructor(init){ super(); this.llm = init.llm; this.prompt = init.prompt ?? defaultMultiSelectPrompt; this.maxOutputs = init.maxOutputs ?? 10; this.outputParser = init.outputParser ?? new SelectionOutputParser(); } _getPrompts() { return { prompt: this.prompt }; } _updatePrompts(prompts) { if ("prompt" in prompts) { this.prompt = prompts.prompt; } } _getPromptModules() { throw new Error("Method not implemented."); } /** * Selects a single choice from a list of choices. * @param choices * @param query */ async _select(choices, query) { const choicesText = buildChoicesText(choices); const prompt = this.prompt.format({ contextList: choicesText, query: extractText(query.query), maxOutputs: `${this.maxOutputs}`, numChoices: `${choicesText.length}` }); const formattedPrompt = this.outputParser?.format(prompt); const prediction = await this.llm.complete({ prompt: formattedPrompt }); const parsed = this.outputParser?.parse(prediction.text); if (!parsed) { throw new Error("Parsed output is undefined"); } return structuredOutputToSelectorResult(parsed); } asQueryComponent() { throw new Error("Method not implemented."); } } /** * A selector that uses the LLM to select a single choice from a list of choices. */ class LLMSingleSelector extends BaseSelector { constructor(init){ super(); this.llm = init.llm; this.prompt = init.prompt ?? defaultSingleSelectPrompt; this.outputParser = init.outputParser ?? new SelectionOutputParser(); } _getPrompts() { return { prompt: this.prompt }; } _updatePrompts(prompts) { if ("prompt" in prompts) { this.prompt = prompts.prompt; } } /** * Selects a single choice from a list of choices. * @param choices * @param query */ async _select(choices, query) { const choicesText = buildChoicesText(choices); const prompt = this.prompt.format({ numChoices: `${choicesText.length}`, context: choicesText, query: extractText(query) }); const formattedPrompt = this.outputParser.format(prompt); const prediction = await this.llm.complete({ prompt: formattedPrompt }); const parsed = this.outputParser?.parse(prediction.text); if (!parsed) { throw new Error("Parsed output is undefined"); } return structuredOutputToSelectorResult(parsed); } asQueryComponent() { throw new Error("Method not implemented."); } _getPromptModules() { return {}; } } /** * @internal */ class GlobalSettings { #prompt; #promptHelper; #nodeParser; #chunkOverlap; #promptHelperAsyncLocalStorage; #nodeParserAsyncLocalStorage; #chunkOverlapAsyncLocalStorage; #promptAsyncLocalStorage; get debug() { return Settings$1.debug; } get llm() { return Settings$1.llm; } set llm(llm) { Settings$1.llm = llm; } withLLM(llm, fn) { return Settings$1.withLLM(llm, fn); } get promptHelper() { if (this.#promptHelper === null) { this.#promptHelper = new PromptHelper(); } return this.#promptHelperAsyncLocalStorage.getStore() ?? this.#promptHelper; } set promptHelper(promptHelper) { this.#promptHelper = promptHelper; } withPromptHelper(promptHelper, fn) { return this.#promptHelperAsyncLocalStorage.run(promptHelper, fn); } get embedModel() { return Settings$1.embedModel; } set embedModel(embedModel) { Settings$1.embedModel = embedModel; } withEmbedModel(embedModel, fn) { return Settings$1.withEmbedModel(embedModel, fn); } get nodeParser() { if (this.#nodeParser === null) { this.#nodeParser = new SentenceSplitter({ chunkSize: this.chunkSize, chunkOverlap: this.chunkOverlap }); } return this.#nodeParserAsyncLocalStorage.getStore() ?? this.#nodeParser; } set nodeParser(nodeParser) { this.#nodeParser = nodeParser; } withNodeParser(nodeParser, fn) { return this.#nodeParserAsyncLocalStorage.run(nodeParser, fn); } get callbackManager() { return Settings$1.callbackManager; } set callbackManager(callbackManager) { Settings$1.callbackManager = callbackManager; } withCallbackManager(callbackManager, fn) { return Settings$1.withCallbackManager(callbackManager, fn); } set chunkSize(chunkSize) { Settings$1.chunkSize = chunkSize; } get chunkSize() { return Settings$1.chunkSize; } withChunkSize(chunkSize, fn) { return Settings$1.withChunkSize(chunkSize, fn); } get chunkOverlap() { return this.#chunkOverlapAsyncLocalStorage.getStore() ?? this.#chunkOverlap; } set chunkOverlap(chunkOverlap) { if (typeof chunkOverlap === "number") { this.#chunkOverlap = chunkOverlap; } } withChunkOverlap(chunkOverlap, fn) { return this.#chunkOverlapAsyncLocalStorage.run(chunkOverlap, fn); } get prompt() { return this.#promptAsyncLocalStorage.getStore() ?? this.#prompt; } set prompt(prompt) { this.#prompt = prompt; } withPrompt(prompt, fn) { return this.#promptAsyncLocalStorage.run(prompt, fn); } constructor(){ this.#prompt = {}; this.#promptHelper = null; this.#nodeParser = null; this.#promptHelperAsyncLocalStorage = new AsyncLocalStorage(); this.#nodeParserAsyncLocalStorage = new AsyncLocalStorage(); this.#chunkOverlapAsyncLocalStorage = new AsyncLocalStorage(); this.#promptAsyncLocalStorage = new AsyncLocalStorage(); } } const Settings = new GlobalSettings(); const getSelectorFromContext = (isMulti = false)=>{ let selector = null; const llm = Settings.llm; if (isMulti) { selector = new LLMMultiSelector({ llm }); } else { selector = new LLMSingleSelector({ llm }); } if (selector === null) { throw new Error("Selector is null"); } return selector; }; export { BaseSelector, LLMMultiSelector, LLMSingleSelector, getSelectorFromContext };