UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

466 lines (390 loc) • 16 kB
// @vitest-environment node import { ModelProvider } from 'model-bank'; import { beforeEach, describe, expect, it, vi } from 'vitest'; import { testProvider } from '../../providerTestUtils'; import { LobeCerebrasAI, params } from './index'; testProvider({ Runtime: LobeCerebrasAI, bizErrorType: 'ProviderBizError', chatDebugEnv: 'DEBUG_CEREBRAS_CHAT_COMPLETION', chatModel: 'llama3.1-8b', defaultBaseURL: 'https://api.cerebras.ai/v1', invalidErrorType: 'InvalidProviderAPIKey', provider: ModelProvider.Cerebras, test: { skipAPICall: true, skipErrorHandle: true, }, }); describe('LobeCerebrasAI - custom features', () => { let instance: InstanceType<typeof LobeCerebrasAI>; beforeEach(() => { instance = new LobeCerebrasAI({ apiKey: 'test_api_key' }); vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( new ReadableStream() as any, ); }); describe('params configuration', () => { it('should export params object with correct baseURL', () => { expect(params.baseURL).toBe('https://api.cerebras.ai/v1'); }); it('should export params with correct provider', () => { expect(params.provider).toBe(ModelProvider.Cerebras); }); it('should have chatCompletion handlePayload function', () => { expect(params.chatCompletion?.handlePayload).toBeDefined(); expect(typeof params.chatCompletion?.handlePayload).toBe('function'); }); it('should have debug configuration', () => { expect(params.debug).toBeDefined(); expect(params.debug.chatCompletion).toBeDefined(); expect(typeof params.debug.chatCompletion).toBe('function'); }); it('should have models function', () => { expect(params.models).toBeDefined(); expect(typeof params.models).toBe('function'); }); }); describe('debug configuration', () => { it('should disable debug by default', () => { delete process.env.DEBUG_CEREBRAS_CHAT_COMPLETION; const result = params.debug.chatCompletion(); expect(result).toBe(false); }); it('should enable debug when env is set to 1', () => { process.env.DEBUG_CEREBRAS_CHAT_COMPLETION = '1'; const result = params.debug.chatCompletion(); expect(result).toBe(true); }); it('should disable debug when env is set to 0', () => { process.env.DEBUG_CEREBRAS_CHAT_COMPLETION = '0'; const result = params.debug.chatCompletion(); expect(result).toBe(false); }); it('should disable debug when env is empty string', () => { process.env.DEBUG_CEREBRAS_CHAT_COMPLETION = ''; const result = params.debug.chatCompletion(); expect(result).toBe(false); }); }); describe('handlePayload', () => { it('should remove frequency_penalty and presence_penalty from payload', async () => { await instance.chat({ frequency_penalty: 0.5, messages: [{ content: 'Hello', role: 'user' }], model: 'llama3.1-8b', presence_penalty: 0.5, }); const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0]; expect(calledPayload.frequency_penalty).toBeUndefined(); expect(calledPayload.presence_penalty).toBeUndefined(); expect(calledPayload.model).toBe('llama3.1-8b'); }); it('should preserve model in the payload', async () => { await instance.chat({ frequency_penalty: 0.5, messages: [{ content: 'Test', role: 'user' }], model: 'llama3.1-70b', presence_penalty: 0.5, }); const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0]; expect(calledPayload.model).toBe('llama3.1-70b'); }); it('should preserve other payload properties', async () => { await instance.chat({ frequency_penalty: 0.5, max_tokens: 1000, messages: [{ content: 'Test', role: 'user' }], model: 'llama3.1-8b', presence_penalty: 0.5, stream: true, temperature: 0.8, top_p: 0.9, }); const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0]; expect(calledPayload.temperature).toBe(0.8); expect(calledPayload.max_tokens).toBe(1000); expect(calledPayload.top_p).toBe(0.9); expect(calledPayload.stream).toBe(true); expect(calledPayload.frequency_penalty).toBeUndefined(); expect(calledPayload.presence_penalty).toBeUndefined(); }); it('should handle payload without frequency_penalty and presence_penalty', async () => { await instance.chat({ messages: [{ content: 'Test', role: 'user' }], model: 'llama3.1-8b', temperature: 0.7, }); const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0]; expect(calledPayload.model).toBe('llama3.1-8b'); expect(calledPayload.temperature).toBe(0.7); expect(calledPayload.frequency_penalty).toBeUndefined(); expect(calledPayload.presence_penalty).toBeUndefined(); }); it('should handle payload with only frequency_penalty', async () => { await instance.chat({ frequency_penalty: 0.5, messages: [{ content: 'Test', role: 'user' }], model: 'llama3.1-8b', }); const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0]; expect(calledPayload.frequency_penalty).toBeUndefined(); expect(calledPayload.presence_penalty).toBeUndefined(); }); it('should handle payload with only presence_penalty', async () => { await instance.chat({ messages: [{ content: 'Test', role: 'user' }], model: 'llama3.1-8b', presence_penalty: 0.5, }); const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0]; expect(calledPayload.frequency_penalty).toBeUndefined(); expect(calledPayload.presence_penalty).toBeUndefined(); }); it('should handle payload with zero values for penalties', async () => { await instance.chat({ frequency_penalty: 0, messages: [{ content: 'Test', role: 'user' }], model: 'llama3.1-8b', presence_penalty: 0, }); const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0]; expect(calledPayload.frequency_penalty).toBeUndefined(); expect(calledPayload.presence_penalty).toBeUndefined(); }); it('should call handlePayload directly and verify transformation', () => { const payload = { frequency_penalty: 0.5, max_tokens: 100, messages: [{ content: 'Test', role: 'user' }], model: 'llama3.1-8b', presence_penalty: 0.5, temperature: 0.7, }; const transformedPayload = params.chatCompletion!.handlePayload!(payload as any); expect(transformedPayload.model).toBe('llama3.1-8b'); expect(transformedPayload.temperature).toBe(0.7); expect(transformedPayload.max_tokens).toBe(100); expect(transformedPayload.frequency_penalty).toBeUndefined(); expect(transformedPayload.presence_penalty).toBeUndefined(); }); }); describe('models function', () => { it('should fetch and process models with data property', async () => { const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockResolvedValue({ data: [ { id: 'llama3.1-8b', object: 'model', owned_by: 'cerebras' }, { id: 'llama3.1-70b', object: 'model', owned_by: 'cerebras' }, ], }), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toBeDefined(); expect(Array.isArray(models)).toBe(true); }); it('should handle models list without data property (direct array)', async () => { const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockResolvedValue([ { id: 'llama3.1-8b', object: 'model', owned_by: 'cerebras' }, { id: 'llama3.1-70b', object: 'model', owned_by: 'cerebras' }, ]), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toBeDefined(); expect(Array.isArray(models)).toBe(true); }); it('should handle empty models list with data property', async () => { const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockResolvedValue({ data: [], }), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toBeDefined(); expect(Array.isArray(models)).toBe(true); expect(models).toHaveLength(0); }); it('should handle empty models list without data property', async () => { const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockResolvedValue([]), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toBeDefined(); expect(Array.isArray(models)).toBe(true); expect(models).toHaveLength(0); }); it('should handle null response', async () => { const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockResolvedValue(null), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toBeDefined(); expect(Array.isArray(models)).toBe(true); expect(models).toHaveLength(0); }); it('should handle undefined response', async () => { const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockResolvedValue(undefined), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toBeDefined(); expect(Array.isArray(models)).toBe(true); expect(models).toHaveLength(0); }); it('should handle response with non-array data', async () => { const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockResolvedValue({ data: 'not-an-array', }), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toBeDefined(); expect(Array.isArray(models)).toBe(true); expect(models).toHaveLength(0); }); it('should handle network error and return empty array', async () => { const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockRejectedValue(new Error('Network error')), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toBeDefined(); expect(Array.isArray(models)).toBe(true); expect(models).toHaveLength(0); expect(consoleWarnSpy).toHaveBeenCalledWith( 'Failed to fetch Cerebras models. Please ensure your Cerebras API key is valid:', expect.any(Error), ); consoleWarnSpy.mockRestore(); }); it('should handle API authentication error and return empty array', async () => { const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); const mockClient = { apiKey: 'invalid_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockRejectedValue(new Error('401 Unauthorized')), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toEqual([]); expect(consoleWarnSpy).toHaveBeenCalledWith( 'Failed to fetch Cerebras models. Please ensure your Cerebras API key is valid:', expect.any(Error), ); consoleWarnSpy.mockRestore(); }); it('should handle API rate limit error and return empty array', async () => { const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockRejectedValue(new Error('429 Too Many Requests')), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toEqual([]); expect(consoleWarnSpy).toHaveBeenCalledWith( 'Failed to fetch Cerebras models. Please ensure your Cerebras API key is valid:', expect.any(Error), ); consoleWarnSpy.mockRestore(); }); it('should handle timeout error and return empty array', async () => { const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockRejectedValue(new Error('Request timeout')), }, } as any; const models = await params.models!({ client: mockClient }); expect(mockClient.models.list).toHaveBeenCalledTimes(1); expect(models).toEqual([]); expect(consoleWarnSpy).toHaveBeenCalledWith( 'Failed to fetch Cerebras models. Please ensure your Cerebras API key is valid:', expect.any(Error), ); consoleWarnSpy.mockRestore(); }); it('should handle malformed JSON response', async () => { const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockRejectedValue(new Error('Invalid JSON')), }, } as any; const models = await params.models!({ client: mockClient }); expect(models).toEqual([]); expect(consoleWarnSpy).toHaveBeenCalledWith( 'Failed to fetch Cerebras models. Please ensure your Cerebras API key is valid:', expect.any(Error), ); consoleWarnSpy.mockRestore(); }); it('should pass correct client to processMultiProviderModelList', async () => { const mockModelList = [ { id: 'llama3.1-8b', object: 'model', owned_by: 'cerebras' }, { id: 'llama3.1-70b', object: 'model', owned_by: 'cerebras' }, ]; const mockClient = { apiKey: 'test_api_key', baseURL: 'https://api.cerebras.ai/v1', models: { list: vi.fn().mockResolvedValue({ data: mockModelList }), }, } as any; const models = await params.models!({ client: mockClient }); // Verify processMultiProviderModelList was called with correct parameters expect(models).toBeDefined(); expect(Array.isArray(models)).toBe(true); }); }); });