UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

249 lines (217 loc) 8.55 kB
import debug from 'debug'; import { and, eq } from 'drizzle-orm'; import { z } from 'zod'; import { AsyncTaskModel } from '@/database/models/asyncTask'; import { NewGeneration, NewGenerationBatch, asyncTasks, generationBatches, generations, } from '@/database/schemas'; import { authedProcedure, router } from '@/libs/trpc/lambda'; import { keyVaults, serverDatabase } from '@/libs/trpc/lambda/middleware'; import { createAsyncCaller } from '@/server/routers/async/caller'; import { FileService } from '@/server/services/file'; import { AsyncTaskError, AsyncTaskErrorType, AsyncTaskStatus, AsyncTaskType, } from '@/types/asyncTask'; import { generateUniqueSeeds } from '@/utils/number'; const log = debug('lobe-image:lambda'); const imageProcedure = authedProcedure .use(keyVaults) .use(serverDatabase) .use(async (opts) => { const { ctx } = opts; const { apiKey } = ctx.jwtPayload; if (apiKey) { log('API key found in jwtPayload: %s', apiKey); } else { log('No API key found in jwtPayload'); } return opts.next({ ctx: { asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId), fileService: new FileService(ctx.serverDB, ctx.userId), }, }); }); const createImageInputSchema = z.object({ generationTopicId: z.string(), imageNum: z.number(), model: z.string(), params: z .object({ cfg: z.number().optional(), height: z.number().optional(), imageUrls: z.array(z.string()).optional(), prompt: z.string(), seed: z.number().nullable().optional(), steps: z.number().optional(), width: z.number().optional(), }) .passthrough(), provider: z.string(), }); export type CreateImageServicePayload = z.infer<typeof createImageInputSchema>; export const imageRouter = router({ createImage: imageProcedure.input(createImageInputSchema).mutation(async ({ input, ctx }) => { const { userId, serverDB, asyncTaskModel, fileService } = ctx; const { generationTopicId, provider, model, imageNum, params } = input; log('Starting image creation process, input: %O', input); // 如果 params 中包含 imageUrls,将它们转换为 S3 keys 用于数据库存储 let configForDatabase = { ...params }; if (Array.isArray(params.imageUrls) && params.imageUrls.length > 0) { log('Converting imageUrls to S3 keys for database storage: %O', params.imageUrls); try { const imageKeys = params.imageUrls.map((url) => { const key = fileService.getKeyFromFullUrl(url); log('Converted URL %s to key %s', url, key); return key; }); // 将转换后的 keys 存储为数据库配置 configForDatabase = { ...params, imageUrls: imageKeys, }; log('Successfully converted imageUrls to keys for database: %O', imageKeys); } catch (error) { log('Error converting imageUrls to keys: %O', error); // 如果转换失败,保持原始 URLs(可能是本地文件或其他格式) log('Keeping original imageUrls due to conversion error'); } } // 步骤 1: 在事务中原子性地创建所有数据库记录 const { batch: createdBatch, generationsWithTasks } = await serverDB.transaction(async (tx) => { log('Starting database transaction for image generation'); // 1. 创建 generationBatch const newBatch: NewGenerationBatch = { config: configForDatabase, generationTopicId, height: params.height, model, prompt: params.prompt, provider, userId, width: params.width, // 使用转换后的配置存储到数据库 }; log('Creating generation batch: %O', newBatch); const [batch] = await tx.insert(generationBatches).values(newBatch).returning(); log('Generation batch created successfully: %s', batch.id); // 2. 创建 4 个 generation(一期固定生成 4 张) const seeds = 'seed' in params ? generateUniqueSeeds(imageNum) : Array.from({ length: imageNum }, () => null); const newGenerations: NewGeneration[] = Array.from({ length: imageNum }, (_, index) => { return { generationBatchId: batch.id, seed: seeds[index], userId, }; }); log('Creating %d generations for batch: %s', newGenerations.length, batch.id); const createdGenerations = await tx.insert(generations).values(newGenerations).returning(); log( 'Generations created successfully: %O', createdGenerations.map((g) => g.id), ); // 3. 并发为每个 generation 创建 asyncTask(在事务中) log('Creating async tasks for generations'); const generationsWithTasks = await Promise.all( createdGenerations.map(async (generation) => { // 在事务中直接创建 asyncTask const [createdAsyncTask] = await tx .insert(asyncTasks) .values({ status: AsyncTaskStatus.Pending, type: AsyncTaskType.ImageGeneration, userId, }) .returning(); const asyncTaskId = createdAsyncTask.id; log('Created async task %s for generation %s', asyncTaskId, generation.id); // 更新 generation 的 asyncTaskId await tx .update(generations) .set({ asyncTaskId }) .where(and(eq(generations.id, generation.id), eq(generations.userId, userId))); return { asyncTaskId, generation }; }), ); log('All async tasks created in transaction'); return { batch, generationsWithTasks, }; }); log('Database transaction completed successfully. Starting async task triggers directly.'); // 步骤 2: 直接执行所有生图任务(去掉 after 包装) log('Starting async image generation tasks directly'); try { log('Creating unified async caller for userId: %s', userId); log( 'Lambda context - userId: %s, jwtPayload keys: %O', ctx.userId, Object.keys(ctx.jwtPayload || {}), ); // 使用统一的 caller 工厂创建 caller const asyncCaller = await createAsyncCaller({ jwtPayload: ctx.jwtPayload, userId: ctx.userId, }); log('Unified async caller created successfully for userId: %s', ctx.userId); log('Processing %d async image generation tasks', generationsWithTasks.length); // 启动所有图像生成任务(不等待完成,真正的后台任务) generationsWithTasks.forEach(({ generation, asyncTaskId }) => { log('Starting background async task %s for generation %s', asyncTaskId, generation.id); // 不使用 await,让任务在后台异步执行 // 这里不应该 await 也不应该 .then.catch,让 runtime 早点释放计算资源 asyncCaller.image.createImage({ generationId: generation.id, model, params, provider, taskId: asyncTaskId, // 使用原始参数 }); }); log('All %d background async image generation tasks started', generationsWithTasks.length); } catch (e) { console.error('[createImage] Failed to process async tasks:', e); log('Failed to process async tasks: %O', e); // 如果整体失败,更新所有任务状态为失败 try { await Promise.allSettled( generationsWithTasks.map(({ asyncTaskId }) => asyncTaskModel.update(asyncTaskId, { error: new AsyncTaskError( AsyncTaskErrorType.ServerError, 'start async task error: ' + (e instanceof Error ? e.message : 'Unknown error'), ), status: AsyncTaskStatus.Error, }), ), ); } catch (batchUpdateError) { console.error('Failed to update batch task statuses:', batchUpdateError); } } const createdGenerations = generationsWithTasks.map((item) => item.generation); log('Image creation process completed successfully: %O', { batchId: createdBatch.id, generationCount: createdGenerations.length, generationIds: createdGenerations.map((g) => g.id), }); return { data: { batch: createdBatch, generations: createdGenerations, }, success: true, }; }), }); export type ImageRouter = typeof imageRouter;