@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
245 lines (208 loc) • 7.14 kB
text/typescript
import { TRPCError } from '@trpc/server';
import dayjs from 'dayjs';
import { eq } from 'drizzle-orm/expressions';
import type { AdapterAccount } from 'next-auth/adapters';
import type { PartialDeep } from 'type-fest';
import { LobeChatDatabase } from '@/database/type';
import { UserGuide, UserPreference } from '@/types/user';
import { UserKeyVaults, UserSettings } from '@/types/user/settings';
import { merge } from '@/utils/merge';
import { today } from '@/utils/time';
import {
NewUser,
UserItem,
UserSettingsItem,
nextauthAccounts,
userSettings,
users,
} from '../schemas';
type DecryptUserKeyVaults = (
encryptKeyVaultsStr: string | null,
userId?: string,
) => Promise<UserKeyVaults>;
export class UserNotFoundError extends TRPCError {
constructor() {
super({ code: 'UNAUTHORIZED', message: 'user not found' });
}
}
export class UserModel {
private userId: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
this.db = db;
}
getUserRegistrationDuration = async (): Promise<{
createdAt: string;
duration: number;
updatedAt: string;
}> => {
const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) });
if (!user)
return {
createdAt: today().format('YYYY-MM-DD'),
duration: 1,
updatedAt: today().format('YYYY-MM-DD'),
};
return {
createdAt: dayjs(user.createdAt).format('YYYY-MM-DD'),
duration: dayjs().diff(dayjs(user.createdAt), 'day') + 1,
updatedAt: today().format('YYYY-MM-DD'),
};
};
getUserState = async (decryptor: DecryptUserKeyVaults) => {
const result = await this.db
.select({
avatar: users.avatar,
email: users.email,
firstName: users.firstName,
fullName: users.fullName,
isOnboarded: users.isOnboarded,
lastName: users.lastName,
preference: users.preference,
settingsDefaultAgent: userSettings.defaultAgent,
settingsGeneral: userSettings.general,
settingsHotkey: userSettings.hotkey,
settingsKeyVaults: userSettings.keyVaults,
settingsLanguageModel: userSettings.languageModel,
settingsSystemAgent: userSettings.systemAgent,
settingsTTS: userSettings.tts,
settingsTool: userSettings.tool,
username: users.username,
})
.from(users)
.where(eq(users.id, this.userId))
.leftJoin(userSettings, eq(users.id, userSettings.id));
if (!result || !result[0]) {
throw new UserNotFoundError();
}
const state = result[0];
// Decrypt keyVaults
let decryptKeyVaults = {};
try {
decryptKeyVaults = await decryptor(state.settingsKeyVaults, this.userId);
} catch {
/* empty */
}
const settings: PartialDeep<UserSettings> = {
defaultAgent: state.settingsDefaultAgent || {},
general: state.settingsGeneral || {},
hotkey: state.settingsHotkey || {},
keyVaults: decryptKeyVaults,
languageModel: state.settingsLanguageModel || {},
systemAgent: state.settingsSystemAgent || {},
tool: state.settingsTool || {},
tts: state.settingsTTS || {},
};
return {
avatar: state.avatar || undefined,
email: state.email || undefined,
firstName: state.firstName || undefined,
fullName: state.fullName || undefined,
isOnboarded: state.isOnboarded,
lastName: state.lastName || undefined,
preference: state.preference as UserPreference,
settings,
userId: this.userId,
username: state.username || undefined,
};
};
getUserSSOProviders = async () => {
const result = await this.db
.select({
expiresAt: nextauthAccounts.expires_at,
provider: nextauthAccounts.provider,
providerAccountId: nextauthAccounts.providerAccountId,
scope: nextauthAccounts.scope,
type: nextauthAccounts.type,
userId: nextauthAccounts.userId,
})
.from(nextauthAccounts)
.where(eq(nextauthAccounts.userId, this.userId));
return result as unknown as AdapterAccount[];
};
getUserSettings = async () => {
return this.db.query.userSettings.findFirst({ where: eq(userSettings.id, this.userId) });
};
updateUser = async (value: Partial<UserItem>) => {
return this.db
.update(users)
.set({ ...value, updatedAt: new Date() })
.where(eq(users.id, this.userId));
};
deleteSetting = async () => {
return this.db.delete(userSettings).where(eq(userSettings.id, this.userId));
};
updateSetting = async (value: Partial<UserSettingsItem>) => {
return this.db
.insert(userSettings)
.values({
id: this.userId,
...value,
})
.onConflictDoUpdate({
set: value,
target: userSettings.id,
});
};
updatePreference = async (value: Partial<UserPreference>) => {
const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) });
if (!user) return;
return this.db
.update(users)
.set({ preference: merge(user.preference, value) })
.where(eq(users.id, this.userId));
};
updateGuide = async (value: Partial<UserGuide>) => {
const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) });
if (!user) return;
const prevPreference = (user.preference || {}) as UserPreference;
return this.db
.update(users)
.set({ preference: { ...prevPreference, guide: merge(prevPreference.guide || {}, value) } })
.where(eq(users.id, this.userId));
};
// Static method
static makeSureUserExist = async (db: LobeChatDatabase, userId: string) => {
await db.insert(users).values({ id: userId }).onConflictDoNothing();
};
static createUser = async (db: LobeChatDatabase, params: NewUser) => {
// if user already exists, skip creation
if (params.id) {
const user = await db.query.users.findFirst({ where: eq(users.id, params.id) });
if (!!user) return { duplicate: true };
}
const [user] = await db
.insert(users)
.values({ ...params })
.returning();
return { duplicate: false, user };
};
static deleteUser = async (db: LobeChatDatabase, id: string) => {
return db.delete(users).where(eq(users.id, id));
};
static findById = async (db: LobeChatDatabase, id: string) => {
return db.query.users.findFirst({ where: eq(users.id, id) });
};
static findByEmail = async (db: LobeChatDatabase, email: string) => {
return db.query.users.findFirst({ where: eq(users.email, email) });
};
static getUserApiKeys = async (
db: LobeChatDatabase,
id: string,
decryptor: DecryptUserKeyVaults,
) => {
const result = await db
.select({
settingsKeyVaults: userSettings.keyVaults,
})
.from(userSettings)
.where(eq(userSettings.id, id));
if (!result || !result[0]) {
throw new UserNotFoundError();
}
const state = result[0];
// Decrypt keyVaults
return await decryptor(state.settingsKeyVaults, id);
};
}