@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
541 lines (478 loc) • 17.1 kB
text/typescript
import debug from 'debug';
import { sql } from 'drizzle-orm';
import { eq } from 'drizzle-orm/expressions';
import {
oidcAccessTokens,
oidcAuthorizationCodes,
oidcClients,
oidcDeviceCodes,
oidcGrants,
oidcInteractions,
oidcRefreshTokens,
oidcSessions,
} from '@/database/schemas/oidc';
import { LobeChatDatabase } from '@/database/type';
// 创建 adapter 日志命名空间
const log = debug('lobe-oidc:adapter');
class OIDCAdapter {
private db: LobeChatDatabase;
private name: string;
constructor(name: string, db: LobeChatDatabase) {
log('[%s] Constructor called with name: %s', name, name);
this.name = name;
this.db = db;
}
/**
* 根据模型名称获取对应的数据库表
*/
private getTable() {
log('Getting table for model: %s', this.name);
switch (this.name) {
case 'AccessToken': {
return oidcAccessTokens;
}
case 'AuthorizationCode': {
return oidcAuthorizationCodes;
}
case 'RefreshToken': {
return oidcRefreshTokens;
}
case 'DeviceCode': {
return oidcDeviceCodes;
}
case 'ClientCredentials': {
return oidcAccessTokens;
} // 使用相同的表
case 'Client': {
return oidcClients;
}
case 'InitialAccessToken': {
return oidcAccessTokens;
} // 使用相同的表
case 'RegistrationAccessToken': {
return oidcAccessTokens;
} // 使用相同的表
case 'Interaction': {
return oidcInteractions;
}
case 'ReplayDetection': {
log('ReplayDetection - no persistent storage needed');
return null;
} // 不需要持久化
case 'PushedAuthorizationRequest': {
return oidcAuthorizationCodes;
} // 使用相同的表
case 'Grant': {
return oidcGrants;
}
case 'Session': {
return oidcSessions;
}
default: {
const error = `不支持的模型: ${this.name}`;
log('ERROR: %s', error);
throw new Error(error);
}
}
}
/**
* 创建模型实例
*/
async upsert(id: string, payload: any, expiresIn: number): Promise<void> {
log('[%s] upsert called - id: %s, expiresIn: %d', this.name, id, `${expiresIn}s`);
log('[%s] payload: %O', this.name, payload);
const table = this.getTable();
if (!table) {
log('[%s] upsert - No table for model, returning early', this.name);
return;
}
if (this.name === 'Client') {
// 客户端模型特殊处理,直接使用传入的数据
log('[Client] Upserting client record');
try {
await this.db
.insert(table)
.values({
applicationType: payload.application_type,
clientSecret: payload.client_secret,
clientUri: payload.client_uri,
description: payload.description,
grants: payload.grant_types || [],
id,
isFirstParty: !!payload.isFirstParty,
logoUri: payload.logo_uri,
name: payload.name,
policyUri: payload.policy_uri,
redirectUris: payload.redirectUris || [],
responseTypes: payload.response_types || [],
scopes: Array.isArray(payload.scopes)
? payload.scopes
: payload.scope
? payload.scope.split(' ')
: [],
tokenEndpointAuthMethod: payload.token_endpoint_auth_method,
tosUri: payload.tos_uri,
} as any)
.onConflictDoUpdate({
set: {
applicationType: payload.application_type,
clientSecret: payload.clientSecret,
clientUri: payload.client_uri,
description: payload.description,
grants: payload.grant_types || [],
isFirstParty: !!payload.isFirstParty,
logoUri: payload.logo_uri,
name: payload.name,
policyUri: payload.policy_uri,
redirectUris: payload.redirectUris || [],
responseTypes: payload.response_types || [],
scopes: payload.scope ? payload.scope.split(' ') : [],
tokenEndpointAuthMethod: payload.token_endpoint_auth_method,
tosUri: payload.tos_uri,
} as any,
target: (table as any).id,
});
log('[Client] Successfully upserted client: %s', id);
} catch (error) {
log('[Client] ERROR upserting client: %O', error);
throw error;
}
return;
}
// 对其他模型,保存完整数据和元数据
const expiresAt = expiresIn ? new Date(Date.now() + expiresIn * 1000) : undefined;
log('[%s] expiresAt set to: %s', this.name, expiresAt ? expiresAt.toISOString() : 'undefined');
const record: Record<string, any> = {
data: payload,
expiresAt,
id,
};
// 添加特定字段
if (payload.accountId) {
record.userId = payload.accountId;
log('[%s] Setting userId: %s', this.name, payload.accountId);
} else {
try {
const { getUserAuth } = await import('@/utils/server/auth');
try {
const { userId } = await getUserAuth();
if (userId) {
payload.accountId = userId;
record.userId = userId;
log('[%s] Setting userId from auth context: %s', this.name, userId);
}
} catch (authError) {
log('[%s] Error getting userId from auth context: %O', this.name, authError);
// 如果获取 userId 失败,继续处理而不抛出错误
}
} catch (importError) {
log('[%s] Error importing auth module: %O', this.name, importError);
// 如果导入模块失败,继续处理而不抛出错误
}
}
if (payload.clientId) {
record.clientId = payload.clientId;
log('[%s] Setting clientId: %s', this.name, payload.clientId);
}
if (payload.grantId) {
record.grantId = payload.grantId;
log('[%s] Setting grantId: %s', this.name, payload.grantId);
}
if (this.name === 'DeviceCode' && payload.userCode) {
record.userCode = payload.userCode;
log('[DeviceCode] Setting userCode: %s', payload.userCode);
}
try {
log('[%s] Executing upsert DB operation', this.name);
await this.db
.insert(table)
.values(record as any)
.onConflictDoUpdate({
set: {
data: payload,
expiresAt,
...(payload.accountId ? { userId: payload.accountId } : {}),
...(payload.clientId ? { clientId: payload.clientId } : {}),
...(payload.grantId ? { grantId: payload.grantId } : {}),
...(this.name === 'DeviceCode' && payload.userCode
? { userCode: payload.userCode }
: {}),
} as any,
target: (table as any).id,
});
log('[%s] Successfully upserted record: %s', this.name, id);
} catch (error) {
log('[%s] ERROR upserting record: %O', this.name, error);
console.error(`[OIDC Adapter] Error upserting ${this.name}:`, error);
throw error;
}
}
/**
* 查找模型实例
*/
async find(id: string): Promise<any> {
log('[%s] find called - id: %s', this.name, id);
const table = this.getTable();
if (!table) {
log('[%s] find - No table for model, returning undefined', this.name);
return undefined;
}
try {
log('[%s] Executing find DB query', this.name);
const result = await this.db
.select()
.from(table)
.where(eq((table as any).id, id))
.limit(1);
log('[%s] Find query results: %O', this.name, result);
if (!result || result.length === 0) {
log('[%s] No record found for id: %s', this.name, id);
return undefined;
}
const model = result[0] as any;
// 客户端模型特殊处理
if (this.name === 'Client') {
log('[Client] Converting client record to expected format');
return {
application_type: model.applicationType,
client_id: model.id,
client_secret: model.clientSecret,
client_uri: model.clientUri,
grant_types: model.grants,
isFirstParty: model.isFirstParty,
logo_uri: model.logoUri,
policy_uri: model.policyUri,
redirect_uris: model.redirectUris,
response_types: model.responseTypes,
scope: model.scopes.join(' '),
token_endpoint_auth_method: model.tokenEndpointAuthMethod,
tos_uri: model.tosUri,
};
}
// 如果记录已过期,返回 undefined
if (model.expiresAt && new Date() > new Date(model.expiresAt)) {
log('[%s] Record expired (expiresAt: %s), returning undefined', this.name, model.expiresAt);
return undefined;
}
// 如果记录已被消费,返回 undefined
if (model.consumedAt) {
log(
'[%s] Record already consumed (consumedAt: %s), returning undefined',
this.name,
model.consumedAt,
);
return undefined;
}
log('[%s] Successfully found and returning record data', this.name);
return model.data;
} catch (error) {
log('[%s] ERROR finding record: %O', this.name, error);
console.error(`[OIDC Adapter] Error finding ${this.name}:`, error);
return undefined;
}
}
/**
* 查找模型实例 by userCode (仅用于设备流程)
*/
async findByUserCode(userCode: string): Promise<any> {
log('[DeviceCode] findByUserCode called - userCode: %s', userCode);
if (this.name !== 'DeviceCode') {
const error = 'findByUserCode 只能用于 DeviceCode 模型';
log('ERROR: %s', error);
throw new Error(error);
}
try {
log('[DeviceCode] Executing findByUserCode DB query');
const result = await this.db
.select()
.from(oidcDeviceCodes)
.where(eq(oidcDeviceCodes.userCode, userCode))
.limit(1);
log('[DeviceCode] findByUserCode query results: %O', result);
if (!result || result.length === 0) {
log('[DeviceCode] No record found for userCode: %s', userCode);
return undefined;
}
const model = result[0];
// 如果记录已过期或已被消费,返回 undefined
if (model.expiresAt && new Date() > new Date(model.expiresAt)) {
log('[DeviceCode] Record expired (expiresAt: %s), returning undefined', model.expiresAt);
return undefined;
}
if (model.consumedAt) {
log(
'[DeviceCode] Record already consumed (consumedAt: %s), returning undefined',
model.consumedAt,
);
return undefined;
}
log('[DeviceCode] Successfully found and returning record data by userCode');
return model.data;
} catch (error) {
log('[DeviceCode] ERROR finding record by userCode: %O', error);
console.error('[OIDC Adapter] Error finding DeviceCode by userCode:', error);
return undefined;
}
}
/**
* 查找交互实例 by uid
*/
async findByUid(uid: string): Promise<any> {
log('[Interaction] findByUid called - uid: %s', uid);
const table = this.getTable();
if (this.name === 'Session') {
try {
const jsonbUidEq = sql`${(table as any).data}->>'uid' = ${uid}`;
// @ts-ignore
const results = await this.db.select().from(table).where(jsonbUidEq).limit(1);
log('[Session] Find by data.uid query results: %O', results);
if (!results || results.length === 0) {
log('[Session] No record found by data.uid: %s', uid);
return undefined;
}
const model = results[0] as any;
// 检查过期
if (model.expiresAt && model.expiresAt < new Date()) {
log('[Session] Record found by data.uid but expired: %s', uid);
await this.destroy(model.id); // 仍然使用主键 id 删除
return undefined;
}
log('[Session] Successfully found by data.uid and returning record data for uid %s', uid);
return model.data;
} catch (error) {
log('[Session] ERROR during findSessionByUid operation for %s: %O', uid, error);
console.error(`[OIDC Adapter] Error finding Session by uid:`, error);
}
}
// 复用 find 方法实现
log('[Interaction] Delegating to find() method');
return this.find(uid);
}
/**
* 根据用户 ID 查找会话
* 用于会话预同步
*/
async findSessionByUserId(userId: string): Promise<any> {
log('[%s] findSessionByUserId called - userId: %s', this.name, userId);
if (this.name !== 'Session') {
log('[%s] findSessionByUserId - Not a Session model, returning undefined', this.name);
return undefined;
}
const table = this.getTable();
if (!table) {
log('[%s] findSessionByUserId - No table for model, returning undefined', this.name);
return undefined;
}
try {
log('[%s] Executing findSessionByUserId DB query', this.name);
const result = await this.db
.select()
.from(table)
.where(eq((table as any).userId, userId))
.limit(1);
log('[%s] findSessionByUserId query results: %O', this.name, result);
if (!result || result.length === 0) {
log('[%s] No session found for userId: %s', this.name, userId);
return undefined;
}
return (result[0] as { data: any }).data;
} catch (error) {
log('[%s] ERROR finding session by userId: %O', this.name, error);
console.error(`[OIDC Adapter] Error finding session by userId:`, error);
return undefined;
}
}
/**
* 销毁模型实例
*/
async destroy(id: string): Promise<void> {
log('[%s] destroy called - id: %s', this.name, id);
const table = this.getTable();
if (!table) {
log('[%s] destroy - No table for model, returning early', this.name);
return;
}
try {
log('[%s] Executing destroy DB operation', this.name);
await this.db.delete(table).where(eq((table as any).id, id));
log('[%s] Successfully destroyed record: %s', this.name, id);
} catch (error) {
log('[%s] ERROR destroying record: %O', this.name, error);
console.error(`[OIDC Adapter] Error destroying ${this.name}:`, error);
throw error;
}
}
/**
* 标记模型实例为已消费
*/
async consume(id: string): Promise<void> {
log('[%s] consume called - id: %s', this.name, id);
const table = this.getTable();
if (!table) {
log('[%s] consume - No table for model, returning early', this.name);
return;
}
try {
log('[%s] Executing consume DB operation', this.name);
await this.db
.update(table)
// @ts-ignore
.set({ consumedAt: new Date() })
.where(eq((table as any).id, id));
log('[%s] Successfully consumed record: %s', this.name, id);
} catch (error) {
log('[%s] ERROR consuming record: %O', this.name, error);
console.error(`[OIDC Adapter] Error consuming ${this.name}:`, error);
throw error;
}
}
/**
* 根据 grantId 撤销所有相关模型实例
*/
async revokeByGrantId(grantId: string): Promise<void> {
log('[%s] revokeByGrantId called - grantId: %s', this.name, grantId);
// Grants 本身不需要通过 grantId 来撤销
if (this.name === 'Grant') {
log('[Grant] revokeByGrantId skipped for Grant model, as it is the grant itself');
return;
}
// 提前检查模型名称是否有效,即使后续不直接使用 table
this.getTable();
try {
log('[%s] Starting transaction for revokeByGrantId operations', this.name);
// 使用事务删除所有包含grantId的记录,确保原子性
await this.db.transaction(async (tx) => {
// 所有可能包含grantId的表
const tables = [
oidcAccessTokens,
oidcAuthorizationCodes,
oidcRefreshTokens,
oidcDeviceCodes,
];
for (const table of tables) {
if ('grantId' in table) {
log('[%s] Revoking %s records by grantId: %s', this.name, grantId);
await tx.delete(table).where(eq((table as any).grantId, grantId));
}
}
});
log(
'[%s] Successfully completed transaction for revoking all records by grantId: %s',
this.name,
grantId,
);
} catch (error) {
log('[%s] ERROR in revokeByGrantId transaction: %O', this.name, error);
console.error(`[OIDC Adapter] Error in revokeByGrantId transaction:`, error);
throw error;
}
}
/**
* 创建适配器工厂
*/
static createAdapterFactory = (db: LobeChatDatabase) => {
log('Creating adapter factory with database instance');
return (name: string) => new OIDCAdapter(name, db);
};
}
export { OIDCAdapter as DrizzleAdapter };