@llamaindex/core
Version:
LlamaIndex Core Module
347 lines (342 loc) • 12.7 kB
JavaScript
import { randomUUID } from '@llamaindex/env';
import { Settings } from '../../global/dist/index.js';
import { PromptHelper, getBiggestPrompt } from '../../indices/dist/index.js';
import { PromptMixin, defaultTextQAPrompt, defaultRefinePrompt, defaultTreeSummarizePrompt } from '../../prompts/dist/index.js';
import { EngineResponse, MetadataMode, splitNodesByType, ModalityType, TextNode } from '../../schema/dist/index.js';
import { z } from 'zod';
import { imageToDataUrl, extractText, streamConverter } from '../../utils/dist/index.js';
class BaseSynthesizer extends PromptMixin {
constructor(options){
super();
this.llm = options.llm ?? Settings.llm;
this.promptHelper = options.promptHelper ?? PromptHelper.fromLLMMetadata(this.llm.metadata);
}
async synthesize(query, stream = false) {
const callbackManager = Settings.callbackManager;
const id = randomUUID();
callbackManager.dispatchEvent("synthesize-start", {
id,
query
});
let response;
if (query.nodes.length === 0) {
if (stream) {
response = EngineResponse.fromResponse("Empty Response", true, []);
} else {
response = EngineResponse.fromResponse("Empty Response", false, []);
}
} else {
const queryMessage = typeof query.query === "string" ? query.query : query.query.query;
response = await this.getResponse(queryMessage, query.nodes, stream);
}
callbackManager.dispatchEvent("synthesize-end", {
id,
query,
response
});
return response;
}
}
async function createContentPerModality(prompt, type, nodes, extraParams, metadataMode) {
switch(type){
case ModalityType.TEXT:
return [
{
type: "text",
text: prompt.format({
...extraParams,
context: nodes.map((r)=>r.getContent(metadataMode)).join("\n\n")
})
}
];
case ModalityType.IMAGE:
return Promise.all(nodes.map(async (node)=>{
return {
type: "image_url",
image_url: {
url: await imageToDataUrl(node.image)
}
};
}));
default:
return [];
}
}
async function createMessageContent(prompt, nodes, extraParams = {}, metadataMode = MetadataMode.NONE) {
const content = [];
const nodeMap = splitNodesByType(nodes);
for(const type in nodeMap){
// for each retrieved modality type, create message content
const nodes = nodeMap[type];
if (nodes) {
content.push(...await createContentPerModality(prompt, type, nodes, extraParams, metadataMode));
}
}
return content;
}
const responseModeSchema = z.enum([
"refine",
"compact",
"tree_summarize",
"multi_modal"
]);
/**
* A response builder that uses the query to ask the LLM generate a better response using multiple text chunks.
*/ class Refine extends BaseSynthesizer {
constructor(options){
super(options);
this.textQATemplate = options.textQATemplate ?? defaultTextQAPrompt;
this.refineTemplate = options.refineTemplate ?? defaultRefinePrompt;
}
_getPromptModules() {
return {};
}
_getPrompts() {
return {
textQATemplate: this.textQATemplate,
refineTemplate: this.refineTemplate
};
}
_updatePrompts(prompts) {
if (prompts.textQATemplate) {
this.textQATemplate = prompts.textQATemplate;
}
if (prompts.refineTemplate) {
this.refineTemplate = prompts.refineTemplate;
}
}
async getResponse(query, nodes, stream) {
let response = undefined;
const textChunks = nodes.map(({ node })=>node.getContent(MetadataMode.LLM));
for(let i = 0; i < textChunks.length; i++){
const text = textChunks[i];
const lastChunk = i === textChunks.length - 1;
if (!response) {
response = await this.giveResponseSingle(query, text, !!stream && lastChunk);
} else {
response = await this.refineResponseSingle(response, query, text, !!stream && lastChunk);
}
}
if (response === undefined) {
response = stream ? async function*() {
yield "";
}() : "";
}
if (typeof response === "string") {
return EngineResponse.fromResponse(response, false, nodes);
} else {
return streamConverter(response, (text)=>EngineResponse.fromResponse(text, true, nodes));
}
}
async giveResponseSingle(query, textChunk, stream) {
const textQATemplate = this.textQATemplate.partialFormat({
query: extractText(query)
});
const textChunks = this.promptHelper.repack(textQATemplate, [
textChunk
]);
let response = undefined;
for(let i = 0; i < textChunks.length; i++){
const chunk = textChunks[i];
const lastChunk = i === textChunks.length - 1;
if (!response) {
response = await this.complete({
prompt: textQATemplate.format({
context: chunk
}),
stream: stream && lastChunk
});
} else {
response = await this.refineResponseSingle(response, query, chunk, stream && lastChunk);
}
}
return response;
}
async refineResponseSingle(initialReponse, query, textChunk, stream) {
const refineTemplate = this.refineTemplate.partialFormat({
query: extractText(query)
});
const textChunks = this.promptHelper.repack(refineTemplate, [
textChunk
]);
let response = initialReponse;
for(let i = 0; i < textChunks.length; i++){
const chunk = textChunks[i];
const lastChunk = i === textChunks.length - 1;
response = await this.complete({
prompt: refineTemplate.format({
context: chunk,
existingAnswer: response
}),
stream: stream && lastChunk
});
}
return response;
}
async complete(params) {
if (params.stream) {
const response = await this.llm.complete({
...params,
stream: true
});
return streamConverter(response, (chunk)=>chunk.text);
} else {
const response = await this.llm.complete({
...params,
stream: false
});
return response.text;
}
}
}
/**
* CompactAndRefine is a slight variation of Refine that first compacts the text chunks into the smallest possible number of chunks.
*/ class CompactAndRefine extends Refine {
async getResponse(query, nodes, stream) {
const textQATemplate = this.textQATemplate.partialFormat({
query: extractText(query)
});
const refineTemplate = this.refineTemplate.partialFormat({
query: extractText(query)
});
const textChunks = nodes.map(({ node })=>node.getContent(MetadataMode.LLM));
const maxPrompt = getBiggestPrompt([
textQATemplate,
refineTemplate
]);
const newTexts = this.promptHelper.repack(maxPrompt, textChunks);
const newNodes = newTexts.map((text)=>new TextNode({
text
}));
if (stream) {
const streamResponse = await super.getResponse(query, newNodes.map((node)=>({
node
})), true);
return streamConverter(streamResponse, (chunk)=>{
chunk.sourceNodes = nodes;
return chunk;
});
}
const originalResponse = await super.getResponse(query, newNodes.map((node)=>({
node
})), false);
originalResponse.sourceNodes = nodes;
return originalResponse;
}
}
/**
* TreeSummarize repacks the text chunks into the smallest possible number of chunks and then summarizes them, then recursively does so until there's one chunk left.
*/ class TreeSummarize extends BaseSynthesizer {
constructor(options){
super(options);
this.summaryTemplate = options.summaryTemplate ?? defaultTreeSummarizePrompt;
}
_getPromptModules() {
return {};
}
_getPrompts() {
return {
summaryTemplate: this.summaryTemplate
};
}
_updatePrompts(prompts) {
if (prompts.summaryTemplate) {
this.summaryTemplate = prompts.summaryTemplate;
}
}
async getResponse(query, nodes, stream) {
const textChunks = nodes.map(({ node })=>node.getContent(MetadataMode.LLM));
if (!textChunks || textChunks.length === 0) {
throw new Error("Must have at least one text chunk");
}
// Should we send the query here too?
const packedTextChunks = this.promptHelper.repack(this.summaryTemplate, textChunks);
if (packedTextChunks.length === 1) {
const params = {
prompt: this.summaryTemplate.format({
context: packedTextChunks[0],
query: extractText(query)
})
};
if (stream) {
const response = await this.llm.complete({
...params,
stream
});
return streamConverter(response, (chunk)=>EngineResponse.fromResponse(chunk.text, true, nodes));
}
return EngineResponse.fromResponse((await this.llm.complete(params)).text, false, nodes);
} else {
const summaries = await Promise.all(packedTextChunks.map((chunk)=>this.llm.complete({
prompt: this.summaryTemplate.format({
context: chunk,
query: extractText(query)
})
})));
if (stream) {
return this.getResponse(query, summaries.map((s)=>({
node: new TextNode({
text: s.text
})
})), true);
}
return this.getResponse(query, summaries.map((s)=>({
node: new TextNode({
text: s.text
})
})), false);
}
}
}
class MultiModal extends BaseSynthesizer {
constructor({ textQATemplate, metadataMode, ...options } = {}){
super(options);
this.metadataMode = metadataMode ?? MetadataMode.NONE;
this.textQATemplate = textQATemplate ?? defaultTextQAPrompt;
}
_getPromptModules() {
return {};
}
_getPrompts() {
return {
textQATemplate: this.textQATemplate
};
}
_updatePrompts(promptsDict) {
if (promptsDict.textQATemplate) {
this.textQATemplate = promptsDict.textQATemplate;
}
}
async getResponse(query, nodes, stream) {
const prompt = await createMessageContent(this.textQATemplate, nodes.map(({ node })=>node), // this might not be good as this remove the image information
{
query: extractText(query)
}, this.metadataMode);
const llm = this.llm;
if (stream) {
const response = await llm.complete({
prompt,
stream
});
return streamConverter(response, ({ text })=>EngineResponse.fromResponse(text, true, nodes));
}
const response = await llm.complete({
prompt
});
return EngineResponse.fromResponse(response.text, false, nodes);
}
}
const modeToSynthesizer = {
compact: CompactAndRefine,
refine: Refine,
tree_summarize: TreeSummarize,
multi_modal: MultiModal
};
function getResponseSynthesizer(mode, options = {}) {
const Synthesizer = modeToSynthesizer[mode];
if (!Synthesizer) {
throw new Error(`Invalid response mode: ${mode}`);
}
return new Synthesizer(options);
}
export { BaseSynthesizer, CompactAndRefine, MultiModal, Refine, TreeSummarize, createMessageContent, getResponseSynthesizer, responseModeSchema };