UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

284 lines (247 loc) • 7.87 kB
import { and, asc, desc, eq, inArray } from 'drizzle-orm'; import { AiModelSortMap, AiModelSourceEnum, AiProviderModelListItem, EnabledAiModel, ToggleAiModelEnableParams, } from 'model-bank'; import { AiModelSelectItem, NewAiModelItem, aiModels } from '../schemas'; import { LobeChatDatabase } from '../type'; export class AiModelModel { private userId: string; private db: LobeChatDatabase; constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; this.db = db; } /** * Helper method to validate if array is empty and return early if needed * @param array - Array to validate * @returns true if array is empty, false otherwise */ private isEmptyArray(array: unknown[]): boolean { return array.length === 0; } create = async (params: NewAiModelItem) => { const [result] = await this.db .insert(aiModels) .values({ ...params, enabled: params.enabled ?? true, // enabled by default, but respect explicit value source: AiModelSourceEnum.Custom, userId: this.userId, }) .returning(); return result; }; delete = async (id: string, providerId: string) => { return this.db .delete(aiModels) .where( and( eq(aiModels.id, id), eq(aiModels.providerId, providerId), eq(aiModels.userId, this.userId), ), ); }; deleteAll = async () => { return this.db.delete(aiModels).where(eq(aiModels.userId, this.userId)); }; query = async () => { return this.db.query.aiModels.findMany({ orderBy: [desc(aiModels.updatedAt)], where: eq(aiModels.userId, this.userId), }); }; getModelListByProviderId = async (providerId: string) => { const result = await this.db .select({ abilities: aiModels.abilities, config: aiModels.config, contextWindowTokens: aiModels.contextWindowTokens, description: aiModels.description, displayName: aiModels.displayName, enabled: aiModels.enabled, id: aiModels.id, parameters: aiModels.parameters, pricing: aiModels.pricing, releasedAt: aiModels.releasedAt, source: aiModels.source, type: aiModels.type, }) .from(aiModels) .where(and(eq(aiModels.providerId, providerId), eq(aiModels.userId, this.userId))) .orderBy( asc(aiModels.sort), desc(aiModels.enabled), desc(aiModels.releasedAt), desc(aiModels.updatedAt), ); return result as AiProviderModelListItem[]; }; getAllModels = async () => { const data = await this.db .select({ abilities: aiModels.abilities, config: aiModels.config, contextWindowTokens: aiModels.contextWindowTokens, displayName: aiModels.displayName, enabled: aiModels.enabled, id: aiModels.id, parameters: aiModels.parameters, providerId: aiModels.providerId, sort: aiModels.sort, source: aiModels.source, type: aiModels.type, }) .from(aiModels) .where(and(eq(aiModels.userId, this.userId))); return data as EnabledAiModel[]; }; findById = async (id: string) => { return this.db.query.aiModels.findFirst({ where: and(eq(aiModels.id, id), eq(aiModels.userId, this.userId)), }); }; update = async (id: string, providerId: string, value: Partial<AiModelSelectItem>) => { return this.db .insert(aiModels) .values({ ...value, id, providerId, updatedAt: new Date(), userId: this.userId }) .onConflictDoUpdate({ set: value, target: [aiModels.id, aiModels.providerId, aiModels.userId], }); }; toggleModelEnabled = async (value: ToggleAiModelEnableParams) => { const now = new Date(); const insertValues = { ...value, updatedAt: now, userId: this.userId, } as typeof aiModels.$inferInsert; if (value.type) insertValues.type = value.type; const updateValues: Partial<typeof aiModels.$inferInsert> = { enabled: value.enabled, updatedAt: now, }; if (value.type) updateValues.type = value.type; return this.db .insert(aiModels) .values(insertValues) .onConflictDoUpdate({ set: updateValues, target: [aiModels.id, aiModels.providerId, aiModels.userId], }); }; batchUpdateAiModels = async (providerId: string, models: AiProviderModelListItem[]) => { // Early return if models array is empty to prevent database insertion error if (this.isEmptyArray(models)) { return []; } const records = models.map(({ id, ...model }) => ({ ...model, id, providerId, updatedAt: new Date(), userId: this.userId, })); return this.db .insert(aiModels) .values(records) .onConflictDoNothing({ target: [aiModels.id, aiModels.userId, aiModels.providerId], }) .returning(); }; batchToggleAiModels = async (providerId: string, models: string[], enabled: boolean) => { // Early return if models array is empty to prevent database insertion error if (this.isEmptyArray(models)) { return; } return this.db.transaction(async (trx) => { // 1. insert models that are not in the db const insertedRecords = await trx .insert(aiModels) .values( models.map((i) => ({ enabled, id: i, providerId, // if the model is not in the db, it's a builtin model source: AiModelSourceEnum.Builtin, updatedAt: new Date(), userId: this.userId, })), ) .onConflictDoNothing({ target: [aiModels.id, aiModels.userId, aiModels.providerId], }) .returning(); // 2. update models that are in the db const insertedIds = new Set(insertedRecords.map((r) => r.id)); const recordsToUpdate = models.filter((r) => !insertedIds.has(r)); await trx .update(aiModels) .set({ enabled }) .where( and( eq(aiModels.providerId, providerId), inArray(aiModels.id, recordsToUpdate), eq(aiModels.userId, this.userId), ), ); }); }; clearRemoteModels(providerId: string) { return this.db .delete(aiModels) .where( and( eq(aiModels.providerId, providerId), eq(aiModels.source, AiModelSourceEnum.Remote), eq(aiModels.userId, this.userId), ), ); } clearModelsByProvider(providerId: string) { return this.db .delete(aiModels) .where(and(eq(aiModels.providerId, providerId), eq(aiModels.userId, this.userId))); } updateModelsOrder = async (providerId: string, sortMap: AiModelSortMap[]) => { // Early return if sortMap array is empty if (this.isEmptyArray(sortMap)) { return; } await this.db.transaction(async (tx) => { const updates = sortMap.map(({ id, sort, type }) => { const now = new Date(); const insertValues: typeof aiModels.$inferInsert = { enabled: true, id, providerId, sort, // source: isBuiltin ? 'builtin' : 'custom', updatedAt: now, userId: this.userId, }; if (type) insertValues.type = type; const updateValues: Partial<typeof aiModels.$inferInsert> = { sort, updatedAt: now, }; if (type) updateValues.type = type; return tx .insert(aiModels) .values(insertValues) .onConflictDoUpdate({ set: updateValues, target: [aiModels.id, aiModels.userId, aiModels.providerId], }); }); await Promise.all(updates); }); }; }