@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
466 lines (390 loc) • 16 kB
text/typescript
// @vitest-environment node
import { ModelProvider } from 'model-bank';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { testProvider } from '../../providerTestUtils';
import { LobeCerebrasAI, params } from './index';
testProvider({
Runtime: LobeCerebrasAI,
bizErrorType: 'ProviderBizError',
chatDebugEnv: 'DEBUG_CEREBRAS_CHAT_COMPLETION',
chatModel: 'llama3.1-8b',
defaultBaseURL: 'https://api.cerebras.ai/v1',
invalidErrorType: 'InvalidProviderAPIKey',
provider: ModelProvider.Cerebras,
test: {
skipAPICall: true,
skipErrorHandle: true,
},
});
describe('LobeCerebrasAI - custom features', () => {
let instance: InstanceType<typeof LobeCerebrasAI>;
beforeEach(() => {
instance = new LobeCerebrasAI({ apiKey: 'test_api_key' });
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
});
describe('params configuration', () => {
it('should export params object with correct baseURL', () => {
expect(params.baseURL).toBe('https://api.cerebras.ai/v1');
});
it('should export params with correct provider', () => {
expect(params.provider).toBe(ModelProvider.Cerebras);
});
it('should have chatCompletion handlePayload function', () => {
expect(params.chatCompletion?.handlePayload).toBeDefined();
expect(typeof params.chatCompletion?.handlePayload).toBe('function');
});
it('should have debug configuration', () => {
expect(params.debug).toBeDefined();
expect(params.debug.chatCompletion).toBeDefined();
expect(typeof params.debug.chatCompletion).toBe('function');
});
it('should have models function', () => {
expect(params.models).toBeDefined();
expect(typeof params.models).toBe('function');
});
});
describe('debug configuration', () => {
it('should disable debug by default', () => {
delete process.env.DEBUG_CEREBRAS_CHAT_COMPLETION;
const result = params.debug.chatCompletion();
expect(result).toBe(false);
});
it('should enable debug when env is set to 1', () => {
process.env.DEBUG_CEREBRAS_CHAT_COMPLETION = '1';
const result = params.debug.chatCompletion();
expect(result).toBe(true);
});
it('should disable debug when env is set to 0', () => {
process.env.DEBUG_CEREBRAS_CHAT_COMPLETION = '0';
const result = params.debug.chatCompletion();
expect(result).toBe(false);
});
it('should disable debug when env is empty string', () => {
process.env.DEBUG_CEREBRAS_CHAT_COMPLETION = '';
const result = params.debug.chatCompletion();
expect(result).toBe(false);
});
});
describe('handlePayload', () => {
it('should remove frequency_penalty and presence_penalty from payload', async () => {
await instance.chat({
frequency_penalty: 0.5,
messages: [{ content: 'Hello', role: 'user' }],
model: 'llama3.1-8b',
presence_penalty: 0.5,
});
const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0];
expect(calledPayload.frequency_penalty).toBeUndefined();
expect(calledPayload.presence_penalty).toBeUndefined();
expect(calledPayload.model).toBe('llama3.1-8b');
});
it('should preserve model in the payload', async () => {
await instance.chat({
frequency_penalty: 0.5,
messages: [{ content: 'Test', role: 'user' }],
model: 'llama3.1-70b',
presence_penalty: 0.5,
});
const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0];
expect(calledPayload.model).toBe('llama3.1-70b');
});
it('should preserve other payload properties', async () => {
await instance.chat({
frequency_penalty: 0.5,
max_tokens: 1000,
messages: [{ content: 'Test', role: 'user' }],
model: 'llama3.1-8b',
presence_penalty: 0.5,
stream: true,
temperature: 0.8,
top_p: 0.9,
});
const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0];
expect(calledPayload.temperature).toBe(0.8);
expect(calledPayload.max_tokens).toBe(1000);
expect(calledPayload.top_p).toBe(0.9);
expect(calledPayload.stream).toBe(true);
expect(calledPayload.frequency_penalty).toBeUndefined();
expect(calledPayload.presence_penalty).toBeUndefined();
});
it('should handle payload without frequency_penalty and presence_penalty', async () => {
await instance.chat({
messages: [{ content: 'Test', role: 'user' }],
model: 'llama3.1-8b',
temperature: 0.7,
});
const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0];
expect(calledPayload.model).toBe('llama3.1-8b');
expect(calledPayload.temperature).toBe(0.7);
expect(calledPayload.frequency_penalty).toBeUndefined();
expect(calledPayload.presence_penalty).toBeUndefined();
});
it('should handle payload with only frequency_penalty', async () => {
await instance.chat({
frequency_penalty: 0.5,
messages: [{ content: 'Test', role: 'user' }],
model: 'llama3.1-8b',
});
const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0];
expect(calledPayload.frequency_penalty).toBeUndefined();
expect(calledPayload.presence_penalty).toBeUndefined();
});
it('should handle payload with only presence_penalty', async () => {
await instance.chat({
messages: [{ content: 'Test', role: 'user' }],
model: 'llama3.1-8b',
presence_penalty: 0.5,
});
const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0];
expect(calledPayload.frequency_penalty).toBeUndefined();
expect(calledPayload.presence_penalty).toBeUndefined();
});
it('should handle payload with zero values for penalties', async () => {
await instance.chat({
frequency_penalty: 0,
messages: [{ content: 'Test', role: 'user' }],
model: 'llama3.1-8b',
presence_penalty: 0,
});
const calledPayload = (instance['client'].chat.completions.create as any).mock.calls[0][0];
expect(calledPayload.frequency_penalty).toBeUndefined();
expect(calledPayload.presence_penalty).toBeUndefined();
});
it('should call handlePayload directly and verify transformation', () => {
const payload = {
frequency_penalty: 0.5,
max_tokens: 100,
messages: [{ content: 'Test', role: 'user' }],
model: 'llama3.1-8b',
presence_penalty: 0.5,
temperature: 0.7,
};
const transformedPayload = params.chatCompletion!.handlePayload!(payload as any);
expect(transformedPayload.model).toBe('llama3.1-8b');
expect(transformedPayload.temperature).toBe(0.7);
expect(transformedPayload.max_tokens).toBe(100);
expect(transformedPayload.frequency_penalty).toBeUndefined();
expect(transformedPayload.presence_penalty).toBeUndefined();
});
});
describe('models function', () => {
it('should fetch and process models with data property', async () => {
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockResolvedValue({
data: [
{ id: 'llama3.1-8b', object: 'model', owned_by: 'cerebras' },
{ id: 'llama3.1-70b', object: 'model', owned_by: 'cerebras' },
],
}),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toBeDefined();
expect(Array.isArray(models)).toBe(true);
});
it('should handle models list without data property (direct array)', async () => {
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockResolvedValue([
{ id: 'llama3.1-8b', object: 'model', owned_by: 'cerebras' },
{ id: 'llama3.1-70b', object: 'model', owned_by: 'cerebras' },
]),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toBeDefined();
expect(Array.isArray(models)).toBe(true);
});
it('should handle empty models list with data property', async () => {
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockResolvedValue({
data: [],
}),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toBeDefined();
expect(Array.isArray(models)).toBe(true);
expect(models).toHaveLength(0);
});
it('should handle empty models list without data property', async () => {
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockResolvedValue([]),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toBeDefined();
expect(Array.isArray(models)).toBe(true);
expect(models).toHaveLength(0);
});
it('should handle null response', async () => {
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockResolvedValue(null),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toBeDefined();
expect(Array.isArray(models)).toBe(true);
expect(models).toHaveLength(0);
});
it('should handle undefined response', async () => {
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockResolvedValue(undefined),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toBeDefined();
expect(Array.isArray(models)).toBe(true);
expect(models).toHaveLength(0);
});
it('should handle response with non-array data', async () => {
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockResolvedValue({
data: 'not-an-array',
}),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toBeDefined();
expect(Array.isArray(models)).toBe(true);
expect(models).toHaveLength(0);
});
it('should handle network error and return empty array', async () => {
const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockRejectedValue(new Error('Network error')),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toBeDefined();
expect(Array.isArray(models)).toBe(true);
expect(models).toHaveLength(0);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to fetch Cerebras models. Please ensure your Cerebras API key is valid:',
expect.any(Error),
);
consoleWarnSpy.mockRestore();
});
it('should handle API authentication error and return empty array', async () => {
const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
const mockClient = {
apiKey: 'invalid_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockRejectedValue(new Error('401 Unauthorized')),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toEqual([]);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to fetch Cerebras models. Please ensure your Cerebras API key is valid:',
expect.any(Error),
);
consoleWarnSpy.mockRestore();
});
it('should handle API rate limit error and return empty array', async () => {
const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockRejectedValue(new Error('429 Too Many Requests')),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toEqual([]);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to fetch Cerebras models. Please ensure your Cerebras API key is valid:',
expect.any(Error),
);
consoleWarnSpy.mockRestore();
});
it('should handle timeout error and return empty array', async () => {
const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockRejectedValue(new Error('Request timeout')),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(mockClient.models.list).toHaveBeenCalledTimes(1);
expect(models).toEqual([]);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to fetch Cerebras models. Please ensure your Cerebras API key is valid:',
expect.any(Error),
);
consoleWarnSpy.mockRestore();
});
it('should handle malformed JSON response', async () => {
const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockRejectedValue(new Error('Invalid JSON')),
},
} as any;
const models = await params.models!({ client: mockClient });
expect(models).toEqual([]);
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Failed to fetch Cerebras models. Please ensure your Cerebras API key is valid:',
expect.any(Error),
);
consoleWarnSpy.mockRestore();
});
it('should pass correct client to processMultiProviderModelList', async () => {
const mockModelList = [
{ id: 'llama3.1-8b', object: 'model', owned_by: 'cerebras' },
{ id: 'llama3.1-70b', object: 'model', owned_by: 'cerebras' },
];
const mockClient = {
apiKey: 'test_api_key',
baseURL: 'https://api.cerebras.ai/v1',
models: {
list: vi.fn().mockResolvedValue({ data: mockModelList }),
},
} as any;
const models = await params.models!({ client: mockClient });
// Verify processMultiProviderModelList was called with correct parameters
expect(models).toBeDefined();
expect(Array.isArray(models)).toBe(true);
});
});
});