UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

278 lines (240 loc) 9.88 kB
import { describe, expect, it } from 'vitest'; import { MODEL_REGISTRY } from '@/server/services/comfyui/config/modelRegistry'; import { getAllModelNames, getModelConfig, getModelsByVariant, } from '@/server/services/comfyui/utils/staticModelLookup'; describe('ModelRegistry', () => { describe('MODEL_REGISTRY', () => { it('should be a non-empty object with valid structure', () => { expect(typeof MODEL_REGISTRY).toBe('object'); expect(Object.keys(MODEL_REGISTRY).length).toBeGreaterThan(0); // Check that all models have required fields Object.entries(MODEL_REGISTRY).forEach(([, config]) => { expect(config).toBeDefined(); expect(config.modelFamily).toBeDefined(); expect(config.priority).toBeTypeOf('number'); if (config.recommendedDtype) { expect( ['default', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2'].includes( config.recommendedDtype, ), ).toBe(true); } }); }); it('should contain essential model families', () => { const modelFamilies = Object.values(MODEL_REGISTRY).map((c) => c.modelFamily); const uniqueFamilies = [...new Set(modelFamilies)]; // Should have at least one model family and FLUX should be included expect(uniqueFamilies.length).toBeGreaterThan(0); expect(uniqueFamilies).toContain('FLUX'); }); it('should have valid priority ranges', () => { Object.entries(MODEL_REGISTRY).forEach(([, config]) => { // Priorities should be positive numbers expect(config.priority).toBeGreaterThan(0); expect(config.priority).toBeLessThanOrEqual(10); }); }); }); describe('getModelConfig', () => { it('should return model config for valid name', () => { // Get any available FLUX model instead of hardcoding const allModelNames = getAllModelNames(); const fluxModels = allModelNames.filter((name) => { const config = getModelConfig(name); return config?.modelFamily === 'FLUX'; }); expect(fluxModels.length).toBeGreaterThan(0); const config = getModelConfig(fluxModels[0]); expect(config).toBeDefined(); expect(config?.modelFamily).toBe('FLUX'); }); it('should return undefined for invalid name', () => { const config = getModelConfig('nonexistent.safetensors'); expect(config).toBeUndefined(); }); }); describe('getAllModelNames', () => { it('should return all model names', () => { const names = getAllModelNames(); expect(names.length).toBeGreaterThan(0); // Check if at least one FLUX model exists instead of hardcoding const hasFluxModel = names.some((name) => { const config = getModelConfig(name); return config?.modelFamily === 'FLUX'; }); expect(hasFluxModel).toBe(true); }); it('should return unique names', () => { const names = getAllModelNames(); const uniqueNames = [...new Set(names)]; expect(uniqueNames.length).toBe(names.length); }); }); describe('getModelsByVariant', () => { it('should return model names for valid variant', () => { const modelNames = getModelsByVariant('dev'); expect(modelNames.length).toBeGreaterThan(0); expect(Array.isArray(modelNames)).toBe(true); // Verify all returned names are strings and correspond to dev variant models modelNames.forEach((name) => { expect(typeof name).toBe('string'); const config = getModelConfig(name); expect(config).toBeDefined(); expect(config?.variant).toBe('dev'); }); }); it('should return models sorted by priority', () => { const modelNames = getModelsByVariant('dev'); expect(modelNames.length).toBeGreaterThan(1); // Verify priority sorting (lower priority number = higher priority) for (let i = 0; i < modelNames.length - 1; i++) { const config1 = getModelConfig(modelNames[i]); const config2 = getModelConfig(modelNames[i + 1]); expect(config1?.priority).toBeLessThanOrEqual(config2?.priority || 0); } }); it('should return empty array for invalid variant', () => { const models = getModelsByVariant('nonexistent' as any); expect(models).toEqual([]); }); }); describe('getModelConfig with options', () => { it('should support case-insensitive lookup', () => { // Get any FLUX dev model for testing case-insensitive lookup const allModels = getAllModelNames(); const fluxDevModel = allModels.find((name) => { const config = getModelConfig(name); return config?.modelFamily === 'FLUX' && config?.variant === 'dev'; }); if (fluxDevModel) { const config = getModelConfig(fluxDevModel.toUpperCase(), { caseInsensitive: true }); expect(config).toBeDefined(); expect(config?.modelFamily).toBe('FLUX'); expect(config?.variant).toBe('dev'); } else { // If no dev variant exists, test with any FLUX model const fluxModel = allModels.find((name) => { const config = getModelConfig(name); return config?.modelFamily === 'FLUX'; }); expect(fluxModel).toBeDefined(); const config = getModelConfig(fluxModel!.toUpperCase(), { caseInsensitive: true }); expect(config).toBeDefined(); expect(config?.modelFamily).toBe('FLUX'); } }); it('should return undefined for non-matching case without caseInsensitive option', () => { // Find any FLUX model and test uppercase version without case-insensitive flag const allModels = getAllModelNames(); const fluxModel = allModels.find((name) => { const config = getModelConfig(name); return config?.modelFamily === 'FLUX'; }); if (fluxModel) { const config = getModelConfig(fluxModel.toUpperCase()); expect(config).toBeUndefined(); } }); it('should filter by variant', () => { // Find models with different variants for testing const allModels = getAllModelNames(); const devModel = allModels.find((name) => { const config = getModelConfig(name); return config?.variant === 'dev'; }); if (devModel) { // Test matching variant const config = getModelConfig(devModel, { variant: 'dev' }); expect(config).toBeDefined(); expect(config?.variant).toBe('dev'); // Test non-matching variant const nonMatchingConfig = getModelConfig(devModel, { variant: 'schnell' }); expect(nonMatchingConfig).toBeUndefined(); } }); it('should filter by modelFamily', () => { // 测试 SD3.5 模型家族 const config = getModelConfig('sd3.5_large.safetensors', { modelFamily: 'SD3' }); expect(config).toBeDefined(); expect(config?.modelFamily).toBe('SD3'); // 测试不匹配的 modelFamily const nonMatchingConfig = getModelConfig('sd3.5_large.safetensors', { modelFamily: 'FLUX' }); expect(nonMatchingConfig).toBeUndefined(); }); it('should filter by priority', () => { // Find a model with priority 1 for testing const allModels = getAllModelNames(); const priority1Model = allModels.find((name) => { const config = getModelConfig(name); return config?.priority === 1; }); if (priority1Model) { const config = getModelConfig(priority1Model, { priority: 1 }); expect(config).toBeDefined(); // Test non-matching priority const nonMatchingConfig = getModelConfig(priority1Model, { priority: 999 }); expect(nonMatchingConfig).toBeUndefined(); } }); it('should filter by recommendedDtype', () => { // flux_shakker_labs_union_pro-fp8_e4m3fn 有 fp8_e4m3fn const config = getModelConfig('flux_shakker_labs_union_pro-fp8_e4m3fn.safetensors', { recommendedDtype: 'fp8_e4m3fn', }); expect(config).toBeDefined(); expect(config?.recommendedDtype).toBe('fp8_e4m3fn'); // 测试不匹配的 recommendedDtype const nonMatchingConfig = getModelConfig( 'flux_shakker_labs_union_pro-fp8_e4m3fn.safetensors', { recommendedDtype: 'default' }, ); expect(nonMatchingConfig).toBeUndefined(); }); it('should combine multiple filters', () => { // Find a FLUX dev model with priority 1 for testing const allModels = getAllModelNames(); const testModel = allModels.find((name) => { const config = getModelConfig(name); return ( config?.modelFamily === 'FLUX' && config?.variant === 'dev' && config?.priority === 1 ); }); if (testModel) { // All filters match const config = getModelConfig(testModel, { modelFamily: 'FLUX', priority: 1, variant: 'dev', }); expect(config).toBeDefined(); // One filter doesn't match const nonMatchingConfig = getModelConfig(testModel, { modelFamily: 'FLUX', priority: 999, // Wrong priority variant: 'dev', }); expect(nonMatchingConfig).toBeUndefined(); } }); it('should handle case-insensitive with other filters', () => { // Find a FLUX dev model for testing const allModels = getAllModelNames(); const fluxDevModel = allModels.find((name) => { const config = getModelConfig(name); return config?.modelFamily === 'FLUX' && config?.variant === 'dev'; }); if (fluxDevModel) { const config = getModelConfig(fluxDevModel.toUpperCase(), { caseInsensitive: true, modelFamily: 'FLUX', variant: 'dev', }); expect(config).toBeDefined(); } }); }); });