@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
284 lines (247 loc) • 7.87 kB
text/typescript
import { and, asc, desc, eq, inArray } from 'drizzle-orm';
import {
AiModelSortMap,
AiModelSourceEnum,
AiProviderModelListItem,
EnabledAiModel,
ToggleAiModelEnableParams,
} from 'model-bank';
import { AiModelSelectItem, NewAiModelItem, aiModels } from '../schemas';
import { LobeChatDatabase } from '../type';
export class AiModelModel {
private userId: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
this.db = db;
}
/**
* Helper method to validate if array is empty and return early if needed
* @param array - Array to validate
* @returns true if array is empty, false otherwise
*/
private isEmptyArray(array: unknown[]): boolean {
return array.length === 0;
}
create = async (params: NewAiModelItem) => {
const [result] = await this.db
.insert(aiModels)
.values({
...params,
enabled: params.enabled ?? true, // enabled by default, but respect explicit value
source: AiModelSourceEnum.Custom,
userId: this.userId,
})
.returning();
return result;
};
delete = async (id: string, providerId: string) => {
return this.db
.delete(aiModels)
.where(
and(
eq(aiModels.id, id),
eq(aiModels.providerId, providerId),
eq(aiModels.userId, this.userId),
),
);
};
deleteAll = async () => {
return this.db.delete(aiModels).where(eq(aiModels.userId, this.userId));
};
query = async () => {
return this.db.query.aiModels.findMany({
orderBy: [desc(aiModels.updatedAt)],
where: eq(aiModels.userId, this.userId),
});
};
getModelListByProviderId = async (providerId: string) => {
const result = await this.db
.select({
abilities: aiModels.abilities,
config: aiModels.config,
contextWindowTokens: aiModels.contextWindowTokens,
description: aiModels.description,
displayName: aiModels.displayName,
enabled: aiModels.enabled,
id: aiModels.id,
parameters: aiModels.parameters,
pricing: aiModels.pricing,
releasedAt: aiModels.releasedAt,
source: aiModels.source,
type: aiModels.type,
})
.from(aiModels)
.where(and(eq(aiModels.providerId, providerId), eq(aiModels.userId, this.userId)))
.orderBy(
asc(aiModels.sort),
desc(aiModels.enabled),
desc(aiModels.releasedAt),
desc(aiModels.updatedAt),
);
return result as AiProviderModelListItem[];
};
getAllModels = async () => {
const data = await this.db
.select({
abilities: aiModels.abilities,
config: aiModels.config,
contextWindowTokens: aiModels.contextWindowTokens,
displayName: aiModels.displayName,
enabled: aiModels.enabled,
id: aiModels.id,
parameters: aiModels.parameters,
providerId: aiModels.providerId,
sort: aiModels.sort,
source: aiModels.source,
type: aiModels.type,
})
.from(aiModels)
.where(and(eq(aiModels.userId, this.userId)));
return data as EnabledAiModel[];
};
findById = async (id: string) => {
return this.db.query.aiModels.findFirst({
where: and(eq(aiModels.id, id), eq(aiModels.userId, this.userId)),
});
};
update = async (id: string, providerId: string, value: Partial<AiModelSelectItem>) => {
return this.db
.insert(aiModels)
.values({ ...value, id, providerId, updatedAt: new Date(), userId: this.userId })
.onConflictDoUpdate({
set: value,
target: [aiModels.id, aiModels.providerId, aiModels.userId],
});
};
toggleModelEnabled = async (value: ToggleAiModelEnableParams) => {
const now = new Date();
const insertValues = {
...value,
updatedAt: now,
userId: this.userId,
} as typeof aiModels.$inferInsert;
if (value.type) insertValues.type = value.type;
const updateValues: Partial<typeof aiModels.$inferInsert> = {
enabled: value.enabled,
updatedAt: now,
};
if (value.type) updateValues.type = value.type;
return this.db
.insert(aiModels)
.values(insertValues)
.onConflictDoUpdate({
set: updateValues,
target: [aiModels.id, aiModels.providerId, aiModels.userId],
});
};
batchUpdateAiModels = async (providerId: string, models: AiProviderModelListItem[]) => {
// Early return if models array is empty to prevent database insertion error
if (this.isEmptyArray(models)) {
return [];
}
const records = models.map(({ id, ...model }) => ({
...model,
id,
providerId,
updatedAt: new Date(),
userId: this.userId,
}));
return this.db
.insert(aiModels)
.values(records)
.onConflictDoNothing({
target: [aiModels.id, aiModels.userId, aiModels.providerId],
})
.returning();
};
batchToggleAiModels = async (providerId: string, models: string[], enabled: boolean) => {
// Early return if models array is empty to prevent database insertion error
if (this.isEmptyArray(models)) {
return;
}
return this.db.transaction(async (trx) => {
// 1. insert models that are not in the db
const insertedRecords = await trx
.insert(aiModels)
.values(
models.map((i) => ({
enabled,
id: i,
providerId,
// if the model is not in the db, it's a builtin model
source: AiModelSourceEnum.Builtin,
updatedAt: new Date(),
userId: this.userId,
})),
)
.onConflictDoNothing({
target: [aiModels.id, aiModels.userId, aiModels.providerId],
})
.returning();
// 2. update models that are in the db
const insertedIds = new Set(insertedRecords.map((r) => r.id));
const recordsToUpdate = models.filter((r) => !insertedIds.has(r));
await trx
.update(aiModels)
.set({ enabled })
.where(
and(
eq(aiModels.providerId, providerId),
inArray(aiModels.id, recordsToUpdate),
eq(aiModels.userId, this.userId),
),
);
});
};
clearRemoteModels(providerId: string) {
return this.db
.delete(aiModels)
.where(
and(
eq(aiModels.providerId, providerId),
eq(aiModels.source, AiModelSourceEnum.Remote),
eq(aiModels.userId, this.userId),
),
);
}
clearModelsByProvider(providerId: string) {
return this.db
.delete(aiModels)
.where(and(eq(aiModels.providerId, providerId), eq(aiModels.userId, this.userId)));
}
updateModelsOrder = async (providerId: string, sortMap: AiModelSortMap[]) => {
// Early return if sortMap array is empty
if (this.isEmptyArray(sortMap)) {
return;
}
await this.db.transaction(async (tx) => {
const updates = sortMap.map(({ id, sort, type }) => {
const now = new Date();
const insertValues: typeof aiModels.$inferInsert = {
enabled: true,
id,
providerId,
sort,
// source: isBuiltin ? 'builtin' : 'custom',
updatedAt: now,
userId: this.userId,
};
if (type) insertValues.type = type;
const updateValues: Partial<typeof aiModels.$inferInsert> = {
sort,
updatedAt: now,
};
if (type) updateValues.type = type;
return tx
.insert(aiModels)
.values(insertValues)
.onConflictDoUpdate({
set: updateValues,
target: [aiModels.id, aiModels.userId, aiModels.providerId],
});
});
await Promise.all(updates);
});
};
}