UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

442 lines (367 loc) • 13.4 kB
import { beforeEach, describe, expect, it, vi } from 'vitest'; import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders'; import { clientDB, initializeDB } from '@/database/client/db'; import { AiProviderModelListItem, EnabledAiModel } from '@/types/aiModel'; import { AiProviderDetailItem, AiProviderListItem, AiProviderRuntimeConfig, EnabledProvider, } from '@/types/aiProvider'; import { AiInfraRepos } from './index'; const userId = 'test-user-id'; const mockProviderConfigs = { openai: { enabled: true }, anthropic: { enabled: false }, }; let repo: AiInfraRepos; beforeEach(async () => { await initializeDB(); vi.clearAllMocks(); repo = new AiInfraRepos(clientDB as any, userId, mockProviderConfigs); }); describe('AiInfraRepos', () => { describe('getAiProviderList', () => { it('should merge builtin and user providers correctly', async () => { const mockUserProviders = [ { id: 'openai', enabled: true, name: 'Custom OpenAI' }, { id: 'custom', enabled: true, name: 'Custom Provider' }, ] as AiProviderListItem[]; vi.spyOn(repo.aiProviderModel, 'getAiProviderList').mockResolvedValueOnce(mockUserProviders); const result = await repo.getAiProviderList(); expect(result).toBeDefined(); expect(result.length).toBeGreaterThan(0); // Verify the merge logic const openaiProvider = result.find((p) => p.id === 'openai'); expect(openaiProvider).toMatchObject({ enabled: true, name: 'Custom OpenAI' }); }); it('should sort providers according to DEFAULT_MODEL_PROVIDER_LIST order', async () => { vi.spyOn(repo.aiProviderModel, 'getAiProviderList').mockResolvedValue([]); const result = await repo.getAiProviderList(); expect(result).toEqual( expect.arrayContaining( DEFAULT_MODEL_PROVIDER_LIST.map((item) => expect.objectContaining({ id: item.id, source: 'builtin', }), ), ), ); }); }); describe('getUserEnabledProviderList', () => { it('should return only enabled providers', async () => { const mockProviders = [ { id: 'openai', enabled: true, name: 'OpenAI', sort: 1 }, { id: 'anthropic', enabled: false, name: 'Anthropic', sort: 2 }, ] as AiProviderListItem[]; vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); const result = await repo.getUserEnabledProviderList(); expect(result).toHaveLength(1); expect(result[0]).toMatchObject({ id: 'openai', name: 'OpenAI', }); }); it('should return only enabled provider', async () => { const mockProviders = [ { enabled: true, id: 'openai', logo: 'logo1', name: 'OpenAI', sort: 1, source: 'builtin' as const, }, { enabled: false, id: 'anthropic', logo: 'logo2', name: 'Anthropic', sort: 2, source: 'builtin' as const, }, ]; vi.spyOn(repo.aiProviderModel, 'getAiProviderList').mockResolvedValue(mockProviders); const result = await repo.getUserEnabledProviderList(); expect(result).toEqual([ { id: 'openai', logo: 'logo1', name: 'OpenAI', source: 'builtin', }, ]); }); }); describe('getEnabledModels', () => { it('should merge and filter enabled models', async () => { const mockProviders = [{ id: 'openai', enabled: true }] as AiProviderListItem[]; const mockAllModels = [ { id: 'gpt-4', providerId: 'openai', enabled: true }, ] as EnabledAiModel[]; vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue(mockAllModels); vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([ { id: 'gpt-4', enabled: true, type: 'chat' }, ]); const result = await repo.getEnabledModels(); expect(result).toBeDefined(); expect(result.length).toBeGreaterThan(0); expect(result[0]).toMatchObject({ id: 'gpt-4', providerId: 'openai', }); }); it('should merge builtin and user models correctly', async () => { const mockProviders = [ { enabled: true, id: 'openai', name: 'OpenAI', sort: 1, source: 'builtin' as const }, ]; const mockAllModels = [ { abilities: { vision: true }, displayName: 'Custom GPT-4', enabled: true, id: 'gpt-4', providerId: 'openai', sort: 1, type: 'chat' as const, contextWindowTokens: 10, }, ]; vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue(mockAllModels); vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([ { abilities: {}, displayName: 'GPT-4', enabled: true, id: 'gpt-4', type: 'chat' as const, }, ]); const result = await repo.getEnabledModels(); expect(result).toContainEqual( expect.objectContaining({ abilities: { vision: true }, displayName: 'Custom GPT-4', enabled: true, contextWindowTokens: 10, id: 'gpt-4', providerId: 'openai', sort: 1, type: 'chat', }), ); }); it('should handle case when user model not found', async () => { const mockProviders = [ { enabled: true, id: 'openai', name: 'OpenAI', sort: 1, source: 'builtin' as const }, ]; const mockAllModels: any[] = []; vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue(mockAllModels); vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([ { abilities: { reasoning: true }, displayName: 'GPT-4', enabled: true, id: 'gpt-4', type: 'chat' as const, }, ]); const result = await repo.getEnabledModels(); expect(result[0]).toEqual( expect.objectContaining({ abilities: { reasoning: true }, enabled: true, id: 'gpt-4', providerId: 'openai', }), ); }); it('should include settings property from builtin model', async () => { const mockProviders = [ { enabled: true, id: 'openai', name: 'OpenAI', source: 'builtin' }, ] as AiProviderListItem[]; const mockAllModels: EnabledAiModel[] = []; const mockSettings = { searchImpl: 'tool' as const }; vi.spyOn(repo, 'getAiProviderList').mockResolvedValue(mockProviders); vi.spyOn(repo.aiModelModel, 'getAllModels').mockResolvedValue(mockAllModels); vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue([ { enabled: true, id: 'gpt-4', settings: mockSettings, type: 'chat', }, ]); const result = await repo.getEnabledModels(); expect(result[0]).toMatchObject({ id: 'gpt-4', settings: mockSettings, }); }); }); describe('getAiProviderModelList', () => { it('should merge builtin and user models', async () => { const providerId = 'openai'; const mockUserModels = [ { id: 'custom-gpt4', enabled: true, type: 'chat' }, ] as AiProviderModelListItem[]; const mockBuiltinModels = [{ id: 'gpt-4', enabled: true }]; vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(mockUserModels); vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(mockBuiltinModels); const result = await repo.getAiProviderModelList(providerId); expect(result).toHaveLength(2); expect(result).toEqual( expect.arrayContaining([ expect.objectContaining({ id: 'custom-gpt4' }), expect.objectContaining({ id: 'gpt-4' }), ]), ); }); it('should merge default and custom models', async () => { const mockCustomModels = [ { displayName: 'Custom GPT-4', enabled: false, id: 'gpt-4', type: 'chat' as const, }, ]; const mockDefaultModels = [ { displayName: 'GPT-4', enabled: true, id: 'gpt-4', type: 'chat' as const, }, ]; vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue(mockCustomModels); vi.spyOn(repo as any, 'fetchBuiltinModels').mockResolvedValue(mockDefaultModels); const result = await repo.getAiProviderModelList('openai'); expect(result).toContainEqual( expect.objectContaining({ displayName: 'Custom GPT-4', enabled: false, id: 'gpt-4', }), ); }); it('should use builtin models', async () => { const providerId = 'ai21'; vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue([]); const result = await repo.getAiProviderModelList(providerId); expect(result).toHaveLength(2); expect(result).toEqual( expect.arrayContaining([ expect.objectContaining({ id: 'jamba-mini' }), expect.objectContaining({ id: 'jamba-large' }), ]), ); }); it('should return empty if not exist provider', async () => { const providerId = 'abc'; vi.spyOn(repo.aiModelModel, 'getModelListByProviderId').mockResolvedValue([]); const result = await repo.getAiProviderModelList(providerId); expect(result).toHaveLength(0); }); }); describe('getAiProviderRuntimeState', () => { it('should return complete runtime state', async () => { const mockRuntimeConfig = { openai: { apiKey: 'test-key' }, } as unknown as Record<string, AiProviderRuntimeConfig>; const mockEnabledProviders = [{ id: 'openai', name: 'OpenAI' }] as EnabledProvider[]; const mockEnabledModels = [{ id: 'gpt-4', providerId: 'openai' }] as EnabledAiModel[]; vi.spyOn(repo.aiProviderModel, 'getAiProviderRuntimeConfig').mockResolvedValue( mockRuntimeConfig, ); vi.spyOn(repo, 'getUserEnabledProviderList').mockResolvedValue(mockEnabledProviders); vi.spyOn(repo, 'getEnabledModels').mockResolvedValue(mockEnabledModels); const result = await repo.getAiProviderRuntimeState(); expect(result).toMatchObject({ enabledAiProviders: mockEnabledProviders, enabledAiModels: mockEnabledModels, runtimeConfig: expect.any(Object), }); }); it('should return provider runtime state', async () => { const mockRuntimeConfig = { openai: { apiKey: 'test-key', }, } as unknown as Record<string, AiProviderRuntimeConfig>; vi.spyOn(repo.aiProviderModel, 'getAiProviderRuntimeConfig').mockResolvedValue( mockRuntimeConfig, ); vi.spyOn(repo, 'getUserEnabledProviderList').mockResolvedValue([ { id: 'openai', logo: 'logo1', name: 'OpenAI', source: 'builtin' }, ]); vi.spyOn(repo, 'getEnabledModels').mockResolvedValue([ { abilities: {}, enabled: true, id: 'gpt-4', providerId: 'openai', type: 'chat', }, ]); const result = await repo.getAiProviderRuntimeState(); expect(result).toEqual({ enabledAiModels: [ expect.objectContaining({ enabled: true, id: 'gpt-4', providerId: 'openai', }), ], enabledAiProviders: [{ id: 'openai', logo: 'logo1', name: 'OpenAI', source: 'builtin' }], runtimeConfig: { openai: { apiKey: 'test-key', enabled: true, }, }, }); }); }); describe('getAiProviderDetail', () => { it('should merge provider config with user settings', async () => { const providerId = 'openai'; const mockProviderDetail = { id: providerId, customSetting: 'test', } as unknown as AiProviderDetailItem; vi.spyOn(repo.aiProviderModel, 'getAiProviderById').mockResolvedValue(mockProviderDetail); const result = await repo.getAiProviderDetail(providerId); expect(result).toMatchObject({ id: providerId, customSetting: 'test', enabled: true, // from mockProviderConfigs }); }); it('should merge provider configs correctly', async () => { const mockProviderDetail = { enabled: true, id: 'openai', keyVaults: { apiKey: 'test-key' }, name: 'Custom OpenAI', settings: {}, source: 'builtin' as const, }; vi.spyOn(repo.aiProviderModel, 'getAiProviderById').mockResolvedValue(mockProviderDetail); const result = await repo.getAiProviderDetail('openai'); expect(result).toEqual({ enabled: true, id: 'openai', keyVaults: { apiKey: 'test-key' }, name: 'Custom OpenAI', settings: {}, source: 'builtin', }); }); }); });