@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
118 lines (101 loc) • 4.47 kB
text/typescript
import { act, renderHook } from '@testing-library/react';
import { describe, expect, it, vi } from 'vitest';
import { fileService } from '@/services/file';
import { ClientService } from '@/services/file/_deprecated';
import { messageService } from '@/services/message';
import { imageGenerationService } from '@/services/textToImage';
import { uploadService } from '@/services/upload';
import { chatSelectors } from '@/store/chat/selectors';
import { ChatMessage } from '@/types/message';
import { DallEImageItem } from '@/types/tool/dalle';
import { useChatStore } from '../../../store';
describe('chatToolSlice - dalle', () => {
describe('generateImageFromPrompts', () => {
it('should generate images from prompts, update items, and upload images', async () => {
const { result } = renderHook(() => useChatStore());
const initialMessageContent = JSON.stringify([
{ prompt: 'test prompt', previewUrl: 'old-url', imageId: 'old-id' },
]);
vi.spyOn(chatSelectors, 'getMessageById').mockImplementationOnce(
(id) => () =>
({
id,
content: initialMessageContent,
}) as ChatMessage,
);
const messageId = 'message-id';
const prompts = [
{ prompt: 'test prompt 1' },
{ prompt: 'test prompt 2' },
] as DallEImageItem[];
const mockUrl = 'https://example.com/image.png';
const mockId = 'image-id';
vi.spyOn(imageGenerationService, 'generateImage').mockResolvedValue(mockUrl);
vi.spyOn(uploadService, 'getImageFileByUrlWithCORS').mockResolvedValue(
new File(['1'], 'file.png', { type: 'image/png' }),
);
// @ts-ignore
vi.spyOn(uploadService, 'uploadToClientS3').mockResolvedValue({} as any);
vi.spyOn(ClientService.prototype, 'createFile').mockResolvedValue({
id: mockId,
url: '',
});
vi.spyOn(result.current, 'toggleDallEImageLoading');
vi.spyOn(ClientService.prototype, 'checkFileHash').mockImplementation(
async () => ({ isExist: false }) as any,
);
await act(async () => {
await result.current.generateImageFromPrompts(prompts, messageId);
});
// For each prompt, loading is toggled on and then off
expect(imageGenerationService.generateImage).toHaveBeenCalledTimes(prompts.length);
// @ts-ignore
expect(uploadService.uploadToClientS3).toHaveBeenCalledTimes(prompts.length);
expect(result.current.toggleDallEImageLoading).toHaveBeenCalledTimes(prompts.length * 2);
});
});
describe('updateImageItem', () => {
it('should update image item correctly', async () => {
const { result } = renderHook(() => useChatStore());
const messageId = 'message-id';
const initialMessageContent = JSON.stringify([
{ prompt: 'test prompt', previewUrl: 'old-url', imageId: 'old-id' },
]);
const updateFunction = (draft: any) => {
draft[0].previewUrl = 'new-url';
draft[0].imageId = 'new-id';
};
vi.spyOn(result.current, 'internal_updateMessageContent');
// 模拟 getMessageById 返回消息内容
vi.spyOn(chatSelectors, 'getMessageById').mockImplementationOnce(
(id) => () =>
({
id,
content: initialMessageContent,
}) as ChatMessage,
);
vi.spyOn(messageService, 'updateMessage').mockResolvedValueOnce(undefined);
await act(async () => {
await result.current.updateImageItem(messageId, updateFunction);
});
// 验证 internal_updateMessageContent 是否被正确调用以更新内容
expect(result.current.internal_updateMessageContent).toHaveBeenCalledWith(
messageId,
JSON.stringify([{ prompt: 'test prompt', previewUrl: 'new-url', imageId: 'new-id' }]),
);
});
});
describe('text2image', () => {
it('should call generateImageFromPrompts with provided data', async () => {
const { result } = renderHook(() => useChatStore());
const id = 'message-id';
const data = [{ prompt: 'prompt 1' }, { prompt: 'prompt 2' }] as DallEImageItem[];
// Mock generateImageFromPrompts
const generateImageFromPromptsMock = vi.spyOn(result.current, 'generateImageFromPrompts');
await act(async () => {
await result.current.text2image(id, data);
});
expect(generateImageFromPromptsMock).toHaveBeenCalledWith(data, id);
});
});
});