UNPKG

rag-cli-tester

Version:

A lightweight CLI tool for testing RAG (Retrieval-Augmented Generation) systems with different embedding combinations

156 lines 5.9 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.EmbeddingGenerator = exports.TransformersPipelineProvider = void 0; class TransformersPipelineProvider { async createPipeline(task, model) { // Use eval to bypass TypeScript's import compilation const transformers = await eval('import("@xenova/transformers")'); return await transformers.pipeline(task, model); } } exports.TransformersPipelineProvider = TransformersPipelineProvider; class EmbeddingGenerator { constructor(config, pipelineProvider) { this.embeddingPipeline = null; this.config = config; this.pipelineProvider = pipelineProvider || new TransformersPipelineProvider(); } async initialize() { try { this.embeddingPipeline = await this.pipelineProvider.createPipeline('feature-extraction', this.config.localModel || 'Xenova/all-MiniLM-L6-v2-small'); } catch (error) { console.error('Failed to initialize embedding model:', error); throw error; } } generateColumnCombinations(columns, maxCombinations = 20) { const combinations = []; const n = Math.min(columns.length, 5); // Limit to 5 columns max // Generate all possible combinations from 1 to n columns for (let i = 1; i <= n; i++) { const combos = this.getCombinations(columns.slice(0, n), i); combos.forEach(combo => { if (combinations.length < maxCombinations) { combinations.push({ columns: combo, name: combo.join(' + ') }); } }); if (combinations.length >= maxCombinations) break; } return combinations; } getCombinations(arr, k) { if (k === 1) return arr.map(item => [item]); if (k === arr.length) return [arr]; const combinations = []; for (let i = 0; i <= arr.length - k; i++) { const head = arr[i]; const tailCombos = this.getCombinations(arr.slice(i + 1), k - 1); tailCombos.forEach(combo => { combinations.push([head, ...combo]); }); } return combinations; } async generateEmbedding(text) { try { if (!this.embeddingPipeline) { throw new Error('Embedding pipeline not initialized'); } const result = await this.embeddingPipeline(text); // Convert to flat array if needed return Array.isArray(result.data) ? result.data : Array.from(result.data); } catch (error) { console.error('Failed to generate embedding:', error); throw error; } } createContext(row, combination) { const contextParts = combination.columns .filter(col => row[col] !== null && row[col] !== undefined) .map(col => { const value = row[col]; return `${col}: ${typeof value === 'object' ? JSON.stringify(value) : value}`; }); return contextParts.join(' | '); } async processTrainingData(data, combination, targetColumn, idColumn) { const embeddings = []; for (let i = 0; i < data.length; i++) { const row = data[i]; const context = this.createContext(row, combination); if (!context.trim()) { console.warn(`Skipping row ${i} - no valid context generated`); continue; } try { const embedding = await this.generateEmbedding(context); embeddings.push({ id: idColumn ? row[idColumn] : `row_${i}`, combination, embedding, context, targetValue: row[targetColumn], metadata: { originalRow: row, rowIndex: i } }); // Progress logging every 50 rows to reduce noise if ((i + 1) % 50 === 0) { console.log(` Processed ${i + 1}/${data.length} rows`); } } catch (error) { console.error(`Failed to process row ${i}:`, error); continue; } } return { embeddings, combination, totalRows: data.length }; } calculateCosineSimilarity(a, b) { if (a.length !== b.length) { throw new Error('Vectors must have the same length'); } let dotProduct = 0; let normA = 0; let normB = 0; for (let i = 0; i < a.length; i++) { dotProduct += a[i] * b[i]; normA += a[i] * a[i]; normB += b[i] * b[i]; } normA = Math.sqrt(normA); normB = Math.sqrt(normB); if (normA === 0 || normB === 0) { return 0; } return dotProduct / (normA * normB); } async findBestMatch(queryEmbedding, trainingData, topK = 1) { const similarities = trainingData.embeddings.map(item => ({ result: item, similarity: this.calculateCosineSimilarity(queryEmbedding, item.embedding) })); // Sort by similarity (highest first) similarities.sort((a, b) => b.similarity - a.similarity); return similarities.slice(0, topK); } async processQuery(query, trainingData, topK = 1) { const queryEmbedding = await this.generateEmbedding(query); return this.findBestMatch(queryEmbedding, trainingData, topK); } } exports.EmbeddingGenerator = EmbeddingGenerator; //# sourceMappingURL=embeddings.js.map