UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

325 lines (272 loc) 10.5 kB
// @vitest-environment node import { Langfuse } from 'langfuse'; import { LangfuseGenerationClient, LangfuseTraceClient } from 'langfuse-core'; import { beforeEach, describe, expect, it, vi } from 'vitest'; import * as langfuseCfg from '@/config/langfuse'; import { JWTPayload } from '@/const/auth'; import { TraceNameMap } from '@/const/trace'; import { AgentRuntime, ChatStreamPayload, LobeOpenAI, ModelProvider } from '@/libs/model-runtime'; import { providerRuntimeMap } from '@/libs/model-runtime/runtimeMap'; import { CreateImagePayload } from '@/libs/model-runtime/types/image'; import { createTraceOptions } from '@/server/modules/AgentRuntime'; import { AgentChatOptions } from './ModelRuntime'; const specialProviders = [ { id: 'openai', payload: { apiKey: 'user-openai-key', baseURL: 'user-endpoint' } }, { id: ModelProvider.Azure, payload: { apiKey: 'user-azure-key', baseURL: 'user-azure-endpoint', apiVersion: '2024-06-01', }, }, { id: ModelProvider.AzureAI, payload: { apiKey: 'user-azure-key', baseURL: 'user-azure-endpoint', }, }, { id: ModelProvider.Bedrock, payload: { accessKeyId: 'user-aws-id', accessKeySecret: 'user-aws-secret', region: 'user-aws-region', }, }, { id: ModelProvider.Ollama, payload: { baseURL: 'https://user-ollama-url' }, }, { id: ModelProvider.Cloudflare, payload: { baseURLOrAccountID: 'https://user-ollama-url' }, }, ]; const testRuntime = (providerId: string, payload?: any) => { describe(`${providerId} provider runtime`, () => { it('should initialize correctly', async () => { const jwtPayload: JWTPayload = { apiKey: 'user-key', ...payload }; const runtime = await AgentRuntime.initializeWithProvider(providerId, jwtPayload); // @ts-ignore expect(runtime['_runtime']).toBeInstanceOf(providerRuntimeMap[providerId]); if (payload?.baseURL) { expect(runtime['_runtime'].baseURL).toBe(payload.baseURL); } }); }); }; let mockModelRuntime: AgentRuntime; beforeEach(async () => { const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' }; mockModelRuntime = await AgentRuntime.initializeWithProvider(ModelProvider.OpenAI, jwtPayload); }); describe('AgentRuntime', () => { describe('should initialize with various providers', () => { const providers = Object.values(ModelProvider); const specialProviderIds = [ModelProvider.VertexAI, ...specialProviders.map((p) => p.id)]; const generalTestProviders = providers.filter( (provider) => !specialProviderIds.includes(provider), ); generalTestProviders.forEach((provider) => { testRuntime(provider); }); specialProviders.forEach(({ id, payload }) => testRuntime(id, payload)); }); describe('AgentRuntime chat method', () => { it('should run correctly', async () => { const payload: ChatStreamPayload = { messages: [{ role: 'user', content: 'Hello, world!' }], model: 'text-davinci-002', temperature: 0, }; vi.spyOn(LobeOpenAI.prototype, 'chat').mockResolvedValue(new Response('')); await mockModelRuntime.chat(payload); }); it('should handle options correctly', async () => { const payload: ChatStreamPayload = { messages: [{ role: 'user', content: 'Hello, world!' }], model: 'text-davinci-002', temperature: 0, }; const options: AgentChatOptions = { provider: 'openai', trace: { traceId: 'test-trace-id', traceName: TraceNameMap.SummaryTopicTitle, sessionId: 'test-session-id', topicId: 'test-topic-id', tags: [], userId: 'test-user-id', }, }; vi.spyOn(LobeOpenAI.prototype, 'chat').mockResolvedValue(new Response('')); await mockModelRuntime.chat(payload, createTraceOptions(payload, options)); }); describe('callback', async () => { const payload: ChatStreamPayload = { messages: [{ role: 'user', content: 'Hello, world!' }], model: 'text-davinci-002', temperature: 0, }; const options: AgentChatOptions = { provider: 'openai', trace: { traceId: 'test-trace-id', traceName: TraceNameMap.SummaryTopicTitle, sessionId: 'test-session-id', topicId: 'test-topic-id', tags: [], userId: 'test-user-id', }, enableTrace: true, }; const updateMock = vi.fn(); it('should call onToolsCalling correctly', async () => { vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({ ENABLE_LANGFUSE: true, LANGFUSE_PUBLIC_KEY: 'abc', LANGFUSE_SECRET_KEY: 'DDD', } as any); // 使用 spyOn 模拟 chat 方法 vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation( async (payload, { callback }: any) => { // 模拟 onToolCall 回调的触发 if (callback?.onToolsCalling) { await callback.onToolsCalling(); } return new Response('abc'); }, ); vi.spyOn(LangfuseTraceClient.prototype, 'update').mockImplementation(updateMock); await mockModelRuntime.chat(payload, createTraceOptions(payload, options)); expect(updateMock).toHaveBeenCalledWith({ tags: ['Tools Calling'] }); }); it('should call onStart correctly', async () => { vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({ ENABLE_LANGFUSE: true, LANGFUSE_PUBLIC_KEY: 'abc', LANGFUSE_SECRET_KEY: 'DDD', } as any); vi.spyOn(LangfuseGenerationClient.prototype, 'update').mockImplementation(updateMock); vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation( async (payload, { callback }: any) => { if (callback?.onStart) { callback.onStart(); } return new Response('Success'); }, ); await mockModelRuntime.chat(payload, createTraceOptions(payload, options)); // Verify onStart was called expect(updateMock).toHaveBeenCalledWith({ completionStartTime: expect.any(Date) }); }); it('should call onCompletion correctly', async () => { vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({ ENABLE_LANGFUSE: true, LANGFUSE_PUBLIC_KEY: 'abc', LANGFUSE_SECRET_KEY: 'DDD', } as any); // Spy on the chat method and trigger onCompletion callback vi.spyOn(LangfuseGenerationClient.prototype, 'update').mockImplementation(updateMock); vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation( async (payload, { callback }: any) => { if (callback?.onCompletion) { await callback.onCompletion({ text: 'Test completion' }); } return new Response('Success'); }, ); await mockModelRuntime.chat(payload, createTraceOptions(payload, options)); // Verify onCompletion was called with expected output expect(updateMock).toHaveBeenCalledWith({ endTime: expect.any(Date), metadata: {}, output: 'Test completion', }); }); it.skip('should call onFinal correctly', async () => { vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({ ENABLE_LANGFUSE: true, LANGFUSE_PUBLIC_KEY: 'abc', LANGFUSE_SECRET_KEY: 'DDD', } as any); vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation( async (payload, { callback }: any) => { if (callback?.onFinal) { await callback.onFinal('Test completion'); } return new Response('Success'); }, ); const shutdownAsyncMock = vi.fn(); vi.spyOn(Langfuse.prototype, 'shutdownAsync').mockImplementation(shutdownAsyncMock); await mockModelRuntime.chat(payload, createTraceOptions(payload, options)); // Verify onCompletion was called with expected output expect(shutdownAsyncMock).toHaveBeenCalled(); }); }); }); describe('AgentRuntime createImage method', () => { it('should run correctly', async () => { const payload: CreateImagePayload = { model: 'dall-e-3', params: { prompt: 'A beautiful sunset over mountains', width: 1024, height: 1024, }, }; const mockResponse = { imageUrl: 'https://example.com/image.jpg', width: 1024, height: 1024, }; vi.spyOn(LobeOpenAI.prototype, 'createImage').mockResolvedValue(mockResponse); const result = await mockModelRuntime.createImage(payload); expect(LobeOpenAI.prototype.createImage).toHaveBeenCalledWith(payload); expect(result).toBe(mockResponse); }); it('should handle undefined createImage method gracefully', async () => { const payload: CreateImagePayload = { model: 'dall-e-3', params: { prompt: 'A beautiful sunset over mountains', width: 1024, height: 1024, }, }; // Mock runtime without createImage method const runtimeWithoutCreateImage = { createImage: undefined, }; // @ts-ignore - testing edge case mockModelRuntime['_runtime'] = runtimeWithoutCreateImage; const result = await mockModelRuntime.createImage(payload); expect(result).toBeUndefined(); }); }); describe('AgentRuntime models method', () => { it('should run correctly', async () => { const mockModels = [ { id: 'gpt-4', name: 'GPT-4' }, { id: 'gpt-3.5-turbo', name: 'GPT-3.5 Turbo' }, ]; vi.spyOn(LobeOpenAI.prototype, 'models').mockResolvedValue(mockModels); const result = await mockModelRuntime.models(); expect(LobeOpenAI.prototype.models).toHaveBeenCalled(); expect(result).toBe(mockModels); }); it('should handle undefined models method gracefully', async () => { // Mock runtime without models method const runtimeWithoutModels = { models: undefined, }; // @ts-ignore - testing edge case mockModelRuntime['_runtime'] = runtimeWithoutModels; const result = await mockModelRuntime.models(); expect(result).toBeUndefined(); }); }); });