@codai/memorai-core
Version:
Simplified advanced memory engine - no tiers, just powerful semantic search with persistence
196 lines (195 loc) • 7.2 kB
JavaScript
import OpenAI from 'openai';
import { EmbeddingError } from '../types/index.js';
export class OpenAIEmbeddingProvider {
constructor(config) {
if (!config.api_key) {
throw new EmbeddingError('Azure OpenAI API key is required');
}
// Configure for Azure OpenAI or standard OpenAI
const clientConfig = {
apiKey: config.api_key,
// Allow browser environment for testing
dangerouslyAllowBrowser: process.env.NODE_ENV === 'test' || process.env.VITEST === 'true',
};
// Azure OpenAI configuration
if (config.provider === 'azure' || config.endpoint?.includes('azure.com')) {
clientConfig.baseURL = config.endpoint;
clientConfig.defaultQuery = {
'api-version': config.azure_api_version || '2024-02-15-preview',
};
clientConfig.defaultHeaders = {
'api-key': config.api_key,
};
}
else {
// Standard OpenAI configuration
clientConfig.baseURL = config.endpoint;
}
this.client = new OpenAI(clientConfig);
this.model = config.model;
this.dimension = this.getModelDimension(config.model);
}
async embed(text) {
try {
const response = await this.client.embeddings.create({
model: this.model,
input: text.replace(/\n/g, ' '),
encoding_format: 'float',
});
const embedding = response.data[0]?.embedding;
if (!embedding) {
throw new EmbeddingError('No embedding returned from OpenAI');
}
return {
embedding,
tokens: response.usage.total_tokens,
model: this.model,
};
}
catch (error) {
if (error instanceof Error) {
throw new EmbeddingError(`OpenAI embedding failed: ${error.message}`, {
text: text.substring(0, 100),
model: this.model,
});
}
throw new EmbeddingError('Unknown embedding error');
}
}
async embedBatch(texts) {
try {
const cleanTexts = texts.map(text => text.replace(/\n/g, ' '));
const response = await this.client.embeddings.create({
model: this.model,
input: cleanTexts,
encoding_format: 'float',
});
return response.data.map(item => ({
embedding: item.embedding,
tokens: Math.floor(response.usage.total_tokens / texts.length), // Approximate
model: this.model,
}));
}
catch (error) {
if (error instanceof Error) {
throw new EmbeddingError(`OpenAI batch embedding failed: ${error.message}`, {
batch_size: texts.length,
model: this.model,
});
}
throw new EmbeddingError('Unknown batch embedding error');
}
}
getDimension() {
return this.dimension;
}
getModelDimension(model) {
const dimensions = {
'text-embedding-3-small': 1536,
'text-embedding-3-large': 3072,
'text-embedding-ada-002': 1536,
// Azure OpenAI deployment names
'memorai-model-r': 1536, // Azure deployment for text-embedding-ada-002
};
return dimensions[model] ?? 1536;
}
}
export class LocalEmbeddingProvider {
constructor(dimension = 1536) {
this.dimension = dimension;
}
async embed(text) {
// Simplified local embedding using text hash - NOT for production
const hash = this.simpleHash(text);
const embedding = this.hashToVector(hash, this.dimension);
return {
embedding,
tokens: Math.ceil(text.length / 4), // Approximate token count
model: 'local-hash',
};
}
async embedBatch(texts) {
return Promise.all(texts.map(text => this.embed(text)));
}
getDimension() {
return this.dimension;
}
simpleHash(str) {
let hash = 0;
for (let i = 0; i < str.length; i++) {
const char = str.charCodeAt(i);
hash = (hash << 5) - hash + char;
hash = hash & hash; // Convert to 32-bit integer
}
return Math.abs(hash);
}
hashToVector(hash, dimension) {
const vector = [];
let currentHash = hash;
for (let i = 0; i < dimension; i++) {
// Create pseudo-random values from hash
currentHash = (currentHash * 1103515245 + 12345) & 0x7fffffff;
vector.push((currentHash / 0x7fffffff) * 2 - 1); // Normalize to [-1, 1]
}
// Normalize vector
const magnitude = Math.sqrt(vector.reduce((sum, val) => sum + val * val, 0));
return vector.map(val => val / magnitude);
}
}
export class EmbeddingService {
constructor(config) {
switch (config.provider) {
case 'azure':
this.provider = new OpenAIEmbeddingProvider({
...config,
endpoint: config.azure_endpoint || config.endpoint,
});
break;
case 'openai':
this.provider = new OpenAIEmbeddingProvider(config);
break;
case 'local':
this.provider = new LocalEmbeddingProvider();
break;
default:
throw new EmbeddingError(`Unsupported embedding provider: ${config.provider}`);
}
}
async embed(text) {
if (!text || text.trim().length === 0) {
throw new EmbeddingError('Text cannot be empty');
}
return this.provider.embed(text.trim());
}
async embedBatch(texts) {
if (texts.length === 0) {
return [];
}
const validTexts = texts.filter(text => text && text.trim().length > 0);
if (validTexts.length === 0) {
throw new EmbeddingError('No valid texts provided');
}
return this.provider.embedBatch(validTexts.map(text => text.trim()));
}
getDimension() {
return this.provider.getDimension();
}
async embedWithRetry(text, maxRetries = 3, baseDelay = 1000) {
let lastError;
for (let attempt = 1; attempt <= maxRetries; attempt++) {
try {
return await this.embed(text);
}
catch (error) {
lastError = error instanceof Error ? error : new Error('Unknown error');
if (attempt === maxRetries) {
break;
}
// Exponential backoff
const delay = baseDelay * Math.pow(2, attempt - 1);
await new Promise(resolve => setTimeout(resolve, delay));
}
}
throw new EmbeddingError(`Failed to embed after ${maxRetries} attempts: ${lastError?.message}`, { text: text.substring(0, 100), attempts: maxRetries });
}
}