@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
556 lines (441 loc) • 17.4 kB
text/typescript
// @vitest-environment node
import { TraceNameMap } from '@lobechat/types';
import { ClientSecretPayload } from '@lobechat/types';
import { Langfuse } from 'langfuse';
import { LangfuseGenerationClient, LangfuseTraceClient } from 'langfuse-core';
import { ModelProvider } from 'model-bank';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import * as langfuseCfg from '@/envs/langfuse';
import { createTraceOptions } from '@/server/modules/ModelRuntime';
import { ChatStreamPayload, LobeOpenAI, ModelRuntime } from '../index';
import { providerRuntimeMap } from '../runtimeMap';
import { CreateImagePayload } from '../types/image';
import { AgentChatOptions } from './ModelRuntime';
const specialProviders = [
{ id: 'openai', payload: { apiKey: 'user-openai-key', baseURL: 'user-endpoint' } },
{
id: ModelProvider.Azure,
payload: {
apiKey: 'user-azure-key',
baseURL: 'user-azure-endpoint',
apiVersion: '2024-06-01',
},
},
{
id: ModelProvider.AzureAI,
payload: {
apiKey: 'user-azure-key',
baseURL: 'user-azure-endpoint',
},
},
{
id: ModelProvider.Bedrock,
payload: {
accessKeyId: 'user-aws-id',
accessKeySecret: 'user-aws-secret',
region: 'user-aws-region',
},
},
{
id: ModelProvider.Ollama,
payload: { baseURL: 'https://user-ollama-url' },
},
{
id: ModelProvider.Cloudflare,
payload: { baseURLOrAccountID: 'https://user-ollama-url' },
},
];
const testRuntime = (providerId: string, payload?: any) => {
describe(`${providerId} provider runtime`, () => {
it('should initialize correctly', async () => {
const jwtPayload: ClientSecretPayload = { apiKey: 'user-key', ...payload };
const runtime = await ModelRuntime.initializeWithProvider(providerId, jwtPayload);
// @ts-ignore
expect(runtime['_runtime']).toBeInstanceOf(providerRuntimeMap[providerId]);
if (payload?.baseURL) {
expect(runtime['_runtime'].baseURL).toBe(payload.baseURL);
}
});
});
};
let mockModelRuntime: ModelRuntime;
beforeEach(async () => {
const jwtPayload: ClientSecretPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' };
mockModelRuntime = await ModelRuntime.initializeWithProvider(ModelProvider.OpenAI, jwtPayload);
});
describe('ModelRuntime', () => {
describe('should initialize with various providers', () => {
const providers = Object.values(ModelProvider).filter((i) => i !== 'lobehub');
const specialProviderIds = [ModelProvider.VertexAI, ...specialProviders.map((p) => p.id)];
const generalTestProviders = providers.filter(
(provider) => !specialProviderIds.includes(provider),
);
generalTestProviders.forEach((provider) => {
testRuntime(provider);
});
specialProviders.forEach(({ id, payload }) => testRuntime(id, payload));
});
describe('ModelRuntime chat method', () => {
it('should run correctly', async () => {
const payload: ChatStreamPayload = {
messages: [{ role: 'user', content: 'Hello, world!' }],
model: 'text-davinci-002',
temperature: 0,
};
vi.spyOn(LobeOpenAI.prototype, 'chat').mockResolvedValue(new Response(''));
await mockModelRuntime.chat(payload);
});
it('should handle options correctly', async () => {
const payload: ChatStreamPayload = {
messages: [{ role: 'user', content: 'Hello, world!' }],
model: 'text-davinci-002',
temperature: 0,
};
const options: AgentChatOptions = {
provider: 'openai',
trace: {
traceId: 'test-trace-id',
traceName: TraceNameMap.SummaryTopicTitle,
sessionId: 'test-session-id',
topicId: 'test-topic-id',
tags: [],
userId: 'test-user-id',
},
};
vi.spyOn(LobeOpenAI.prototype, 'chat').mockResolvedValue(new Response(''));
await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
});
describe('callback', async () => {
const payload: ChatStreamPayload = {
messages: [{ role: 'user', content: 'Hello, world!' }],
model: 'text-davinci-002',
temperature: 0,
};
const options: AgentChatOptions = {
provider: 'openai',
trace: {
traceId: 'test-trace-id',
traceName: TraceNameMap.SummaryTopicTitle,
sessionId: 'test-session-id',
topicId: 'test-topic-id',
tags: [],
userId: 'test-user-id',
},
enableTrace: true,
};
const updateMock = vi.fn();
it('should call onToolsCalling correctly', async () => {
vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
ENABLE_LANGFUSE: true,
LANGFUSE_PUBLIC_KEY: 'abc',
LANGFUSE_SECRET_KEY: 'DDD',
} as any);
// 使用 spyOn 模拟 chat 方法
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
async (payload, { callback }: any) => {
// 模拟 onToolCall 回调的触发
if (callback?.onToolsCalling) {
await callback.onToolsCalling();
}
return new Response('abc');
},
);
vi.spyOn(LangfuseTraceClient.prototype, 'update').mockImplementation(updateMock);
await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
expect(updateMock).toHaveBeenCalledWith({ tags: ['Tools Calling'] });
});
it('should call onStart correctly', async () => {
vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
ENABLE_LANGFUSE: true,
LANGFUSE_PUBLIC_KEY: 'abc',
LANGFUSE_SECRET_KEY: 'DDD',
} as any);
vi.spyOn(LangfuseGenerationClient.prototype, 'update').mockImplementation(updateMock);
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
async (payload, { callback }: any) => {
if (callback?.onStart) {
callback.onStart();
}
return new Response('Success');
},
);
await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
// Verify onStart was called
expect(updateMock).toHaveBeenCalledWith({ completionStartTime: expect.any(Date) });
});
it('should call onCompletion correctly', async () => {
vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
ENABLE_LANGFUSE: true,
LANGFUSE_PUBLIC_KEY: 'abc',
LANGFUSE_SECRET_KEY: 'DDD',
} as any);
// Spy on the chat method and trigger onCompletion callback
vi.spyOn(LangfuseGenerationClient.prototype, 'update').mockImplementation(updateMock);
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
async (payload, { callback }: any) => {
if (callback?.onCompletion) {
await callback.onCompletion({ text: 'Test completion' });
}
return new Response('Success');
},
);
await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
// Verify onCompletion was called with expected output
expect(updateMock).toHaveBeenCalledWith({
endTime: expect.any(Date),
metadata: {},
output: 'Test completion',
});
});
it.skip('should call onFinal correctly', async () => {
vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
ENABLE_LANGFUSE: true,
LANGFUSE_PUBLIC_KEY: 'abc',
LANGFUSE_SECRET_KEY: 'DDD',
} as any);
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
async (payload, { callback }: any) => {
if (callback?.onFinal) {
await callback.onFinal('Test completion');
}
return new Response('Success');
},
);
const shutdownAsyncMock = vi.fn();
vi.spyOn(Langfuse.prototype, 'shutdownAsync').mockImplementation(shutdownAsyncMock);
await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
// Verify onCompletion was called with expected output
expect(shutdownAsyncMock).toHaveBeenCalled();
});
});
});
describe('ModelRuntime generateObject method', () => {
it('should run correctly', async () => {
const payload = {
model: 'gpt-4',
messages: [{ role: 'user' as const, content: 'Generate a JSON object' }],
schema: {
name: 'PersonSchema',
schema: {
type: 'object' as const,
properties: { name: { type: 'string' } },
},
},
};
const mockResponse = { name: 'John Doe' };
vi.spyOn(LobeOpenAI.prototype, 'generateObject').mockResolvedValue(mockResponse);
const result = await mockModelRuntime.generateObject(payload);
expect(LobeOpenAI.prototype.generateObject).toHaveBeenCalledWith(payload);
expect(result).toBe(mockResponse);
});
});
describe('ModelRuntime textToImage method', () => {
it('should run correctly', async () => {
const payload = {
model: 'stable-diffusion',
prompt: 'A beautiful landscape',
};
const mockResponse = ['https://example.com/image1.jpg', 'https://example.com/image2.jpg'];
vi.spyOn(LobeOpenAI.prototype, 'textToImage').mockResolvedValue(mockResponse);
const result = await mockModelRuntime.textToImage(payload);
expect(LobeOpenAI.prototype.textToImage).toHaveBeenCalledWith(payload);
expect(result).toBe(mockResponse);
});
it('should handle undefined textToImage method gracefully', async () => {
const payload = {
model: 'stable-diffusion',
prompt: 'A beautiful landscape',
};
// Mock runtime without textToImage method
const runtimeWithoutTextToImage = {
textToImage: undefined,
};
// @ts-ignore - testing edge case
mockModelRuntime['_runtime'] = runtimeWithoutTextToImage;
const result = await mockModelRuntime.textToImage(payload);
expect(result).toBeUndefined();
});
});
describe('ModelRuntime createImage method', () => {
it('should run correctly', async () => {
const payload: CreateImagePayload = {
model: 'dall-e-3',
params: {
prompt: 'A beautiful sunset over mountains',
width: 1024,
height: 1024,
},
};
const mockResponse = {
imageUrl: 'https://example.com/image.jpg',
width: 1024,
height: 1024,
};
vi.spyOn(LobeOpenAI.prototype, 'createImage').mockResolvedValue(mockResponse);
const result = await mockModelRuntime.createImage(payload);
expect(LobeOpenAI.prototype.createImage).toHaveBeenCalledWith(payload);
expect(result).toBe(mockResponse);
});
it('should handle undefined createImage method gracefully', async () => {
const payload: CreateImagePayload = {
model: 'dall-e-3',
params: {
prompt: 'A beautiful sunset over mountains',
width: 1024,
height: 1024,
},
};
// Mock runtime without createImage method
const runtimeWithoutCreateImage = {
createImage: undefined,
};
// @ts-ignore - testing edge case
mockModelRuntime['_runtime'] = runtimeWithoutCreateImage;
const result = await mockModelRuntime.createImage(payload);
expect(result).toBeUndefined();
});
});
describe('ModelRuntime models method', () => {
it('should run correctly', async () => {
const mockModels = [
{ id: 'gpt-4', name: 'GPT-4' },
{ id: 'gpt-3.5-turbo', name: 'GPT-3.5 Turbo' },
];
vi.spyOn(LobeOpenAI.prototype, 'models').mockResolvedValue(mockModels);
const result = await mockModelRuntime.models();
expect(LobeOpenAI.prototype.models).toHaveBeenCalled();
expect(result).toBe(mockModels);
});
it('should handle undefined models method gracefully', async () => {
// Mock runtime without models method
const runtimeWithoutModels = {
models: undefined,
};
// @ts-ignore - testing edge case
mockModelRuntime['_runtime'] = runtimeWithoutModels;
const result = await mockModelRuntime.models();
expect(result).toBeUndefined();
});
});
describe('ModelRuntime embeddings method', () => {
it('should run correctly', async () => {
const payload = {
model: 'text-embedding-ada-002',
input: 'Hello world',
};
const mockEmbeddings = [[0.1, 0.2, 0.3]];
vi.spyOn(LobeOpenAI.prototype, 'embeddings').mockResolvedValue(mockEmbeddings);
const result = await mockModelRuntime.embeddings(payload);
expect(LobeOpenAI.prototype.embeddings).toHaveBeenCalledWith(payload, undefined);
expect(result).toBe(mockEmbeddings);
});
it('should handle options correctly', async () => {
const payload = {
model: 'text-embedding-ada-002',
input: 'Hello world',
};
const options = {};
const mockEmbeddings = [[0.1, 0.2, 0.3]];
vi.spyOn(LobeOpenAI.prototype, 'embeddings').mockResolvedValue(mockEmbeddings);
const result = await mockModelRuntime.embeddings(payload, options);
expect(LobeOpenAI.prototype.embeddings).toHaveBeenCalledWith(payload, options);
expect(result).toBe(mockEmbeddings);
});
it('should handle undefined embeddings method gracefully', async () => {
const payload = {
model: 'text-embedding-ada-002',
input: 'Hello world',
};
// Mock runtime without embeddings method
const runtimeWithoutEmbeddings = {
embeddings: undefined,
};
// @ts-ignore - testing edge case
mockModelRuntime['_runtime'] = runtimeWithoutEmbeddings;
const result = await mockModelRuntime.embeddings(payload);
expect(result).toBeUndefined();
});
});
describe('ModelRuntime textToSpeech method', () => {
it('should run correctly', async () => {
const payload = {
model: 'tts-1',
input: 'Hello world',
voice: 'alloy',
};
const mockResponse = new ArrayBuffer(8);
vi.spyOn(LobeOpenAI.prototype, 'textToSpeech').mockResolvedValue(mockResponse);
const result = await mockModelRuntime.textToSpeech(payload);
expect(LobeOpenAI.prototype.textToSpeech).toHaveBeenCalledWith(payload, undefined);
expect(result).toBe(mockResponse);
});
it('should handle options correctly', async () => {
const payload = {
model: 'tts-1',
input: 'Hello world',
voice: 'alloy',
};
const options = {};
const mockResponse = new ArrayBuffer(8);
vi.spyOn(LobeOpenAI.prototype, 'textToSpeech').mockResolvedValue(mockResponse);
const result = await mockModelRuntime.textToSpeech(payload, options);
expect(LobeOpenAI.prototype.textToSpeech).toHaveBeenCalledWith(payload, options);
expect(result).toBe(mockResponse);
});
it('should handle undefined textToSpeech method gracefully', async () => {
const payload = {
model: 'tts-1',
input: 'Hello world',
voice: 'alloy',
};
// Mock runtime without textToSpeech method
const runtimeWithoutTextToSpeech = {
textToSpeech: undefined,
};
// @ts-ignore - testing edge case
mockModelRuntime['_runtime'] = runtimeWithoutTextToSpeech;
const result = await mockModelRuntime.textToSpeech(payload);
expect(result).toBeUndefined();
});
});
describe('ModelRuntime pullModel method', () => {
it('should run correctly', async () => {
const params = {
model: 'llama2',
};
const mockResponse = new Response('Success');
const mockPullModel = vi.fn().mockResolvedValue(mockResponse);
// Mock runtime with pullModel method
mockModelRuntime['_runtime'].pullModel = mockPullModel;
const result = await mockModelRuntime.pullModel(params);
expect(mockPullModel).toHaveBeenCalledWith(params, undefined);
expect(result).toBe(mockResponse);
});
it('should handle options correctly', async () => {
const params = {
model: 'llama2',
};
const options = {};
const mockResponse = new Response('Success');
const mockPullModel = vi.fn().mockResolvedValue(mockResponse);
// Mock runtime with pullModel method
mockModelRuntime['_runtime'].pullModel = mockPullModel;
const result = await mockModelRuntime.pullModel(params, options);
expect(mockPullModel).toHaveBeenCalledWith(params, options);
expect(result).toBe(mockResponse);
});
it('should handle undefined pullModel method gracefully', async () => {
const params = {
model: 'llama2',
};
// Mock runtime without pullModel method
const runtimeWithoutPullModel = {
pullModel: undefined,
};
// @ts-ignore - testing edge case
mockModelRuntime['_runtime'] = runtimeWithoutPullModel;
const result = await mockModelRuntime.pullModel(params);
expect(result).toBeUndefined();
});
});
});