@roadiehq/rag-ai-backend
Version:
149 lines (143 loc) • 4.5 kB
JavaScript
;
class RagAiController {
static instance;
llmService;
augmentationIndexer;
retrievalPipeline;
logger;
constructor(logger, llmService, augmentationIndexer, retrievalPipeline) {
this.logger = logger;
this.llmService = llmService;
this.augmentationIndexer = augmentationIndexer;
this.retrievalPipeline = retrievalPipeline;
}
static getInstance({
logger,
llmService,
augmentationIndexer,
retrievalPipeline
}) {
if (!RagAiController.instance) {
RagAiController.instance = new RagAiController(
logger,
llmService,
augmentationIndexer,
retrievalPipeline
);
}
return RagAiController.instance;
}
createEmbeddings = async (req, res) => {
const source = req.params.source;
const entityFilter = req.body.entityFilter;
this.logger.info(`Creating embeddings for source ${source}`);
const amountOfEmbeddings = await this.augmentationIndexer.createEmbeddings(
source,
entityFilter
);
return res.status(200).send({
response: `${amountOfEmbeddings} embeddings created for source ${source}, for entities with filter ${JSON.stringify(
entityFilter
)}`
});
};
getEmbeddings = async (req, res) => {
if (!this.retrievalPipeline) {
return res.status(500).send({
message: "No retrieval pipeline configured for this AI backend. "
});
}
const source = req.params.source;
const query = req.query.query;
const entityFilter = req.body.entityFilter;
const response = await this.retrievalPipeline.retrieveAugmentationContext(
query,
source,
entityFilter
);
return res.status(200).send({ response });
};
deleteEmbeddings = async (req, res) => {
const source = req.params.source;
const entityFilter = req.body.entityFilter;
await this.augmentationIndexer.deleteEmbeddings(source, entityFilter);
return res.status(201).send({ response: `Embeddings deleted for source ${source}` });
};
query = async (req, res) => {
const source = req.params.source;
const query = req.body.query;
const entityFilter = req.body.entityFilter;
res.writeHead(200, {
"Content-Type": "text/event-stream",
Connection: "keep-alive",
"Cache-Control": "no-cache"
});
try {
const embeddingDocs = this.retrievalPipeline ? await this.retrievalPipeline.retrieveAugmentationContext(
query,
source,
entityFilter
) : [];
const embeddingsEvent = `event: embeddings
`;
const embeddingsData = `data: ${JSON.stringify(embeddingDocs)}
`;
res.write(embeddingsEvent + embeddingsData);
const stream = await this.llmService.query(embeddingDocs, query);
const usage = { input_tokens: 0, output_tokens: 0, total_tokens: 0 };
for await (const chunk of stream) {
if (typeof chunk !== "string" && "usage_metadata" in chunk) {
usage.input_tokens += chunk.usage_metadata?.input_tokens ?? 0;
usage.output_tokens += chunk.usage_metadata?.output_tokens ?? 0;
usage.total_tokens += chunk.usage_metadata?.total_tokens ?? 0;
}
const text = typeof chunk === "string" ? chunk : chunk.content;
const event = `event: response
`;
const data = this.parseSseText(text);
res.write(event + data);
res.flush?.();
}
if (Object.values(usage).some((it) => it !== 0)) {
this.logger.info(
`Produced response with token usage: ${JSON.stringify(usage)}`
);
res.write(`event: usage
data: ${JSON.stringify(usage)}
`);
} else {
this.logger.info(
`Unable to retrieve token usage information from this model invocation.`
);
res.write(
`event: usage
data: ${JSON.stringify({
input_tokens: -1,
output_tokens: -1,
total_tokens: -1
})}
`
);
}
} catch (e) {
this.logger.error(
`There was an error executing query ${query} for source ${source} on entity ${entityFilter}: ${e.message}`,
e
);
throw e;
}
res.end();
};
parseSseText = (text) => {
const lines = text.split("\n");
const output = lines.reduce((result, line) => {
const data = `data: ${line}
`;
return result + data;
}, "");
return `${output}
`;
};
}
exports.RagAiController = RagAiController;
//# sourceMappingURL=RagAiController.cjs.js.map