UNPKG

dtamind-components

Version:

DTAmindai Components

106 lines 4.63 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); const chains_1 = require("langchain/chains"); const utils_1 = require("../../../src/utils"); const handler_1 = require("../../../src/handler"); const Moderation_1 = require("../../moderation/Moderation"); const OutputParserHelpers_1 = require("../../outputparsers/OutputParserHelpers"); class MultiRetrievalQAChain_Chains { constructor() { this.label = 'Multi Retrieval QA Chain'; this.name = 'multiRetrievalQAChain'; this.version = 2.0; this.badge = 'DEPRECATING'; this.type = 'MultiRetrievalQAChain'; this.icon = 'qa.svg'; this.category = 'Chains'; this.description = 'QA Chain that automatically picks an appropriate vector store from multiple retrievers'; this.baseClasses = [this.type, ...(0, utils_1.getBaseClasses)(chains_1.MultiRetrievalQAChain)]; this.inputs = [ { label: 'Language Model', name: 'model', type: 'BaseLanguageModel' }, { label: 'Vector Store Retriever', name: 'vectorStoreRetriever', type: 'VectorStoreRetriever', list: true }, { label: 'Return Source Documents', name: 'returnSourceDocuments', type: 'boolean', optional: true }, { label: 'Input Moderation', description: 'Detect text that could generate harmful output and prevent it from being sent to the language model', name: 'inputModeration', type: 'Moderation', optional: true, list: true } ]; } async init(nodeData) { const model = nodeData.inputs?.model; const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever; const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments; const retrieverNames = []; const retrieverDescriptions = []; const retrievers = []; for (const vs of vectorStoreRetriever) { retrieverNames.push(vs.name); retrieverDescriptions.push(vs.description); retrievers.push(vs.vectorStore.asRetriever(vs.vectorStore.k ?? 4)); } const chain = chains_1.MultiRetrievalQAChain.fromLLMAndRetrievers(model, { retrieverNames, retrieverDescriptions, retrievers, retrievalQAChainOpts: { verbose: process.env.DEBUG === 'true' ? true : false, returnSourceDocuments } }); return chain; } async run(nodeData, input, options) { const chain = nodeData.instance; const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments; const moderations = nodeData.inputs?.inputModeration; const shouldStreamResponse = options.shouldStreamResponse; const sseStreamer = options.sseStreamer; const chatId = options.chatId; if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the Multi Retrieval QA Chain input = await (0, Moderation_1.checkInputs)(moderations, input); } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)); if (options.shouldStreamResponse) { (0, Moderation_1.streamResponse)(options.sseStreamer, options.chatId, e.message); } return (0, OutputParserHelpers_1.formatResponse)(e.message); } } const obj = { input }; const loggerHandler = new handler_1.ConsoleCallbackHandler(options.logger, options?.orgId); const callbacks = await (0, handler_1.additionalCallbacks)(nodeData, options); if (shouldStreamResponse) { const handler = new handler_1.CustomChainHandler(sseStreamer, chatId, 2, returnSourceDocuments); const res = await chain.call(obj, [loggerHandler, handler, ...callbacks]); if (res.text && res.sourceDocuments) return res; return res?.text; } else { const res = await chain.call(obj, [loggerHandler, ...callbacks]); if (res.text && res.sourceDocuments) return res; return res?.text; } } } module.exports = { nodeClass: MultiRetrievalQAChain_Chains }; //# sourceMappingURL=MultiRetrievalQAChain.js.map