@escher-dbai/rag-module
Version:
Enterprise RAG module with chat context storage, vector search, and session management. Complete chat history retrieval and streaming content extraction for Electron apps.
465 lines (389 loc) • 15 kB
JavaScript
// const { pipeline, env } = require('@xenova/transformers'); // Will be loaded dynamically
const fs = require('fs-extra');
const path = require('path');
const { EventEmitter } = require('events');
/**
* Local Embedding Service - Supports multiple models including BGE-M3
* Downloads and runs embedding models locally for privacy
*/
class EmbeddingService extends EventEmitter {
constructor(basePath, configManager) {
super();
this.basePath = basePath;
this.configManager = configManager;
this.modelsPath = path.join(basePath, 'models');
// Current loaded model
this.currentModel = null;
this.modelName = null;
this.loadedPipeline = null;
this.dimensions = null;
// Model configurations
this.supportedModels = {
'BAAI/bge-m3': {
dimensions: 1024,
type: 'sentence-transformers',
description: 'BGE-M3 multilingual embedding model (1024 dimensions)',
xenova_model: 'Xenova/bge-m3' // Xenova version for transformers.js compatibility
},
'sentence-transformers/all-MiniLM-L6-v2': {
dimensions: 384,
type: 'sentence-transformers',
description: 'Compact multilingual model (384 dimensions)'
},
'sentence-transformers/all-mpnet-base-v2': {
dimensions: 768,
type: 'sentence-transformers',
description: 'High-quality English model (768 dimensions)'
},
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2': {
dimensions: 384,
type: 'sentence-transformers',
description: 'Multilingual paraphrase model (384 dimensions)'
}
};
this.initialized = false;
}
/**
* Initialize embedding service
*/
async initialize() {
try {
// Load transformers dynamically
const { pipeline, env } = await import('@xenova/transformers');
this.pipeline = pipeline;
this.env = env;
// Ensure models directory exists
await fs.ensureDir(this.modelsPath);
// Set transformers cache directory
this.env.cacheDir = this.modelsPath;
// Load default model from config or use BGE-M3
const config = this.configManager.getConfig();
const defaultModel = config.embeddingModel || 'BAAI/bge-m3';
await this.loadModel(defaultModel);
this.initialized = true;
this.emit('initialized', `Embedding service initialized with model: ${this.modelName}`);
return true;
} catch (error) {
this.emit('error', `Failed to initialize embedding service: ${error.message}`);
throw error;
}
}
/**
* Load or download embedding model with robust fallback
* @param {string} modelName - Model identifier
* @returns {Promise<boolean>}
*/
async loadModel(modelName) {
try {
this.emit('loading-model', `Loading embedding model: ${modelName}`);
// Validate model
if (!this.supportedModels[modelName]) {
throw new Error(`Unsupported model: ${modelName}. Supported models: ${Object.keys(this.supportedModels).join(', ')}`);
}
// Close existing pipeline if loaded
if (this.loadedPipeline) {
this.loadedPipeline = null;
}
let actualModelName = modelName;
let loadSuccess = false;
// For BGE-M3, try multiple model sources with fallback
if (modelName === 'BAAI/bge-m3') {
const modelVariants = [
'Xenova/bge-m3', // First try Xenova version
'sentence-transformers/all-MiniLM-L6-v2', // Fallback to MiniLM
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' // Another fallback
];
for (const variant of modelVariants) {
try {
console.log(` 🔄 Attempting to load: ${variant}`);
this.loadedPipeline = await this.pipeline('feature-extraction', variant, {
cache_dir: this.modelsPath,
local_files_only: false,
revision: 'main'
});
actualModelName = variant;
loadSuccess = true;
if (variant !== 'Xenova/bge-m3') {
console.log(` ⚠️ BGE-M3 not available, using fallback: ${variant}`);
// Keep BGE-M3 dimensions for compatibility with backend
this.dimensions = 1024;
this.emit('model-fallback', `Using fallback model ${variant} with BGE-M3 dimensions for compatibility`);
} else {
console.log(` ✅ Successfully loaded BGE-M3: ${variant}`);
this.dimensions = this.supportedModels[modelName].dimensions;
}
break;
} catch (loadError) {
console.log(` ❌ Failed to load ${variant}: ${loadError.message}`);
continue;
}
}
if (!loadSuccess) {
throw new Error(`All BGE-M3 variants failed to load. Try using a different model or check network connectivity.`);
}
} else {
// For other models, use standard loading
if (this.supportedModels[modelName].xenova_model) {
actualModelName = this.supportedModels[modelName].xenova_model;
console.log(` 🔄 Using Xenova version: ${actualModelName}`);
}
this.loadedPipeline = await this.pipeline('feature-extraction', actualModelName, {
cache_dir: this.modelsPath,
local_files_only: false,
revision: 'main'
});
this.dimensions = this.supportedModels[modelName].dimensions;
loadSuccess = true;
}
// Update model info
this.modelName = modelName;
this.currentModel = this.supportedModels[modelName];
this.actualModelName = actualModelName; // Track what was actually loaded
// Update config
await this.configManager.updateConfig({ embeddingModel: modelName });
this.emit('model-loaded', `Model loaded successfully: ${modelName} -> ${actualModelName} (${this.dimensions}D)`);
return true;
} catch (error) {
this.emit('error', `Failed to load model ${modelName}: ${error.message}`);
throw error;
}
}
/**
* Generate embedding for single text with dimension padding/truncation for compatibility
* @param {string} text - Text to embed
* @returns {Promise<number[]>} - Embedding vector (always BGE-M3 dimensions for backend compatibility)
*/
async embed(text) {
this._ensureInitialized();
if (!text || typeof text !== 'string') {
throw new Error('Text must be a non-empty string');
}
try {
// Generate embedding using the pipeline
// Use CLS pooling for BGE-M3, mean pooling for other models
const poolingStrategy = this.modelName === 'BAAI/bge-m3' ? 'cls' : 'mean';
const output = await this.loadedPipeline(text, {
pooling: poolingStrategy,
normalize: true
});
// Convert tensor to array
let embedding = Array.from(output.data);
// Handle dimension compatibility for BGE-M3 backend expectation
if (this.modelName === 'BAAI/bge-m3') {
const expectedDimensions = 1024;
if (embedding.length < expectedDimensions) {
// Pad with zeros if smaller (unlikely)
console.log(` 📏 Padding embedding from ${embedding.length} to ${expectedDimensions} dimensions`);
const padding = new Array(expectedDimensions - embedding.length).fill(0);
embedding = embedding.concat(padding);
} else if (embedding.length > expectedDimensions) {
// Truncate if larger (more likely with fallback models)
console.log(` ✂️ Truncating embedding from ${embedding.length} to ${expectedDimensions} dimensions`);
embedding = embedding.slice(0, expectedDimensions);
}
// Final validation
if (embedding.length !== expectedDimensions) {
throw new Error(`Failed to normalize embedding dimensions: got ${embedding.length}, expected ${expectedDimensions}`);
}
} else {
// For non-BGE models, validate against expected dimensions
if (embedding.length !== this.dimensions) {
throw new Error(`Unexpected embedding dimensions: got ${embedding.length}, expected ${this.dimensions}`);
}
}
return embedding;
} catch (error) {
this.emit('error', `Failed to generate embedding: ${error.message}`);
throw error;
}
}
/**
* Generate embeddings for multiple texts (batch processing)
* @param {string[]} texts - Array of texts to embed
* @param {Object} [options] - Processing options
* @param {number} [options.batchSize=32] - Batch size for processing
* @param {Function} [options.onProgress] - Progress callback
* @returns {Promise<number[][]>} - Array of embedding vectors
*/
async embedBatch(texts, options = {}) {
this._ensureInitialized();
if (!Array.isArray(texts) || texts.length === 0) {
throw new Error('Texts must be a non-empty array');
}
const { batchSize = 32, onProgress } = options;
const embeddings = [];
const totalBatches = Math.ceil(texts.length / batchSize);
try {
for (let i = 0; i < texts.length; i += batchSize) {
const batch = texts.slice(i, i + batchSize);
const batchNum = Math.floor(i / batchSize) + 1;
this.emit('batch-processing', {
batch: batchNum,
total: totalBatches,
texts: batch.length
});
// Process batch
const batchEmbeddings = await Promise.all(
batch.map(text => this.embed(text))
);
embeddings.push(...batchEmbeddings);
// Call progress callback if provided
if (onProgress) {
onProgress({
completed: i + batch.length,
total: texts.length,
batch: batchNum,
totalBatches
});
}
}
this.emit('batch-completed', `Generated ${embeddings.length} embeddings`);
return embeddings;
} catch (error) {
this.emit('error', `Failed to process batch embeddings: ${error.message}`);
throw error;
}
}
/**
* Calculate cosine similarity between two embeddings
* @param {number[]} embedding1 - First embedding vector
* @param {number[]} embedding2 - Second embedding vector
* @returns {number} - Similarity score (0-1)
*/
calculateSimilarity(embedding1, embedding2) {
if (!Array.isArray(embedding1) || !Array.isArray(embedding2)) {
throw new Error('Embeddings must be arrays');
}
if (embedding1.length !== embedding2.length) {
throw new Error('Embeddings must have the same dimensions');
}
// Calculate dot product
let dotProduct = 0;
let norm1 = 0;
let norm2 = 0;
for (let i = 0; i < embedding1.length; i++) {
dotProduct += embedding1[i] * embedding2[i];
norm1 += embedding1[i] * embedding1[i];
norm2 += embedding2[i] * embedding2[i];
}
// Calculate cosine similarity
const magnitude = Math.sqrt(norm1) * Math.sqrt(norm2);
if (magnitude === 0) return 0;
return dotProduct / magnitude;
}
/**
* Get information about loaded model
* @returns {Object} - Model information
*/
getModelInfo() {
this._ensureInitialized();
return {
name: this.modelName,
actualModel: this.actualModelName || this.modelName,
dimensions: this.dimensions,
type: this.currentModel.type,
description: this.currentModel.description,
loaded: !!this.loadedPipeline,
isFallback: this.actualModelName !== this.modelName
};
}
/**
* Get list of supported models
* @returns {Object} - Supported models with their info
*/
getSupportedModels() {
return { ...this.supportedModels };
}
/**
* Check if a model is downloaded locally
* @param {string} modelName - Model identifier
* @returns {Promise<boolean>}
*/
async isModelDownloaded(modelName) {
const modelPath = path.join(this.modelsPath, 'models--' + modelName.replace('/', '--'));
return await fs.pathExists(modelPath);
}
/**
* Download model without loading it
* @param {string} modelName - Model identifier
* @returns {Promise<boolean>}
*/
async downloadModel(modelName) {
if (!this.supportedModels[modelName]) {
throw new Error(`Unsupported model: ${modelName}`);
}
this.emit('downloading-model', `Downloading model: ${modelName}`);
try {
// Create temporary pipeline to trigger download
await this.pipeline('feature-extraction', modelName, {
cache_dir: this.modelsPath,
local_files_only: false
});
this.emit('model-downloaded', `Model downloaded: ${modelName}`);
return true;
} catch (error) {
this.emit('error', `Failed to download model ${modelName}: ${error.message}`);
throw error;
}
}
/**
* Get storage info for models
* @returns {Promise<Object>}
*/
async getStorageInfo() {
const models = [];
let totalSize = 0;
for (const modelName of Object.keys(this.supportedModels)) {
const isDownloaded = await this.isModelDownloaded(modelName);
const modelPath = path.join(this.modelsPath, 'models--' + modelName.replace('/', '--'));
let size = 0;
if (isDownloaded) {
try {
const stats = await fs.stat(modelPath);
size = stats.size;
totalSize += size;
} catch (error) {
// Ignore stat errors
}
}
models.push({
name: modelName,
...this.supportedModels[modelName],
downloaded: isDownloaded,
size: size,
sizeFormatted: this._formatFileSize(size)
});
}
return {
models,
totalSize,
totalSizeFormatted: this._formatFileSize(totalSize),
modelsPath: this.modelsPath
};
}
/**
* Clean up resources
*/
async close() {
if (this.loadedPipeline) {
this.loadedPipeline = null;
}
this.emit('closed', 'Embedding service closed');
}
// ============ PRIVATE METHODS ============
_ensureInitialized() {
if (!this.initialized) {
throw new Error('Embedding service must be initialized before use');
}
if (!this.loadedPipeline) {
throw new Error('No model loaded. Call loadModel() first');
}
}
_formatFileSize(bytes) {
if (bytes === 0) return '0 B';
const sizes = ['B', 'KB', 'MB', 'GB'];
const i = Math.floor(Math.log(bytes) / Math.log(1024));
return Math.round(bytes / Math.pow(1024, i) * 100) / 100 + ' ' + sizes[i];
}
}
module.exports = EmbeddingService;