@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
508 lines (420 loc) • 16.2 kB
text/typescript
import { type Mock, beforeEach, describe, expect, it, vi } from 'vitest';
import type { ModelConfig } from '@/server/services/comfyui/config/modelRegistry';
import { resolveModel } from '@/server/services/comfyui/utils/staticModelLookup';
import {
type SD3Variant,
WorkflowDetector,
} from '@/server/services/comfyui/utils/workflowDetector';
// Mock static model lookup functions
vi.mock('../../utils/staticModelLookup', () => ({
resolveModel: vi.fn(),
getModelConfig: vi.fn(),
}));
describe('WorkflowDetector', () => {
const mockedResolveModel = resolveModel as Mock;
beforeEach(() => {
vi.clearAllMocks();
});
describe('detectModelType', () => {
describe('Input Processing', () => {
it('should remove "comfyui/" prefix from modelId', () => {
const mockConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 1,
recommendedDtype: 'default',
variant: 'dev',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('comfyui/flux-dev');
expect(mockedResolveModel).toHaveBeenCalledWith('flux-dev');
expect(result).toEqual({
architecture: 'FLUX',
isSupported: true,
variant: 'dev',
});
});
it('should handle modelId without comfyui prefix', () => {
const mockConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 1,
recommendedDtype: 'default',
variant: 'schnell',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('flux-schnell');
expect(mockedResolveModel).toHaveBeenCalledWith('flux-schnell');
expect(result).toEqual({
architecture: 'FLUX',
isSupported: true,
variant: 'schnell',
});
});
it('should handle multiple comfyui prefixes correctly', () => {
const mockConfig: ModelConfig = {
modelFamily: 'SD3',
priority: 1,
recommendedDtype: 'default',
variant: 'sd35',
};
mockedResolveModel.mockReturnValue(mockConfig);
// Only the first "comfyui/" should be removed
const result = WorkflowDetector.detectModelType('comfyui/comfyui/model');
expect(mockedResolveModel).toHaveBeenCalledWith('comfyui/model');
expect(result).toEqual({
architecture: 'SD3',
isSupported: true,
variant: 'sd35',
});
});
});
describe('FLUX Model Detection', () => {
it('should detect FLUX dev variant', () => {
const mockConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 1,
recommendedDtype: 'default',
variant: 'dev',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('flux-dev');
expect(result).toEqual({
architecture: 'FLUX',
isSupported: true,
variant: 'dev',
});
});
it('should detect FLUX schnell variant', () => {
const mockConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 2,
recommendedDtype: 'fp8_e4m3fn',
variant: 'schnell',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('flux-schnell-fp8');
expect(result).toEqual({
architecture: 'FLUX',
isSupported: true,
variant: 'schnell',
});
});
it('should detect FLUX kontext variant', () => {
const mockConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 1,
recommendedDtype: 'default',
variant: 'kontext',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('flux-kontext-dev');
expect(result).toEqual({
architecture: 'FLUX',
isSupported: true,
variant: 'kontext',
});
});
it('should detect FLUX krea model with dev variant', () => {
const mockConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 1,
recommendedDtype: 'default',
variant: 'dev',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('flux-krea-dev');
expect(result).toEqual({
architecture: 'FLUX',
isSupported: true,
variant: 'dev',
});
});
it('should handle FLUX model with comfyui prefix', () => {
const mockConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 2,
recommendedDtype: 'fp8_e5m2',
variant: 'dev',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('comfyui/custom-flux-model');
expect(mockedResolveModel).toHaveBeenCalledWith('custom-flux-model');
expect(result).toEqual({
architecture: 'FLUX',
isSupported: true,
variant: 'dev',
});
});
});
describe('Custom SD Model Detection', () => {
it('should detect custom SD model', () => {
const result = WorkflowDetector.detectModelType('stable-diffusion-custom');
// Custom SD models are hardcoded and don't use resolveModel
expect(mockedResolveModel).not.toHaveBeenCalled();
expect(result).toEqual({
architecture: 'SDXL', // Uses SDXL for img2img support
isSupported: true,
variant: 'custom-sd',
});
});
it('should detect custom SD refiner model', () => {
const result = WorkflowDetector.detectModelType('stable-diffusion-custom-refiner');
// Custom SD models are hardcoded and don't use resolveModel
expect(mockedResolveModel).not.toHaveBeenCalled();
expect(result).toEqual({
architecture: 'SDXL', // Uses SDXL for img2img support
isSupported: true,
variant: 'custom-sd',
});
});
it('should handle custom SD with comfyui prefix', () => {
const result = WorkflowDetector.detectModelType('comfyui/stable-diffusion-custom');
// Custom SD models are hardcoded and don't use resolveModel
expect(mockedResolveModel).not.toHaveBeenCalled();
expect(result).toEqual({
architecture: 'SDXL', // Uses SDXL for img2img support
isSupported: true,
variant: 'custom-sd',
});
});
});
describe('SD3 Model Detection', () => {
it('should detect SD3 sd35 variant', () => {
const mockConfig: ModelConfig = {
modelFamily: 'SD3',
priority: 1,
recommendedDtype: 'default',
variant: 'sd35',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('sd3.5_large');
expect(result).toEqual({
architecture: 'SD3',
isSupported: true,
variant: 'sd35',
});
});
it('should handle SD3 model with comfyui prefix', () => {
const mockConfig: ModelConfig = {
modelFamily: 'SD3',
priority: 2,
recommendedDtype: 'default',
variant: 'sd35',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('comfyui/sd3.5_medium');
expect(mockedResolveModel).toHaveBeenCalledWith('sd3.5_medium');
expect(result).toEqual({
architecture: 'SD3',
isSupported: true,
variant: 'sd35',
});
});
});
describe('Unknown/Unsupported Model Detection', () => {
it('should return unknown architecture when model is not found', () => {
mockedResolveModel.mockReturnValue(null);
const result = WorkflowDetector.detectModelType('unknown-model');
expect(result).toEqual({
architecture: 'unknown',
isSupported: false,
});
});
it('should return SDXL architecture for SDXL model family', () => {
const mockConfig: ModelConfig = {
modelFamily: 'SDXL' as any,
priority: 1,
recommendedDtype: 'default',
variant: 'sdxl-t2i',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('sdxl-base');
expect(result).toEqual({
architecture: 'SDXL',
isSupported: true,
variant: 'sdxl-t2i',
});
});
it('should return SD1 architecture for SD1 model family', () => {
const mockConfig: ModelConfig = {
modelFamily: 'SD1' as any,
priority: 3,
recommendedDtype: 'default',
variant: 'sd15-t2i',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('stable-diffusion-v1-5');
expect(result).toEqual({
architecture: 'SD1',
isSupported: true,
variant: 'sd15-t2i',
});
});
it('should handle null modelId by causing runtime error (expected behavior)', () => {
// According to the function signature, modelId is expected to be a string
// Passing null/undefined would cause a runtime error, which is expected behavior
expect(() => {
WorkflowDetector.detectModelType(null as any);
}).toThrow('Cannot read properties of null');
});
it('should handle undefined modelId by causing runtime error (expected behavior)', () => {
// According to the function signature, modelId is expected to be a string
// Passing null/undefined would cause a runtime error, which is expected behavior
expect(() => {
WorkflowDetector.detectModelType(undefined as any);
}).toThrow('Cannot read properties of undefined');
});
it('should handle empty string modelId', () => {
mockedResolveModel.mockReturnValue(null);
const result = WorkflowDetector.detectModelType('');
expect(mockedResolveModel).toHaveBeenCalledWith('');
expect(result).toEqual({
architecture: 'unknown',
isSupported: false,
});
});
it('should handle whitespace-only modelId', () => {
mockedResolveModel.mockReturnValue(null);
const result = WorkflowDetector.detectModelType(' ');
expect(mockedResolveModel).toHaveBeenCalledWith(' ');
expect(result).toEqual({
architecture: 'unknown',
isSupported: false,
});
});
});
describe('Type Casting', () => {
it('should properly cast FLUX variant to FluxVariant type', () => {
const mockConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 1,
recommendedDtype: 'default',
variant: 'dev',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('flux-model');
expect(result.variant).toBe('dev');
expect(typeof result.variant).toBe('string');
// Test with dev variant (krea uses dev workflow)
const mockKreaConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 1,
recommendedDtype: 'default',
variant: 'dev',
};
mockedResolveModel.mockReturnValue(mockKreaConfig);
const kreaResult = WorkflowDetector.detectModelType('flux-krea-model');
expect(kreaResult.variant).toBe('dev');
});
it('should properly cast SD3 variant to SD3Variant type', () => {
const mockConfig: ModelConfig = {
modelFamily: 'SD3',
priority: 1,
recommendedDtype: 'default',
variant: 'sd35',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('sd3-model');
expect(result.variant).toBe('sd35');
expect(typeof result.variant).toBe('string');
// Verify it matches SD3Variant type expectations
const sd3Variants: SD3Variant[] = ['sd35'];
expect(sd3Variants).toContain(result.variant as SD3Variant);
});
});
describe('Edge Cases', () => {
it('should handle special characters in modelId', () => {
mockedResolveModel.mockReturnValue(null);
const result = WorkflowDetector.detectModelType('model-with-special!@#$%^&*()_+');
expect(mockedResolveModel).toHaveBeenCalledWith('model-with-special!@#$%^&*()_+');
expect(result).toEqual({
architecture: 'unknown',
isSupported: false,
});
});
it('should handle modelId with path separators', () => {
const mockConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 1,
recommendedDtype: 'default',
variant: 'dev',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('path/to/model.safetensors');
expect(mockedResolveModel).toHaveBeenCalledWith('path/to/model.safetensors');
expect(result).toEqual({
architecture: 'FLUX',
isSupported: true,
variant: 'dev',
});
});
it('should handle very long modelId', () => {
const longModelId = 'a'.repeat(1000);
mockedResolveModel.mockReturnValue(null);
const result = WorkflowDetector.detectModelType(longModelId);
expect(mockedResolveModel).toHaveBeenCalledWith(longModelId);
expect(result).toEqual({
architecture: 'unknown',
isSupported: false,
});
});
it('should handle modelId that is only "comfyui/"', () => {
mockedResolveModel.mockReturnValue(null);
const result = WorkflowDetector.detectModelType('comfyui/');
expect(mockedResolveModel).toHaveBeenCalledWith('');
expect(result).toEqual({
architecture: 'unknown',
isSupported: false,
});
});
it('should handle case sensitivity in modelId', () => {
const mockConfig: ModelConfig = {
modelFamily: 'FLUX',
priority: 1,
recommendedDtype: 'default',
variant: 'dev',
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('COMFYUI/FLUX-DEV');
// Should not match the prefix replacement since it's case sensitive
expect(mockedResolveModel).toHaveBeenCalledWith('COMFYUI/FLUX-DEV');
expect(result).toEqual({
architecture: 'FLUX',
isSupported: true,
variant: 'dev',
});
});
});
describe('Configuration Edge Cases', () => {
it('should handle config with missing variant property', () => {
const mockConfig: Partial<ModelConfig> = {
modelFamily: 'FLUX',
priority: 1,
recommendedDtype: 'default',
// variant is missing
};
mockedResolveModel.mockReturnValue(mockConfig as ModelConfig);
const result = WorkflowDetector.detectModelType('flux-model');
expect(result).toEqual({
architecture: 'FLUX',
isSupported: true,
variant: undefined, // Will be cast to FluxVariant but is undefined
});
});
it('should handle config with null variant', () => {
const mockConfig: ModelConfig = {
modelFamily: 'SD3',
priority: 1,
recommendedDtype: 'default',
variant: null as any,
};
mockedResolveModel.mockReturnValue(mockConfig);
const result = WorkflowDetector.detectModelType('sd3-model');
expect(result).toEqual({
architecture: 'SD3',
isSupported: true,
variant: null, // Will be cast to SD3Variant but is null
});
});
});
});
});