UNPKG

@boundless-oss/atlas

Version:

Atlas - MCP Server for comprehensive startup project management

320 lines (246 loc) 10.7 kB
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; import { LocalEmbeddingModel } from '../embeddings.js'; import type { EmbeddingVector } from '../types.js'; import * as fs from 'fs'; import * as path from 'path'; // Mock the fs module vi.mock('fs', () => ({ promises: { mkdir: vi.fn(), readFile: vi.fn(), writeFile: vi.fn(), access: vi.fn() } })); describe('LocalEmbeddingModel', () => { let model: LocalEmbeddingModel; const mockFs = fs.promises as any; beforeEach(() => { vi.clearAllMocks(); model = new LocalEmbeddingModel({ modelName: 'test-model', dimension: 384, cachePath: '.atlas/test-cache' }); }); afterEach(() => { vi.restoreAllMocks(); }); describe('initialization', () => { it('should initialize with default configuration', () => { const defaultModel = new LocalEmbeddingModel(); expect(defaultModel.modelName).toBe('all-MiniLM-L6-v2'); expect(defaultModel.dimension).toBe(384); }); it('should initialize with custom configuration', () => { expect(model.modelName).toBe('test-model'); expect(model.dimension).toBe(384); }); it('should create cache directory on initialization', async () => { mockFs.access.mockRejectedValue(new Error('Not found')); mockFs.mkdir.mockResolvedValue(undefined); await model.initialize(); expect(mockFs.mkdir).toHaveBeenCalledWith('.atlas/test-cache', { recursive: true }); }); it('should not create cache directory if it exists', async () => { mockFs.access.mockResolvedValue(undefined); await model.initialize(); expect(mockFs.mkdir).not.toHaveBeenCalled(); }); }); describe('embed', () => { it('should generate embeddings for multiple texts', async () => { const texts = ['Hello world', 'Test embedding', 'Third text']; // Mock the internal embedding generation const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); mockEmbed.mockImplementation((text: string) => { const vector = new Float32Array(384); // Simple mock: use text length as first value vector[0] = text.length / 100; return Promise.resolve(vector); }); const embeddings = await model.embed(texts); expect(embeddings).toHaveLength(3); expect(embeddings[0]).toBeInstanceOf(Float32Array); expect(embeddings[0].length).toBe(384); expect(embeddings[0][0]).toBeCloseTo(0.11); // "Hello world".length / 100 expect(embeddings[1][0]).toBeCloseTo(0.14); // "Test embedding".length / 100 }); it('should handle empty array', async () => { const embeddings = await model.embed([]); expect(embeddings).toHaveLength(0); }); it('should batch embeddings for efficiency', async () => { const texts = Array(50).fill('test'); const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); mockEmbed.mockResolvedValue(new Float32Array(384)); await model.embed(texts); // Should process in batches expect(mockEmbed).toHaveBeenCalledTimes(50); }); }); describe('embedSingle', () => { it('should generate embedding for single text', async () => { const text = 'Single text example'; const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); const mockVector = new Float32Array(384); mockVector[0] = 0.5; mockEmbed.mockResolvedValue(mockVector); const embedding = await model.embedSingle(text); expect(embedding).toBeInstanceOf(Float32Array); expect(embedding.length).toBe(384); expect(embedding[0]).toBe(0.5); }); it('should handle empty string', async () => { const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); mockEmbed.mockResolvedValue(new Float32Array(384)); const embedding = await model.embedSingle(''); expect(embedding).toBeInstanceOf(Float32Array); expect(embedding.length).toBe(384); }); }); describe('caching', () => { it('should cache embeddings to disk', async () => { const text = 'Cached text'; const cacheKey = model['getCacheKey'](text); const cachePath = path.join('.atlas/test-cache', `${cacheKey}.json`); // First call - generate and cache mockFs.readFile.mockRejectedValue(new Error('Not found')); mockFs.writeFile.mockResolvedValue(undefined); const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); const mockVector = new Float32Array(384); mockVector[0] = 0.123; mockEmbed.mockResolvedValue(mockVector); const embedding1 = await model.embedSingle(text); expect(mockFs.writeFile).toHaveBeenCalledWith( cachePath, expect.any(String) ); expect(embedding1[0]).toBeCloseTo(0.123, 5); }); it('should load embeddings from cache', async () => { const text = 'Cached text'; const cacheKey = model['getCacheKey'](text); const cachePath = path.join('.atlas/test-cache', `${cacheKey}.json`); // Mock cached data const cachedVector = Array(384).fill(0); cachedVector[0] = 0.789; mockFs.readFile.mockResolvedValue(JSON.stringify(cachedVector)); const embedding = await model.embedSingle(text); expect(mockFs.readFile).toHaveBeenCalledWith(cachePath, 'utf-8'); expect(embedding[0]).toBeCloseTo(0.789); }); it('should handle cache read errors gracefully', async () => { const text = 'Uncached text'; mockFs.readFile.mockRejectedValue(new Error('Read error')); mockFs.writeFile.mockResolvedValue(undefined); const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); mockEmbed.mockResolvedValue(new Float32Array(384)); const embedding = await model.embedSingle(text); expect(embedding).toBeInstanceOf(Float32Array); expect(mockEmbed).toHaveBeenCalled(); }); it('should handle cache write errors gracefully', async () => { const text = 'Write error text'; mockFs.readFile.mockRejectedValue(new Error('Not found')); mockFs.writeFile.mockRejectedValue(new Error('Write error')); const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); mockEmbed.mockResolvedValue(new Float32Array(384)); // Should not throw const embedding = await model.embedSingle(text); expect(embedding).toBeInstanceOf(Float32Array); }); }); describe('similarity', () => { it('should calculate cosine similarity between vectors', () => { const vec1 = new Float32Array([1, 0, 0]); const vec2 = new Float32Array([1, 0, 0]); const similarity = model.cosineSimilarity(vec1, vec2); expect(similarity).toBe(1); // Identical vectors }); it('should handle orthogonal vectors', () => { const vec1 = new Float32Array([1, 0, 0]); const vec2 = new Float32Array([0, 1, 0]); const similarity = model.cosineSimilarity(vec1, vec2); expect(similarity).toBe(0); // Orthogonal vectors }); it('should handle opposite vectors', () => { const vec1 = new Float32Array([1, 0, 0]); const vec2 = new Float32Array([-1, 0, 0]); const similarity = model.cosineSimilarity(vec1, vec2); expect(similarity).toBe(-1); // Opposite vectors }); it('should handle normalized similarity calculation', () => { const vec1 = new Float32Array([3, 4, 0]); // Length 5 const vec2 = new Float32Array([4, 3, 0]); // Length 5 const similarity = model.cosineSimilarity(vec1, vec2); // (3*4 + 4*3 + 0*0) / (5 * 5) = 24/25 = 0.96 expect(similarity).toBeCloseTo(0.96); }); it('should handle zero vectors', () => { const vec1 = new Float32Array([0, 0, 0]); const vec2 = new Float32Array([1, 0, 0]); const similarity = model.cosineSimilarity(vec1, vec2); expect(similarity).toBe(0); }); }); describe('getCacheKey', () => { it('should generate consistent cache keys', () => { const text = 'Test text for caching'; const key1 = model['getCacheKey'](text); const key2 = model['getCacheKey'](text); expect(key1).toBe(key2); expect(key1).toMatch(/^[a-f0-9]+$/); // Should be hex }); it('should generate different keys for different texts', () => { const key1 = model['getCacheKey']('Text 1'); const key2 = model['getCacheKey']('Text 2'); expect(key1).not.toBe(key2); }); it('should handle special characters', () => { const text = 'Text with y�W& and �mojis =�'; const key = model['getCacheKey'](text); expect(key).toMatch(/^[a-f0-9]+$/); }); }); describe('batch processing', () => { it('should process large batches efficiently', async () => { const texts = Array(100).fill(0).map((_, i) => `Text number ${i}`); const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); mockEmbed.mockResolvedValue(new Float32Array(384)); const start = Date.now(); const embeddings = await model.embed(texts); const duration = Date.now() - start; expect(embeddings).toHaveLength(100); expect(duration).toBeLessThan(5000); // Should be reasonably fast }); it('should maintain order in batch processing', async () => { const texts = ['First', 'Second', 'Third', 'Fourth']; const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); mockEmbed.mockImplementation(async (text: string) => { const vector = new Float32Array(384); vector[0] = texts.indexOf(text); return vector; }); const embeddings = await model.embed(texts); expect(embeddings[0][0]).toBe(0); // First expect(embeddings[1][0]).toBe(1); // Second expect(embeddings[2][0]).toBe(2); // Third expect(embeddings[3][0]).toBe(3); // Fourth }); }); describe('error handling', () => { it('should handle model loading errors', async () => { const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); mockEmbed.mockRejectedValue(new Error('Model loading failed')); await expect(model.embedSingle('test')).rejects.toThrow('Model loading failed'); }); it('should handle invalid input gracefully', async () => { const mockEmbed = vi.spyOn(model as any, 'generateEmbedding'); mockEmbed.mockResolvedValue(new Float32Array(384)); // @ts-expect-error - Testing invalid input const result = await model.embed(null); expect(result).toEqual([]); }); }); });