@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
245 lines (210 loc) • 8.22 kB
text/typescript
// @vitest-environment node
import { Langfuse } from 'langfuse';
import { LangfuseGenerationClient, LangfuseTraceClient } from 'langfuse-core';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import * as langfuseCfg from '@/config/langfuse';
import { JWTPayload } from '@/const/auth';
import { TraceNameMap } from '@/const/trace';
import { AgentRuntime, ChatStreamPayload, LobeOpenAI, ModelProvider } from '@/libs/model-runtime';
import { providerRuntimeMap } from '@/libs/model-runtime/runtimeMap';
import { createTraceOptions } from '@/server/modules/AgentRuntime';
import { AgentChatOptions } from './ModelRuntime';
const specialProviders = [
{ id: 'openai', payload: { apiKey: 'user-openai-key', baseURL: 'user-endpoint' } },
{
id: ModelProvider.Azure,
payload: {
apiKey: 'user-azure-key',
baseURL: 'user-azure-endpoint',
apiVersion: '2024-06-01',
},
},
{
id: ModelProvider.AzureAI,
payload: {
apiKey: 'user-azure-key',
baseURL: 'user-azure-endpoint',
},
},
{
id: ModelProvider.Bedrock,
payload: {
accessKeyId: 'user-aws-id',
accessKeySecret: 'user-aws-secret',
region: 'user-aws-region',
},
},
{
id: ModelProvider.Ollama,
payload: { baseURL: 'https://user-ollama-url' },
},
{
id: ModelProvider.Cloudflare,
payload: { baseURLOrAccountID: 'https://user-ollama-url' },
},
];
const testRuntime = (providerId: string, payload?: any) => {
describe(`${providerId} provider runtime`, () => {
it('should initialize correctly', async () => {
const jwtPayload: JWTPayload = { apiKey: 'user-key', ...payload };
const runtime = await AgentRuntime.initializeWithProvider(providerId, jwtPayload);
// @ts-ignore
expect(runtime['_runtime']).toBeInstanceOf(providerRuntimeMap[providerId]);
if (payload?.baseURL) {
expect(runtime['_runtime'].baseURL).toBe(payload.baseURL);
}
});
});
};
let mockModelRuntime: AgentRuntime;
beforeEach(async () => {
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' };
mockModelRuntime = await AgentRuntime.initializeWithProvider(ModelProvider.OpenAI, jwtPayload);
});
describe('AgentRuntime', () => {
describe('should initialize with various providers', () => {
const providers = Object.values(ModelProvider);
const specialProviderIds = [ModelProvider.VertexAI, ...specialProviders.map((p) => p.id)];
const generalTestProviders = providers.filter(
(provider) => !specialProviderIds.includes(provider),
);
generalTestProviders.forEach((provider) => {
testRuntime(provider);
});
specialProviders.forEach(({ id, payload }) => testRuntime(id, payload));
});
describe('AgentRuntime chat method', () => {
it('should run correctly', async () => {
const payload: ChatStreamPayload = {
messages: [{ role: 'user', content: 'Hello, world!' }],
model: 'text-davinci-002',
temperature: 0,
};
vi.spyOn(LobeOpenAI.prototype, 'chat').mockResolvedValue(new Response(''));
await mockModelRuntime.chat(payload);
});
it('should handle options correctly', async () => {
const payload: ChatStreamPayload = {
messages: [{ role: 'user', content: 'Hello, world!' }],
model: 'text-davinci-002',
temperature: 0,
};
const options: AgentChatOptions = {
provider: 'openai',
trace: {
traceId: 'test-trace-id',
traceName: TraceNameMap.SummaryTopicTitle,
sessionId: 'test-session-id',
topicId: 'test-topic-id',
tags: [],
userId: 'test-user-id',
},
};
vi.spyOn(LobeOpenAI.prototype, 'chat').mockResolvedValue(new Response(''));
await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
});
describe('callback', async () => {
const payload: ChatStreamPayload = {
messages: [{ role: 'user', content: 'Hello, world!' }],
model: 'text-davinci-002',
temperature: 0,
};
const options: AgentChatOptions = {
provider: 'openai',
trace: {
traceId: 'test-trace-id',
traceName: TraceNameMap.SummaryTopicTitle,
sessionId: 'test-session-id',
topicId: 'test-topic-id',
tags: [],
userId: 'test-user-id',
},
enableTrace: true,
};
const updateMock = vi.fn();
it('should call onToolsCalling correctly', async () => {
vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
ENABLE_LANGFUSE: true,
LANGFUSE_PUBLIC_KEY: 'abc',
LANGFUSE_SECRET_KEY: 'DDD',
} as any);
// 使用 spyOn 模拟 chat 方法
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
async (payload, { callback }: any) => {
// 模拟 onToolCall 回调的触发
if (callback?.onToolsCalling) {
await callback.onToolsCalling();
}
return new Response('abc');
},
);
vi.spyOn(LangfuseTraceClient.prototype, 'update').mockImplementation(updateMock);
await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
expect(updateMock).toHaveBeenCalledWith({ tags: ['Tools Calling'] });
});
it('should call onStart correctly', async () => {
vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
ENABLE_LANGFUSE: true,
LANGFUSE_PUBLIC_KEY: 'abc',
LANGFUSE_SECRET_KEY: 'DDD',
} as any);
vi.spyOn(LangfuseGenerationClient.prototype, 'update').mockImplementation(updateMock);
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
async (payload, { callback }: any) => {
if (callback?.onStart) {
callback.onStart();
}
return new Response('Success');
},
);
await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
// Verify onStart was called
expect(updateMock).toHaveBeenCalledWith({ completionStartTime: expect.any(Date) });
});
it('should call onCompletion correctly', async () => {
vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
ENABLE_LANGFUSE: true,
LANGFUSE_PUBLIC_KEY: 'abc',
LANGFUSE_SECRET_KEY: 'DDD',
} as any);
// Spy on the chat method and trigger onCompletion callback
vi.spyOn(LangfuseGenerationClient.prototype, 'update').mockImplementation(updateMock);
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
async (payload, { callback }: any) => {
if (callback?.onCompletion) {
await callback.onCompletion({ text: 'Test completion' });
}
return new Response('Success');
},
);
await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
// Verify onCompletion was called with expected output
expect(updateMock).toHaveBeenCalledWith({
endTime: expect.any(Date),
metadata: {},
output: 'Test completion',
});
});
it.skip('should call onFinal correctly', async () => {
vi.spyOn(langfuseCfg, 'getLangfuseConfig').mockReturnValue({
ENABLE_LANGFUSE: true,
LANGFUSE_PUBLIC_KEY: 'abc',
LANGFUSE_SECRET_KEY: 'DDD',
} as any);
vi.spyOn(LobeOpenAI.prototype, 'chat').mockImplementation(
async (payload, { callback }: any) => {
if (callback?.onFinal) {
await callback.onFinal('Test completion');
}
return new Response('Success');
},
);
const shutdownAsyncMock = vi.fn();
vi.spyOn(Langfuse.prototype, 'shutdownAsync').mockImplementation(shutdownAsyncMock);
await mockModelRuntime.chat(payload, createTraceOptions(payload, options));
// Verify onCompletion was called with expected output
expect(shutdownAsyncMock).toHaveBeenCalled();
});
});
});
});