UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

322 lines (280 loc) 10.2 kB
import OpenAI from 'openai'; import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { LobeOpenAICompatibleRuntime } from '@/libs/model-runtime'; import * as debugStreamModule from './utils/debugStream'; interface TesstProviderParams { Runtime: any; bizErrorType?: string; chatDebugEnv: string; chatModel: string; defaultBaseURL: string; invalidErrorType?: string; provider: string; test?: { skipAPICall?: boolean; }; } export const testProvider = ({ provider, invalidErrorType = 'InvalidProviderAPIKey', bizErrorType = 'ProviderBizError', defaultBaseURL, Runtime, chatDebugEnv, chatModel, test = {}, }: TesstProviderParams) => { // Mock the console.error to avoid polluting test output vi.spyOn(console, 'error').mockImplementation(() => {}); let instance: LobeOpenAICompatibleRuntime; beforeEach(() => { instance = new Runtime({ apiKey: 'test' }); // 使用 vi.spyOn 来模拟 chat.completions.create 方法 vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( new ReadableStream() as any, ); }); afterEach(() => { vi.clearAllMocks(); }); describe(`${provider} Runtime`, () => { describe('init', () => { it('should correctly initialize with an API key', async () => { const instance = new Runtime({ apiKey: 'test_api_key' }); expect(instance).toBeInstanceOf(Runtime); expect(instance.baseURL).toEqual(defaultBaseURL); }); }); describe('chat', () => { it('should return a StreamingTextResponse on successful API call', async () => { // Arrange const mockStream = new ReadableStream(); const mockResponse = Promise.resolve(mockStream); (instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse); // Act const result = await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: chatModel, temperature: 0, }); // Assert expect(result).toBeInstanceOf(Response); }); if (!test?.skipAPICall) { it(`should call ${provider} API with corresponding options`, async () => { // Arrange const mockStream = new ReadableStream(); const mockResponse = Promise.resolve(mockStream); (instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse); // Act const result = await instance.chat({ max_tokens: 1024, messages: [{ content: 'Hello', role: 'user' }], model: chatModel, temperature: 0.7, top_p: 1, }); // Assert expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( { max_tokens: 1024, messages: [{ content: 'Hello', role: 'user' }], model: chatModel, stream: true, stream_options: { include_usage: true, }, temperature: 0.7, top_p: 1, }, { headers: { Accept: '*/*' } }, ); expect(result).toBeInstanceOf(Response); }); } describe('Error', () => { it('should return ProviderBizError with an openai error response when OpenAI.APIError is thrown', async () => { // Arrange const apiError = new OpenAI.APIError( 400, { error: { message: 'Bad Request', }, status: 400, }, 'Error message', {}, ); vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); // Act try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: chatModel, temperature: 0, }); } catch (e) { expect(e).toEqual({ endpoint: defaultBaseURL, error: { error: { message: 'Bad Request' }, status: 400, }, errorType: bizErrorType, provider, }); } }); it('should throw AgentRuntimeError with InvalidProviderAPIKey if no apiKey is provided', async () => { try { new Runtime({}); } catch (e) { expect(e).toEqual({ errorType: invalidErrorType }); } }); it('should return ProviderBizError with the cause when OpenAI.APIError is thrown with cause', async () => { // Arrange const errorInfo = { cause: { message: 'api is undefined', }, stack: 'abc', }; const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); // Act try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: chatModel, temperature: 0, }); } catch (e) { expect(e).toEqual({ endpoint: defaultBaseURL, error: { cause: { message: 'api is undefined' }, stack: 'abc', }, errorType: bizErrorType, provider, }); } }); it('should return ProviderBizError with an cause response with desensitize Url', async () => { // Arrange const errorInfo = { cause: { message: 'api is undefined' }, stack: 'abc', }; const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); instance = new Runtime({ apiKey: 'test', baseURL: 'https://api.abc.com/v1', }); vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); // Act try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: chatModel, temperature: 0, }); } catch (e) { expect(e).toEqual({ endpoint: 'https://api.***.com/v1', error: { cause: { message: 'api is undefined' }, stack: 'abc', }, errorType: bizErrorType, provider, }); } }); it(`should throw an InvalidAPIKey error type on 401 status code`, async () => { // Mock the API call to simulate a 401 error const error = new Error('Unauthorized') as any; error.status = 401; vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error); try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: chatModel, temperature: 0, }); } catch (e) { // Expect the chat method to throw an error with InvalidHunyuanAPIKey expect(e).toEqual({ endpoint: defaultBaseURL, error: error, errorType: invalidErrorType, provider, }); } }); it('should return AgentRuntimeError for non-OpenAI errors', async () => { // Arrange const genericError = new Error('Generic Error'); vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError); // Act try { await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: chatModel, temperature: 0, }); } catch (e) { expect(e).toEqual({ endpoint: defaultBaseURL, error: { cause: genericError.cause, message: genericError.message, name: genericError.name, stack: genericError.stack, }, errorType: 'AgentRuntimeError', provider, }); } }); }); describe('DEBUG', () => { it(`should call debugStream and return StreamingTextResponse when ${chatDebugEnv} is 1`, async () => { // Arrange const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流 const mockDebugStream = new ReadableStream({ start(controller) { controller.enqueue('Debug stream content'); controller.close(); }, }) as any; mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法 // 模拟 chat.completions.create 返回值,包括模拟的 tee 方法 (instance['client'].chat.completions.create as Mock).mockResolvedValue({ tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], }); // 保存原始环境变量值 const originalDebugValue = process.env[chatDebugEnv]; // 模拟环境变量 process.env[chatDebugEnv] = '1'; vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); // 执行测试 // 运行你的测试函数,确保它会在条件满足时调用 debugStream // 假设的测试函数调用,你可能需要根据实际情况调整 await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], model: chatModel, stream: true, temperature: 0, }); // 验证 debugStream 被调用 expect(debugStreamModule.debugStream).toHaveBeenCalled(); // 恢复原始环境变量值 process.env[chatDebugEnv] = originalDebugValue; }); }); }); }); };