UNPKG

mongodb-chatbot-server

Version:

A chatbot server for retrieval augmented generation (RAG).

381 lines 15.2 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.convertMessageFromLlmToDb = exports.streamGenerateResponseMessage = exports.awaitGenerateResponseMessage = exports.generateResponse = void 0; const mongodb_rag_core_1 = require("mongodb-rag-core"); const utils_1 = require("../utils"); const assert_1 = require("assert"); /** Generate a response with/without streaming. Supports tool calling and standard response generation. Response includes the user message with any data mutations and the assistant response message, plus any intermediate tool calls. */ async function generateResponse({ shouldStream, llm, latestMessageText, clientContext, customData, generateUserPrompt, filterPreviousMessages, dataStreamer, reqId, llmNotWorkingMessage, noRelevantContentMessage, conversation, request, }) { const { userMessage, references, staticResponse, rejectQuery } = await (generateUserPrompt ? generateUserPrompt({ userMessageText: latestMessageText, clientContext, conversation, reqId, customData, }) : { userMessage: { role: "user", content: latestMessageText, customData, }, }); // Add request custom data to user message. const userMessageWithCustomData = customData ? { ...userMessage, // Override request custom data fields with user message custom data fields. customData: { ...customData, ...(userMessage.customData ?? {}) }, } : userMessage; const newMessages = [userMessageWithCustomData]; // Metadata for streaming let streamingResponseMetadata; // Send static response if query rejected or static response provided if (rejectQuery) { const rejectionMessage = { role: "assistant", content: noRelevantContentMessage, references: references ?? [], }; newMessages.push(rejectionMessage); } else if (staticResponse) { newMessages.push(staticResponse); // Need to specify response metadata for streaming streamingResponseMetadata = staticResponse.metadata; } // Prepare conversation messages for LLM const previousConversationMessagesForLlm = (filterPreviousMessages ? await filterPreviousMessages(conversation) : conversation.messages).map(convertConversationMessageToLlmMessage); const newMessagesForLlm = newMessages.map((m) => { // Use transformed content if it exists for user message // (e.g. from a custom user prompt, query preprocessor, etc), // otherwise use original content. if (m.role === "user") { return { content: m.contentForLlm ?? m.content, role: "user", }; } return convertConversationMessageToLlmMessage(m); }); const llmConversation = [ ...previousConversationMessagesForLlm, ...newMessagesForLlm, ]; const shouldGenerateMessage = !rejectQuery && !staticResponse; if (shouldStream) { (0, assert_1.strict)(dataStreamer, "Data streamer required for streaming"); const { messages } = await streamGenerateResponseMessage({ dataStreamer, reqId, llm, llmConversation, noRelevantContentMessage, llmNotWorkingMessage, request, shouldGenerateMessage, conversation, references, metadata: streamingResponseMetadata, }); newMessages.push(...messages); } else { const { messages } = await awaitGenerateResponseMessage({ reqId, llm, llmConversation, llmNotWorkingMessage, noRelevantContentMessage, request, shouldGenerateMessage, conversation, references, }); newMessages.push(...messages); } return { messages: newMessages }; } exports.generateResponse = generateResponse; async function awaitGenerateResponseMessage({ reqId, llmConversation, llm, llmNotWorkingMessage, noRelevantContentMessage, request, references, conversation, shouldGenerateMessage = true, }) { const newMessages = []; const outputReferences = []; if (references) { outputReferences.push(...references); } if (shouldGenerateMessage) { try { (0, utils_1.logRequest)({ reqId, message: `All messages for LLM: ${JSON.stringify(llmConversation)}`, }); const answer = await llm.answerQuestionAwaited({ messages: llmConversation, }); newMessages.push(convertMessageFromLlmToDb(answer)); // LLM responds with tool call if (answer?.function_call) { (0, assert_1.strict)(llm.callTool, "You must implement the callTool() method on your ChatLlm to access this code."); const toolAnswer = await llm.callTool({ messages: [...llmConversation, ...newMessages], conversation, request, }); (0, utils_1.logRequest)({ reqId, message: `LLM tool call: ${JSON.stringify(toolAnswer)}`, }); const { toolCallMessage, references: toolReferences, rejectUserQuery, } = toolAnswer; newMessages.push(convertMessageFromLlmToDb(toolCallMessage)); // Update references from tool call if (toolReferences) { outputReferences.push(...toolReferences); } // Return static response if query rejected by tool call if (rejectUserQuery) { newMessages.push({ role: "assistant", content: noRelevantContentMessage, }); } else { // Otherwise respond with LLM again const answer = await llm.answerQuestionAwaited({ messages: [...llmConversation, ...newMessages], // Only allow 1 tool call per user message. }); newMessages.push(convertMessageFromLlmToDb(answer)); } } } catch (err) { const errorMessage = err instanceof Error ? err.message : JSON.stringify(err); (0, utils_1.logRequest)({ reqId, message: `LLM error: ${errorMessage}`, type: "error", }); (0, utils_1.logRequest)({ reqId, message: "Only sending vector search results to user", }); const llmNotWorkingResponse = { role: "assistant", content: llmNotWorkingMessage, references, }; newMessages.push(llmNotWorkingResponse); } } // Add references to the last assistant message (excluding function calls) if (newMessages.at(-1)?.role === "assistant" && !newMessages.at(-1).functionCall && outputReferences.length > 0) { newMessages.at(-1).references = outputReferences; } return { messages: newMessages }; } exports.awaitGenerateResponseMessage = awaitGenerateResponseMessage; async function streamGenerateResponseMessage({ dataStreamer, llm, llmConversation, reqId, references, noRelevantContentMessage, llmNotWorkingMessage, conversation, request, metadata, shouldGenerateMessage, }) { const newMessages = []; const outputReferences = []; if (references) { outputReferences.push(...references); } if (metadata) { dataStreamer.streamData({ type: "metadata", data: metadata }); } if (shouldGenerateMessage) { try { const answerStream = await llm.answerQuestionStream({ messages: llmConversation, }); const initialAssistantMessage = { role: "assistant", content: "", }; const functionCallContent = { name: "", arguments: "", }; for await (const event of answerStream) { if (event.choices.length === 0) { continue; } // The event could contain many choices, but we only want the first one const choice = event.choices[0]; // Assistant response to user if (choice.delta?.content) { const content = (0, mongodb_rag_core_1.escapeNewlines)(choice.delta.content ?? ""); dataStreamer.streamData({ type: "delta", data: content, }); initialAssistantMessage.content += content; } // Tool call else if (choice.delta?.function_call) { if (choice.delta?.function_call.name) { functionCallContent.name += (0, mongodb_rag_core_1.escapeNewlines)(choice.delta?.function_call.name ?? ""); } if (choice.delta?.function_call.arguments) { functionCallContent.arguments += (0, mongodb_rag_core_1.escapeNewlines)(choice.delta?.function_call.arguments ?? ""); } } else if (choice.delta) { (0, utils_1.logRequest)({ reqId, message: `Unexpected message in stream: no delta. Message: ${JSON.stringify(choice.delta.content)}`, type: "warn", }); } } const shouldCallTool = functionCallContent.name !== ""; if (shouldCallTool) { initialAssistantMessage.functionCall = functionCallContent; } newMessages.push(initialAssistantMessage); (0, utils_1.logRequest)({ reqId, message: `LLM response: ${JSON.stringify(initialAssistantMessage)}`, }); // Tool call if (shouldCallTool) { (0, assert_1.strict)(llm.callTool, "You must implement the callTool() method on your ChatLlm to access this code."); const { toolCallMessage, references: toolReferences, rejectUserQuery, } = await llm.callTool({ messages: [...llmConversation, ...newMessages], conversation, dataStreamer, request, }); newMessages.push(convertMessageFromLlmToDb(toolCallMessage)); if (rejectUserQuery) { newMessages.push({ role: "assistant", content: noRelevantContentMessage, }); dataStreamer.streamData({ type: "delta", data: noRelevantContentMessage, }); } else { if (toolReferences) { outputReferences.push(...toolReferences); } const answerStream = await llm.answerQuestionStream({ messages: [...llmConversation, ...newMessages], }); const answerContent = await dataStreamer.stream({ stream: answerStream, }); const answerMessage = { role: "assistant", content: answerContent, }; newMessages.push(answerMessage); } } } catch (err) { const errorMessage = err instanceof Error ? err.message : JSON.stringify(err); (0, utils_1.logRequest)({ reqId, message: `LLM error: ${errorMessage}`, type: "error", }); (0, utils_1.logRequest)({ reqId, message: "Only sending vector search results to user", }); const llmNotWorkingResponse = { role: "assistant", content: llmNotWorkingMessage, }; dataStreamer.streamData({ type: "delta", data: llmNotWorkingMessage, }); newMessages.push(llmNotWorkingResponse); } } // Handle streaming static message response else { const staticMessage = llmConversation.at(-1); (0, assert_1.strict)(staticMessage?.content, "No static message content"); (0, assert_1.strict)(staticMessage.role === "assistant", "Static message not assistant"); (0, utils_1.logRequest)({ reqId, message: `Sending static message to user: ${staticMessage.content}`, type: "warn", }); dataStreamer.streamData({ type: "delta", data: staticMessage.content, }); } // Add references to the last assistant message if (newMessages.at(-1)?.role === "assistant" && outputReferences.length > 0) { newMessages.at(-1).references = outputReferences; } if (outputReferences.length > 0) { // Stream back references dataStreamer.streamData({ type: "references", data: outputReferences, }); } return { messages: newMessages.map(convertMessageFromLlmToDb) }; } exports.streamGenerateResponseMessage = streamGenerateResponseMessage; function convertMessageFromLlmToDb(message) { const dbMessage = { ...message, content: message?.content ?? "", }; if (message.role === "assistant" && message.function_call) { dbMessage.functionCall = message.function_call; } return dbMessage; } exports.convertMessageFromLlmToDb = convertMessageFromLlmToDb; function convertConversationMessageToLlmMessage(message) { const { content, role } = message; if (role === "system") { return { content: content, role: "system", }; } if (role === "function") { return { content: content, role: "function", name: message.name, }; } if (role === "user") { return { content: content, role: "user", }; } if (role === "assistant") { return { content: content, role: "assistant", ...(message.functionCall ? { function_call: message.functionCall } : {}), }; } throw new Error(`Invalid message role: ${role}`); } //# sourceMappingURL=generateResponse.js.map