@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
204 lines (173 loc) • 6.42 kB
text/typescript
import { TRPCError } from '@trpc/server';
import { inArray } from 'drizzle-orm/expressions';
import { z } from 'zod';
import { DEFAULT_FILE_EMBEDDING_MODEL_ITEM } from '@/const/settings/knowledge';
import { AsyncTaskModel } from '@/database/models/asyncTask';
import { ChunkModel } from '@/database/models/chunk';
import { EmbeddingModel } from '@/database/models/embedding';
import { FileModel } from '@/database/models/file';
import { MessageModel } from '@/database/models/message';
import { knowledgeBaseFiles } from '@/database/schemas';
import { authedProcedure, router } from '@/libs/trpc/lambda';
import { keyVaults, serverDatabase } from '@/libs/trpc/lambda/middleware';
import { getServerDefaultFilesConfig } from '@/server/globalConfig';
import { initAgentRuntimeWithUserPayload } from '@/server/modules/AgentRuntime';
import { ChunkService } from '@/server/services/chunk';
import { SemanticSearchSchema } from '@/types/rag';
const chunkProcedure = authedProcedure
.use(serverDatabase)
.use(keyVaults)
.use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
chunkService: new ChunkService(ctx.userId),
embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId),
fileModel: new FileModel(ctx.serverDB, ctx.userId),
messageModel: new MessageModel(ctx.serverDB, ctx.userId),
},
});
});
export const chunkRouter = router({
createEmbeddingChunksTask: chunkProcedure
.input(
z.object({
id: z.string(),
}),
)
.mutation(async ({ ctx, input }) => {
const asyncTaskId = await ctx.chunkService.asyncEmbeddingFileChunks(input.id, ctx.jwtPayload);
return { id: asyncTaskId, success: true };
}),
createParseFileTask: chunkProcedure
.input(
z.object({
id: z.string(),
skipExist: z.boolean().optional(),
}),
)
.mutation(async ({ ctx, input }) => {
const asyncTaskId = await ctx.chunkService.asyncParseFileToChunks(
input.id,
ctx.jwtPayload,
input.skipExist,
);
return { id: asyncTaskId, success: true };
}),
getChunksByFileId: chunkProcedure
.input(
z.object({
cursor: z.number().nullish(),
id: z.string(),
}),
)
.query(async ({ ctx, input }) => {
return {
items: await ctx.chunkModel.findByFileId(input.id, input.cursor || 0),
nextCursor: input.cursor ? input.cursor + 1 : 1,
};
}),
retryParseFileTask: chunkProcedure
.input(
z.object({
id: z.string(),
}),
)
.mutation(async ({ ctx, input }) => {
const result = await ctx.fileModel.findById(input.id);
if (!result) return;
// 1. delete the previous task if exist
if (result.chunkTaskId) {
await ctx.asyncTaskModel.delete(result.chunkTaskId);
}
// 2. create a new asyncTask for chunking
const asyncTaskId = await ctx.chunkService.asyncParseFileToChunks(input.id, ctx.jwtPayload);
return { id: asyncTaskId, success: true };
}),
semanticSearch: chunkProcedure
.input(
z.object({
fileIds: z.array(z.string()).optional(),
query: z.string(),
}),
)
.mutation(async ({ ctx, input }) => {
const { model, provider } =
getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM;
const agentRuntime = await initAgentRuntimeWithUserPayload(provider, ctx.jwtPayload);
const embeddings = await agentRuntime.embeddings({
dimensions: 1024,
input: input.query,
model,
});
console.timeEnd('embedding');
return ctx.chunkModel.semanticSearch({
embedding: embeddings![0],
fileIds: input.fileIds,
query: input.query,
});
}),
semanticSearchForChat: chunkProcedure
.input(SemanticSearchSchema)
.mutation(async ({ ctx, input }) => {
try {
const item = await ctx.messageModel.findMessageQueriesById(input.messageId);
const { model, provider } =
getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM;
let embedding: number[];
let ragQueryId: string;
// if there is no message rag or it's embeddings, then we need to create one
if (!item || !item.embeddings) {
// TODO: need to support customize
const agentRuntime = await initAgentRuntimeWithUserPayload(provider, ctx.jwtPayload);
// slice content to make sure in the context window limit
const query =
input.rewriteQuery.length > 8000
? input.rewriteQuery.slice(0, 8000)
: input.rewriteQuery;
const embeddings = await agentRuntime.embeddings({
dimensions: 1024,
input: query,
model,
});
embedding = embeddings![0];
const embeddingsId = await ctx.embeddingModel.create({
embeddings: embedding,
model,
});
const result = await ctx.messageModel.createMessageQuery({
embeddingsId,
messageId: input.messageId,
rewriteQuery: input.rewriteQuery,
userQuery: input.userQuery,
});
ragQueryId = result.id;
} else {
embedding = item.embeddings;
ragQueryId = item.id;
}
let finalFileIds = input.fileIds ?? [];
if (input.knowledgeIds && input.knowledgeIds.length > 0) {
const knowledgeFiles = await ctx.serverDB.query.knowledgeBaseFiles.findMany({
where: inArray(knowledgeBaseFiles.knowledgeBaseId, input.knowledgeIds),
});
finalFileIds = knowledgeFiles.map((f) => f.fileId).concat(finalFileIds);
}
const chunks = await ctx.chunkModel.semanticSearchForChat({
embedding,
fileIds: finalFileIds,
query: input.rewriteQuery,
});
// TODO: need to rerank the chunks
return { chunks, queryId: ragQueryId };
} catch (e) {
console.error(e);
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: (e as any).errorType || JSON.stringify(e),
});
}
}),
});