UNPKG

@roadiehq/rag-ai-backend

Version:

149 lines (143 loc) 4.5 kB
'use strict'; class RagAiController { static instance; llmService; augmentationIndexer; retrievalPipeline; logger; constructor(logger, llmService, augmentationIndexer, retrievalPipeline) { this.logger = logger; this.llmService = llmService; this.augmentationIndexer = augmentationIndexer; this.retrievalPipeline = retrievalPipeline; } static getInstance({ logger, llmService, augmentationIndexer, retrievalPipeline }) { if (!RagAiController.instance) { RagAiController.instance = new RagAiController( logger, llmService, augmentationIndexer, retrievalPipeline ); } return RagAiController.instance; } createEmbeddings = async (req, res) => { const source = req.params.source; const entityFilter = req.body.entityFilter; this.logger.info(`Creating embeddings for source ${source}`); const amountOfEmbeddings = await this.augmentationIndexer.createEmbeddings( source, entityFilter ); return res.status(200).send({ response: `${amountOfEmbeddings} embeddings created for source ${source}, for entities with filter ${JSON.stringify( entityFilter )}` }); }; getEmbeddings = async (req, res) => { if (!this.retrievalPipeline) { return res.status(500).send({ message: "No retrieval pipeline configured for this AI backend. " }); } const source = req.params.source; const query = req.query.query; const entityFilter = req.body.entityFilter; const response = await this.retrievalPipeline.retrieveAugmentationContext( query, source, entityFilter ); return res.status(200).send({ response }); }; deleteEmbeddings = async (req, res) => { const source = req.params.source; const entityFilter = req.body.entityFilter; await this.augmentationIndexer.deleteEmbeddings(source, entityFilter); return res.status(201).send({ response: `Embeddings deleted for source ${source}` }); }; query = async (req, res) => { const source = req.params.source; const query = req.body.query; const entityFilter = req.body.entityFilter; res.writeHead(200, { "Content-Type": "text/event-stream", Connection: "keep-alive", "Cache-Control": "no-cache" }); try { const embeddingDocs = this.retrievalPipeline ? await this.retrievalPipeline.retrieveAugmentationContext( query, source, entityFilter ) : []; const embeddingsEvent = `event: embeddings `; const embeddingsData = `data: ${JSON.stringify(embeddingDocs)} `; res.write(embeddingsEvent + embeddingsData); const stream = await this.llmService.query(embeddingDocs, query); const usage = { input_tokens: 0, output_tokens: 0, total_tokens: 0 }; for await (const chunk of stream) { if (typeof chunk !== "string" && "usage_metadata" in chunk) { usage.input_tokens += chunk.usage_metadata?.input_tokens ?? 0; usage.output_tokens += chunk.usage_metadata?.output_tokens ?? 0; usage.total_tokens += chunk.usage_metadata?.total_tokens ?? 0; } const text = typeof chunk === "string" ? chunk : chunk.content; const event = `event: response `; const data = this.parseSseText(text); res.write(event + data); res.flush?.(); } if (Object.values(usage).some((it) => it !== 0)) { this.logger.info( `Produced response with token usage: ${JSON.stringify(usage)}` ); res.write(`event: usage data: ${JSON.stringify(usage)} `); } else { this.logger.info( `Unable to retrieve token usage information from this model invocation.` ); res.write( `event: usage data: ${JSON.stringify({ input_tokens: -1, output_tokens: -1, total_tokens: -1 })} ` ); } } catch (e) { this.logger.error( `There was an error executing query ${query} for source ${source} on entity ${entityFilter}: ${e.message}`, e ); throw e; } res.end(); }; parseSseText = (text) => { const lines = text.split("\n"); const output = lines.reduce((result, line) => { const data = `data: ${line} `; return result + data; }, ""); return `${output} `; }; } exports.RagAiController = RagAiController; //# sourceMappingURL=RagAiController.cjs.js.map