UNPKG

@codai/memorai-core

Version:

Simplified advanced memory engine - no tiers, just powerful semantic search with persistence

196 lines (195 loc) 7.2 kB
import OpenAI from 'openai'; import { EmbeddingError } from '../types/index.js'; export class OpenAIEmbeddingProvider { constructor(config) { if (!config.api_key) { throw new EmbeddingError('Azure OpenAI API key is required'); } // Configure for Azure OpenAI or standard OpenAI const clientConfig = { apiKey: config.api_key, // Allow browser environment for testing dangerouslyAllowBrowser: process.env.NODE_ENV === 'test' || process.env.VITEST === 'true', }; // Azure OpenAI configuration if (config.provider === 'azure' || config.endpoint?.includes('azure.com')) { clientConfig.baseURL = config.endpoint; clientConfig.defaultQuery = { 'api-version': config.azure_api_version || '2024-02-15-preview', }; clientConfig.defaultHeaders = { 'api-key': config.api_key, }; } else { // Standard OpenAI configuration clientConfig.baseURL = config.endpoint; } this.client = new OpenAI(clientConfig); this.model = config.model; this.dimension = this.getModelDimension(config.model); } async embed(text) { try { const response = await this.client.embeddings.create({ model: this.model, input: text.replace(/\n/g, ' '), encoding_format: 'float', }); const embedding = response.data[0]?.embedding; if (!embedding) { throw new EmbeddingError('No embedding returned from OpenAI'); } return { embedding, tokens: response.usage.total_tokens, model: this.model, }; } catch (error) { if (error instanceof Error) { throw new EmbeddingError(`OpenAI embedding failed: ${error.message}`, { text: text.substring(0, 100), model: this.model, }); } throw new EmbeddingError('Unknown embedding error'); } } async embedBatch(texts) { try { const cleanTexts = texts.map(text => text.replace(/\n/g, ' ')); const response = await this.client.embeddings.create({ model: this.model, input: cleanTexts, encoding_format: 'float', }); return response.data.map(item => ({ embedding: item.embedding, tokens: Math.floor(response.usage.total_tokens / texts.length), // Approximate model: this.model, })); } catch (error) { if (error instanceof Error) { throw new EmbeddingError(`OpenAI batch embedding failed: ${error.message}`, { batch_size: texts.length, model: this.model, }); } throw new EmbeddingError('Unknown batch embedding error'); } } getDimension() { return this.dimension; } getModelDimension(model) { const dimensions = { 'text-embedding-3-small': 1536, 'text-embedding-3-large': 3072, 'text-embedding-ada-002': 1536, // Azure OpenAI deployment names 'memorai-model-r': 1536, // Azure deployment for text-embedding-ada-002 }; return dimensions[model] ?? 1536; } } export class LocalEmbeddingProvider { constructor(dimension = 1536) { this.dimension = dimension; } async embed(text) { // Simplified local embedding using text hash - NOT for production const hash = this.simpleHash(text); const embedding = this.hashToVector(hash, this.dimension); return { embedding, tokens: Math.ceil(text.length / 4), // Approximate token count model: 'local-hash', }; } async embedBatch(texts) { return Promise.all(texts.map(text => this.embed(text))); } getDimension() { return this.dimension; } simpleHash(str) { let hash = 0; for (let i = 0; i < str.length; i++) { const char = str.charCodeAt(i); hash = (hash << 5) - hash + char; hash = hash & hash; // Convert to 32-bit integer } return Math.abs(hash); } hashToVector(hash, dimension) { const vector = []; let currentHash = hash; for (let i = 0; i < dimension; i++) { // Create pseudo-random values from hash currentHash = (currentHash * 1103515245 + 12345) & 0x7fffffff; vector.push((currentHash / 0x7fffffff) * 2 - 1); // Normalize to [-1, 1] } // Normalize vector const magnitude = Math.sqrt(vector.reduce((sum, val) => sum + val * val, 0)); return vector.map(val => val / magnitude); } } export class EmbeddingService { constructor(config) { switch (config.provider) { case 'azure': this.provider = new OpenAIEmbeddingProvider({ ...config, endpoint: config.azure_endpoint || config.endpoint, }); break; case 'openai': this.provider = new OpenAIEmbeddingProvider(config); break; case 'local': this.provider = new LocalEmbeddingProvider(); break; default: throw new EmbeddingError(`Unsupported embedding provider: ${config.provider}`); } } async embed(text) { if (!text || text.trim().length === 0) { throw new EmbeddingError('Text cannot be empty'); } return this.provider.embed(text.trim()); } async embedBatch(texts) { if (texts.length === 0) { return []; } const validTexts = texts.filter(text => text && text.trim().length > 0); if (validTexts.length === 0) { throw new EmbeddingError('No valid texts provided'); } return this.provider.embedBatch(validTexts.map(text => text.trim())); } getDimension() { return this.provider.getDimension(); } async embedWithRetry(text, maxRetries = 3, baseDelay = 1000) { let lastError; for (let attempt = 1; attempt <= maxRetries; attempt++) { try { return await this.embed(text); } catch (error) { lastError = error instanceof Error ? error : new Error('Unknown error'); if (attempt === maxRetries) { break; } // Exponential backoff const delay = baseDelay * Math.pow(2, attempt - 1); await new Promise(resolve => setTimeout(resolve, delay)); } } throw new EmbeddingError(`Failed to embed after ${maxRetries} attempts: ${lastError?.message}`, { text: text.substring(0, 100), attempts: maxRetries }); } }