@boundless-oss/atlas
Version:
Atlas - MCP Server for comprehensive startup project management
320 lines (246 loc) • 10.7 kB
text/typescript
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import { LocalEmbeddingModel } from '../embeddings.js';
import type { EmbeddingVector } from '../types.js';
import * as fs from 'fs';
import * as path from 'path';
// Mock the fs module
vi.mock('fs', () => ({
promises: {
mkdir: vi.fn(),
readFile: vi.fn(),
writeFile: vi.fn(),
access: vi.fn()
}
}));
describe('LocalEmbeddingModel', () => {
let model: LocalEmbeddingModel;
const mockFs = fs.promises as any;
beforeEach(() => {
vi.clearAllMocks();
model = new LocalEmbeddingModel({
modelName: 'test-model',
dimension: 384,
cachePath: '.atlas/test-cache'
});
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('initialization', () => {
it('should initialize with default configuration', () => {
const defaultModel = new LocalEmbeddingModel();
expect(defaultModel.modelName).toBe('all-MiniLM-L6-v2');
expect(defaultModel.dimension).toBe(384);
});
it('should initialize with custom configuration', () => {
expect(model.modelName).toBe('test-model');
expect(model.dimension).toBe(384);
});
it('should create cache directory on initialization', async () => {
mockFs.access.mockRejectedValue(new Error('Not found'));
mockFs.mkdir.mockResolvedValue(undefined);
await model.initialize();
expect(mockFs.mkdir).toHaveBeenCalledWith('.atlas/test-cache', { recursive: true });
});
it('should not create cache directory if it exists', async () => {
mockFs.access.mockResolvedValue(undefined);
await model.initialize();
expect(mockFs.mkdir).not.toHaveBeenCalled();
});
});
describe('embed', () => {
it('should generate embeddings for multiple texts', async () => {
const texts = ['Hello world', 'Test embedding', 'Third text'];
// Mock the internal embedding generation
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
mockEmbed.mockImplementation((text: string) => {
const vector = new Float32Array(384);
// Simple mock: use text length as first value
vector[0] = text.length / 100;
return Promise.resolve(vector);
});
const embeddings = await model.embed(texts);
expect(embeddings).toHaveLength(3);
expect(embeddings[0]).toBeInstanceOf(Float32Array);
expect(embeddings[0].length).toBe(384);
expect(embeddings[0][0]).toBeCloseTo(0.11); // "Hello world".length / 100
expect(embeddings[1][0]).toBeCloseTo(0.14); // "Test embedding".length / 100
});
it('should handle empty array', async () => {
const embeddings = await model.embed([]);
expect(embeddings).toHaveLength(0);
});
it('should batch embeddings for efficiency', async () => {
const texts = Array(50).fill('test');
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
mockEmbed.mockResolvedValue(new Float32Array(384));
await model.embed(texts);
// Should process in batches
expect(mockEmbed).toHaveBeenCalledTimes(50);
});
});
describe('embedSingle', () => {
it('should generate embedding for single text', async () => {
const text = 'Single text example';
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
const mockVector = new Float32Array(384);
mockVector[0] = 0.5;
mockEmbed.mockResolvedValue(mockVector);
const embedding = await model.embedSingle(text);
expect(embedding).toBeInstanceOf(Float32Array);
expect(embedding.length).toBe(384);
expect(embedding[0]).toBe(0.5);
});
it('should handle empty string', async () => {
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
mockEmbed.mockResolvedValue(new Float32Array(384));
const embedding = await model.embedSingle('');
expect(embedding).toBeInstanceOf(Float32Array);
expect(embedding.length).toBe(384);
});
});
describe('caching', () => {
it('should cache embeddings to disk', async () => {
const text = 'Cached text';
const cacheKey = model['getCacheKey'](text);
const cachePath = path.join('.atlas/test-cache', `${cacheKey}.json`);
// First call - generate and cache
mockFs.readFile.mockRejectedValue(new Error('Not found'));
mockFs.writeFile.mockResolvedValue(undefined);
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
const mockVector = new Float32Array(384);
mockVector[0] = 0.123;
mockEmbed.mockResolvedValue(mockVector);
const embedding1 = await model.embedSingle(text);
expect(mockFs.writeFile).toHaveBeenCalledWith(
cachePath,
expect.any(String)
);
expect(embedding1[0]).toBeCloseTo(0.123, 5);
});
it('should load embeddings from cache', async () => {
const text = 'Cached text';
const cacheKey = model['getCacheKey'](text);
const cachePath = path.join('.atlas/test-cache', `${cacheKey}.json`);
// Mock cached data
const cachedVector = Array(384).fill(0);
cachedVector[0] = 0.789;
mockFs.readFile.mockResolvedValue(JSON.stringify(cachedVector));
const embedding = await model.embedSingle(text);
expect(mockFs.readFile).toHaveBeenCalledWith(cachePath, 'utf-8');
expect(embedding[0]).toBeCloseTo(0.789);
});
it('should handle cache read errors gracefully', async () => {
const text = 'Uncached text';
mockFs.readFile.mockRejectedValue(new Error('Read error'));
mockFs.writeFile.mockResolvedValue(undefined);
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
mockEmbed.mockResolvedValue(new Float32Array(384));
const embedding = await model.embedSingle(text);
expect(embedding).toBeInstanceOf(Float32Array);
expect(mockEmbed).toHaveBeenCalled();
});
it('should handle cache write errors gracefully', async () => {
const text = 'Write error text';
mockFs.readFile.mockRejectedValue(new Error('Not found'));
mockFs.writeFile.mockRejectedValue(new Error('Write error'));
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
mockEmbed.mockResolvedValue(new Float32Array(384));
// Should not throw
const embedding = await model.embedSingle(text);
expect(embedding).toBeInstanceOf(Float32Array);
});
});
describe('similarity', () => {
it('should calculate cosine similarity between vectors', () => {
const vec1 = new Float32Array([1, 0, 0]);
const vec2 = new Float32Array([1, 0, 0]);
const similarity = model.cosineSimilarity(vec1, vec2);
expect(similarity).toBe(1); // Identical vectors
});
it('should handle orthogonal vectors', () => {
const vec1 = new Float32Array([1, 0, 0]);
const vec2 = new Float32Array([0, 1, 0]);
const similarity = model.cosineSimilarity(vec1, vec2);
expect(similarity).toBe(0); // Orthogonal vectors
});
it('should handle opposite vectors', () => {
const vec1 = new Float32Array([1, 0, 0]);
const vec2 = new Float32Array([-1, 0, 0]);
const similarity = model.cosineSimilarity(vec1, vec2);
expect(similarity).toBe(-1); // Opposite vectors
});
it('should handle normalized similarity calculation', () => {
const vec1 = new Float32Array([3, 4, 0]); // Length 5
const vec2 = new Float32Array([4, 3, 0]); // Length 5
const similarity = model.cosineSimilarity(vec1, vec2);
// (3*4 + 4*3 + 0*0) / (5 * 5) = 24/25 = 0.96
expect(similarity).toBeCloseTo(0.96);
});
it('should handle zero vectors', () => {
const vec1 = new Float32Array([0, 0, 0]);
const vec2 = new Float32Array([1, 0, 0]);
const similarity = model.cosineSimilarity(vec1, vec2);
expect(similarity).toBe(0);
});
});
describe('getCacheKey', () => {
it('should generate consistent cache keys', () => {
const text = 'Test text for caching';
const key1 = model['getCacheKey'](text);
const key2 = model['getCacheKey'](text);
expect(key1).toBe(key2);
expect(key1).toMatch(/^[a-f0-9]+$/); // Should be hex
});
it('should generate different keys for different texts', () => {
const key1 = model['getCacheKey']('Text 1');
const key2 = model['getCacheKey']('Text 2');
expect(key1).not.toBe(key2);
});
it('should handle special characters', () => {
const text = 'Text with y�W& and �mojis =�';
const key = model['getCacheKey'](text);
expect(key).toMatch(/^[a-f0-9]+$/);
});
});
describe('batch processing', () => {
it('should process large batches efficiently', async () => {
const texts = Array(100).fill(0).map((_, i) => `Text number ${i}`);
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
mockEmbed.mockResolvedValue(new Float32Array(384));
const start = Date.now();
const embeddings = await model.embed(texts);
const duration = Date.now() - start;
expect(embeddings).toHaveLength(100);
expect(duration).toBeLessThan(5000); // Should be reasonably fast
});
it('should maintain order in batch processing', async () => {
const texts = ['First', 'Second', 'Third', 'Fourth'];
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
mockEmbed.mockImplementation(async (text: string) => {
const vector = new Float32Array(384);
vector[0] = texts.indexOf(text);
return vector;
});
const embeddings = await model.embed(texts);
expect(embeddings[0][0]).toBe(0); // First
expect(embeddings[1][0]).toBe(1); // Second
expect(embeddings[2][0]).toBe(2); // Third
expect(embeddings[3][0]).toBe(3); // Fourth
});
});
describe('error handling', () => {
it('should handle model loading errors', async () => {
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
mockEmbed.mockRejectedValue(new Error('Model loading failed'));
await expect(model.embedSingle('test')).rejects.toThrow('Model loading failed');
});
it('should handle invalid input gracefully', async () => {
const mockEmbed = vi.spyOn(model as any, 'generateEmbedding');
mockEmbed.mockResolvedValue(new Float32Array(384));
// @ts-expect-error - Testing invalid input
const result = await model.embed(null);
expect(result).toEqual([]);
});
});
});