@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
281 lines (238 loc) • 8.17 kB
text/typescript
import { and, asc, desc, eq } from 'drizzle-orm/expressions';
import { isEmpty } from 'lodash-es';
import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders';
import { LobeChatDatabase } from '@/database/type';
import { ModelProvider } from '@/libs/model-runtime';
import {
AiProviderDetailItem,
AiProviderListItem,
AiProviderRuntimeConfig,
CreateAiProviderParams,
UpdateAiProviderConfigParams,
} from '@/types/aiProvider';
import { merge } from '@/utils/merge';
import { AiProviderSelectItem, aiModels, aiProviders } from '../schemas';
type DecryptUserKeyVaults = (encryptKeyVaultsStr: string | null) => Promise<any>;
type EncryptUserKeyVaults = (keyVaults: string) => Promise<string>;
export class AiProviderModel {
private userId: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
this.db = db;
}
create = async (
{ keyVaults: userKey, ...params }: CreateAiProviderParams,
encryptor?: EncryptUserKeyVaults,
) => {
// eslint-disable-next-line unicorn/consistent-function-scoping
const defaultSerialize = (s: string) => s;
const encrypt = encryptor ?? defaultSerialize;
const keyVaults = await encrypt(JSON.stringify(userKey));
const [result] = await this.db
.insert(aiProviders)
.values({
...params,
// each new ai provider we will set it to enabled by default
enabled: true,
keyVaults,
userId: this.userId,
})
.returning();
return result;
};
delete = async (id: string) => {
return this.db.transaction(async (trx) => {
// 1. delete all models of the provider
await trx
.delete(aiModels)
.where(and(eq(aiModels.providerId, id), eq(aiModels.userId, this.userId)));
// 2. delete the provider
await trx
.delete(aiProviders)
.where(and(eq(aiProviders.id, id), eq(aiProviders.userId, this.userId)));
});
};
deleteAll = async () => {
return this.db.delete(aiProviders).where(eq(aiProviders.userId, this.userId));
};
query = async () => {
return this.db.query.aiProviders.findMany({
orderBy: [desc(aiProviders.updatedAt)],
where: eq(aiProviders.userId, this.userId),
});
};
getAiProviderList = async (): Promise<AiProviderListItem[]> => {
const result = await this.db
.select({
description: aiProviders.description,
enabled: aiProviders.enabled,
id: aiProviders.id,
logo: aiProviders.logo,
name: aiProviders.name,
sort: aiProviders.sort,
source: aiProviders.source,
})
.from(aiProviders)
.where(eq(aiProviders.userId, this.userId))
.orderBy(asc(aiProviders.sort), desc(aiProviders.updatedAt));
return result as AiProviderListItem[];
};
findById = async (id: string) => {
return this.db.query.aiProviders.findFirst({
where: and(eq(aiProviders.id, id), eq(aiProviders.userId, this.userId)),
});
};
update = async (id: string, value: Partial<AiProviderSelectItem>) => {
return this.db
.update(aiProviders)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(aiProviders.id, id), eq(aiProviders.userId, this.userId)));
};
updateConfig = async (
id: string,
value: UpdateAiProviderConfigParams,
encryptor?: EncryptUserKeyVaults,
) => {
// eslint-disable-next-line unicorn/consistent-function-scoping
const defaultSerialize = (s: string) => s;
const encrypt = encryptor ?? defaultSerialize;
const keyVaults = await encrypt(JSON.stringify(value.keyVaults));
const commonFields = {
checkModel: value.checkModel,
config: value.config,
fetchOnClient: value.fetchOnClient,
keyVaults,
};
return this.db
.insert(aiProviders)
.values({
...commonFields,
id,
source: this.getProviderSource(id),
userId: this.userId,
})
.onConflictDoUpdate({
set: commonFields,
target: [aiProviders.id, aiProviders.userId],
});
};
toggleProviderEnabled = async (id: string, enabled: boolean) => {
return this.db
.insert(aiProviders)
.values({
enabled,
id,
source: this.getProviderSource(id),
updatedAt: new Date(),
userId: this.userId,
})
.onConflictDoUpdate({
set: { enabled },
target: [aiProviders.id, aiProviders.userId],
});
};
updateOrder = async (sortMap: { id: string; sort: number }[]) => {
await this.db.transaction(async (tx) => {
const updates = sortMap.map(({ id, sort }) => {
return tx
.insert(aiProviders)
.values({
enabled: true,
id,
sort,
source: this.getProviderSource(id),
updatedAt: new Date(),
userId: this.userId,
})
.onConflictDoUpdate({
set: { sort, updatedAt: new Date() },
target: [aiProviders.id, aiProviders.userId],
});
});
await Promise.all(updates);
});
};
getAiProviderById = async (
id: string,
decryptor?: DecryptUserKeyVaults,
): Promise<AiProviderDetailItem | undefined> => {
const query = this.db
.select({
checkModel: aiProviders.checkModel,
config: aiProviders.config,
description: aiProviders.description,
enabled: aiProviders.enabled,
fetchOnClient: aiProviders.fetchOnClient,
id: aiProviders.id,
keyVaults: aiProviders.keyVaults,
logo: aiProviders.logo,
name: aiProviders.name,
settings: aiProviders.settings,
source: aiProviders.source,
})
.from(aiProviders)
.where(and(eq(aiProviders.id, id), eq(aiProviders.userId, this.userId)))
.limit(1);
const [result] = await query;
if (!result) {
// if the provider is builtin but not init, we will insert it to the db
if (this.isBuiltInProvider(id)) {
await this.db.insert(aiProviders).values({ id, source: 'builtin', userId: this.userId });
const resultAgain = await query;
return { ...resultAgain[0] } as unknown as AiProviderDetailItem;
}
return;
}
const decrypt = decryptor ?? JSON.parse;
let keyVaults = {};
if (!!result.keyVaults) {
try {
keyVaults = await decrypt(result.keyVaults);
} catch {
/* empty */
}
}
return {
...result,
fetchOnClient: typeof result.fetchOnClient === 'boolean' ? result.fetchOnClient : undefined,
keyVaults,
settings: isEmpty(result.settings) ? undefined : result.settings,
} as AiProviderDetailItem;
};
getAiProviderRuntimeConfig = async (decryptor?: DecryptUserKeyVaults) => {
const result = await this.db
.select({
config: aiProviders.config,
fetchOnClient: aiProviders.fetchOnClient,
id: aiProviders.id,
keyVaults: aiProviders.keyVaults,
settings: aiProviders.settings,
})
.from(aiProviders)
.where(and(eq(aiProviders.userId, this.userId)));
const decrypt = decryptor ?? JSON.parse;
let runtimeConfig: Record<string, AiProviderRuntimeConfig> = {};
for (const item of result) {
const builtin = DEFAULT_MODEL_PROVIDER_LIST.find((provider) => provider.id === item.id);
const userSettings = item.settings || {};
let keyVaults = {};
if (!!item.keyVaults) {
try {
keyVaults = await decrypt(item.keyVaults);
} catch {
/* empty */
}
}
runtimeConfig[item.id] = {
config: item.config || {},
fetchOnClient: typeof item.fetchOnClient === 'boolean' ? item.fetchOnClient : undefined,
keyVaults,
settings: !!builtin ? merge(builtin.settings, userSettings) : userSettings,
};
}
return runtimeConfig;
};
private isBuiltInProvider = (id: string) => Object.values(ModelProvider).includes(id as any);
private getProviderSource = (id: string) => (this.isBuiltInProvider(id) ? 'builtin' : 'custom');
}