@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
329 lines (289 loc) • 11.8 kB
text/typescript
import { isEmpty } from 'lodash-es';
import {
AIChatModelCard,
AiModelSourceEnum,
AiProviderModelListItem,
EnabledAiModel,
} from 'model-bank';
import pMap from 'p-map';
import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders';
import {
AiProviderDetailItem,
AiProviderListItem,
AiProviderRuntimeState,
EnabledProvider,
} from '@/types/aiProvider';
import { ProviderConfig } from '@/types/user/settings';
import { merge, mergeArrayById } from '@/utils/merge';
import { AiModelModel } from '../../models/aiModel';
import { AiProviderModel } from '../../models/aiProvider';
import { LobeChatDatabase } from '../../type';
type DecryptUserKeyVaults = (encryptKeyVaultsStr: string | null) => Promise<any>;
/**
* Provider 级默认表(只在本地内置模型没给出 settings.searchImpl 和 settings.searchProvider 时使用)
* 注意:不在 DB 存储,纯读取时注入
*/
const PROVIDER_SEARCH_DEFAULTS: Record<
string,
{ searchImpl?: 'tool' | 'params' | 'internal'; searchProvider?: string }
> = {
ai360: { searchImpl: 'params' },
aihubmix: { searchImpl: 'params' },
anthropic: { searchImpl: 'params' },
baichuan: { searchImpl: 'params' },
default: { searchImpl: 'params' },
google: { searchImpl: 'params', searchProvider: 'google' },
hunyuan: { searchImpl: 'params' },
jina: { searchImpl: 'internal' },
minimax: { searchImpl: 'params' },
// openai: 默认 params,但对 -search- 型号做 internal 特判
openai: { searchImpl: 'params' },
// perplexity: 默认 internal
perplexity: { searchImpl: 'internal' },
qwen: { searchImpl: 'params' },
spark: { searchImpl: 'params' }, // 某些模型(如 max-32k)若内置标了 internal,会优先使用内置
stepfun: { searchImpl: 'params' },
vertexai: { searchImpl: 'params', searchProvider: 'google' },
wenxin: { searchImpl: 'params' },
xai: { searchImpl: 'params' },
zhipu: { searchImpl: 'params' },
};
// 特殊模型配置 - 模型级别的特殊设置会覆盖服务商默认配置
const MODEL_SEARCH_DEFAULTS: Record<
string,
Record<string, { searchImpl?: 'tool' | 'params' | 'internal'; searchProvider?: string }>
> = {
openai: {
'gpt-4o-mini-search-preview': { searchImpl: 'internal' },
'gpt-4o-search-preview': { searchImpl: 'internal' },
// 可在此处添加其他特殊模型配置
},
spark: {
'max-32k': { searchImpl: 'internal' },
},
// 可在此处添加其他服务商的特殊模型配置
};
// 根据 providerId + modelId 推断默认 settings
const inferProviderSearchDefaults = (
providerId: string | undefined,
modelId: string,
): { searchImpl?: 'tool' | 'params' | 'internal'; searchProvider?: string } => {
const modelSpecificConfig = providerId ? MODEL_SEARCH_DEFAULTS[providerId]?.[modelId] : undefined;
if (modelSpecificConfig) {
return modelSpecificConfig;
}
return (providerId && PROVIDER_SEARCH_DEFAULTS[providerId]) || PROVIDER_SEARCH_DEFAULTS.default;
};
// 仅在读取时注入 settings; 根据 abilities.search 来添加或删去settings 中的 search 相关字段
const injectSearchSettings = (providerId: string, item: any) => {
const abilities = item?.abilities || {};
// 模型显式关闭搜索能力:移除 settings 中的 search 相关字段,确保 UI 不显示启用模型内置搜索
if (abilities.search === false) {
if (item?.settings?.searchImpl || item?.settings?.searchProvider) {
const next = { ...item } as any;
if (next.settings) {
const { searchImpl, searchProvider, ...restSettings } = next.settings;
next.settings = Object.keys(restSettings).length > 0 ? restSettings : undefined;
}
return next;
}
return item;
}
// 模型显式开启搜索能力:添加 settings 中的 search 相关字段
else if (abilities.search === true) {
// 内置(本地)模型如果已经带了任一字段,直接保留,不覆盖
if (item?.settings?.searchImpl || item?.settings?.searchProvider) return item;
// 否则按 providerId + modelId
const searchSettings = inferProviderSearchDefaults(providerId, item.id);
return {
...item,
settings: {
...item.settings,
...searchSettings,
},
};
}
// 兼容老版本中数据库没有存储 abilities.search 字段的情况
return item;
};
export class AiInfraRepos {
private userId: string;
private db: LobeChatDatabase;
aiProviderModel: AiProviderModel;
private readonly providerConfigs: Record<string, ProviderConfig>;
aiModelModel: AiModelModel;
constructor(
db: LobeChatDatabase,
userId: string,
providerConfigs: Record<string, ProviderConfig>,
) {
this.userId = userId;
this.db = db;
this.aiProviderModel = new AiProviderModel(db, userId);
this.aiModelModel = new AiModelModel(db, userId);
this.providerConfigs = providerConfigs;
}
/**
* Calculate the final providerList based on the known providerConfig
*/
getAiProviderList = async () => {
const userProviders = await this.aiProviderModel.getAiProviderList();
// 1. 先创建一个基于 DEFAULT_MODEL_PROVIDER_LIST id 顺序的映射
const orderMap = new Map(DEFAULT_MODEL_PROVIDER_LIST.map((item, index) => [item.id, index]));
const builtinProviders = DEFAULT_MODEL_PROVIDER_LIST.map((item) => ({
description: item.description,
enabled:
userProviders.some((provider) => provider.id === item.id && provider.enabled) ||
this.providerConfigs[item.id]?.enabled,
id: item.id,
name: item.name,
source: 'builtin',
})) as AiProviderListItem[];
const mergedProviders = mergeArrayById(builtinProviders, userProviders);
// 3. 根据 orderMap 排序
return mergedProviders.sort((a, b) => {
const orderA = orderMap.get(a.id) ?? Number.MAX_SAFE_INTEGER;
const orderB = orderMap.get(b.id) ?? Number.MAX_SAFE_INTEGER;
return orderA - orderB;
});
};
/**
* used in the chat page. to show the enabled providers
*/
getUserEnabledProviderList = async () => {
const list = await this.getAiProviderList();
return list
.filter((item) => item.enabled)
.sort((a, b) => a.sort! - b.sort!)
.map(
(item): EnabledProvider => ({
id: item.id,
logo: item.logo,
name: item.name,
source: item.source,
}),
);
};
/**
* used in the chat page. to show the enabled models
*/
getEnabledModels = async (filterEnabled: boolean = true) => {
const [providers, allModels] = await Promise.all([
this.getAiProviderList(),
this.aiModelModel.getAllModels(),
]);
const enabledProviders = providers.filter((item) => (filterEnabled ? item.enabled : true));
const builtinModelList = await pMap(
enabledProviders,
async (provider) => {
const aiModels = await this.fetchBuiltinModels(provider.id);
return (aiModels || [])
.map<EnabledAiModel & { enabled?: boolean | null }>((item) => {
const user = allModels.find((m) => m.id === item.id && m.providerId === provider.id);
// 用户未修改本地模型
if (!user)
return {
...item,
abilities: item.abilities || {},
providerId: provider.id,
};
const mergedModel = {
...item,
abilities: !isEmpty(user.abilities) ? user.abilities : item.abilities || {},
config: !isEmpty(user.config) ? user.config : item.config,
contextWindowTokens:
typeof user.contextWindowTokens === 'number'
? user.contextWindowTokens
: item.contextWindowTokens,
displayName: user?.displayName || item.displayName,
enabled: typeof user.enabled === 'boolean' ? user.enabled : item.enabled,
id: item.id,
providerId: provider.id,
settings: item.settings,
sort: user.sort || undefined,
type: user.type || item.type,
};
return injectSearchSettings(provider.id, mergedModel); // 用户修改本地模型,检查搜索设置
})
.filter((item) => (filterEnabled ? item.enabled : true));
},
{ concurrency: 10 },
);
const enabledProviderIds = new Set(enabledProviders.map((item) => item.id));
// 用户数据库模型,检查搜索设置
const appendedUserModels = allModels
.filter((item) =>
filterEnabled ? enabledProviderIds.has(item.providerId) && item.enabled : true,
)
.map((item) => injectSearchSettings(item.providerId, item));
return [...builtinModelList.flat(), ...appendedUserModels].sort(
(a, b) => (a?.sort || -1) - (b?.sort || -1),
) as EnabledAiModel[];
};
getAiProviderRuntimeState = async (
decryptor?: DecryptUserKeyVaults,
): Promise<AiProviderRuntimeState> => {
const [result, enabledAiProviders, allModels] = await Promise.all([
this.aiProviderModel.getAiProviderRuntimeConfig(decryptor),
this.getUserEnabledProviderList(),
this.getEnabledModels(false),
]);
const runtimeConfig = result;
Object.entries(result).forEach(([key, value]) => {
runtimeConfig[key] = merge(this.providerConfigs[key] || {}, value);
});
const enabledAiModels = allModels.filter((model) => model.enabled);
const enabledChatAiProviders = enabledAiProviders.filter((provider) => {
return allModels.some((model) => model.providerId === provider.id && model.type === 'chat');
});
const enabledImageAiProviders = enabledAiProviders.filter((provider) => {
return allModels.some((model) => model.providerId === provider.id && model.type === 'image');
});
return {
enabledAiModels,
enabledAiProviders,
enabledChatAiProviders,
enabledImageAiProviders,
runtimeConfig,
};
};
getAiProviderModelList = async (providerId: string) => {
const aiModels = await this.aiModelModel.getModelListByProviderId(providerId);
const defaultModels: AiProviderModelListItem[] =
(await this.fetchBuiltinModels(providerId)) || [];
// 这里不修改搜索设置不影响使用,但是为了get数据统一
const mergedModel = mergeArrayById(defaultModels, aiModels) as AiProviderModelListItem[];
return mergedModel.map((m) => injectSearchSettings(providerId, m));
};
/**
* use in the `/settings?active=provider&provider=[id]` page
*/
getAiProviderDetail = async (id: string, decryptor?: DecryptUserKeyVaults) => {
const config = await this.aiProviderModel.getAiProviderById(id, decryptor);
return merge(this.providerConfigs[id] || {}, config) as AiProviderDetailItem;
};
/**
* Fetch builtin models from config
*/
private fetchBuiltinModels = async (
providerId: string,
): Promise<AiProviderModelListItem[] | undefined> => {
try {
const modules = await import('model-bank');
// TODO: when model-bank is a separate module, we will try import from model-bank/[prividerId] again
// @ts-expect-error providerId is string
const providerModels = modules[providerId];
// use the serverModelLists as the defined server model list
// fallback to empty array for custom provider
const presetList = this.providerConfigs[providerId]?.serverModelLists || providerModels || [];
return (presetList as AIChatModelCard[]).map<AiProviderModelListItem>((m) => ({
...m,
enabled: m.enabled || false,
source: AiModelSourceEnum.Builtin,
}));
} catch (error) {
console.error(error);
// maybe provider id not exist
}
};
}