UNPKG

rag-cli-tester

Version:

A lightweight CLI tool for testing RAG (Retrieval-Augmented Generation) systems with different embedding combinations

197 lines โ€ข 8.61 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.RAGTester = void 0; const metrics_1 = require("./metrics"); class RAGTester { constructor(dbConnection, embeddingGenerator) { this.db = dbConnection; this.embeddings = embeddingGenerator; this.similarityMetric = new metrics_1.SimilarityMetric(); this.brdrMetric = new metrics_1.BRDRMetric(); } async initialize() { await this.embeddings.initialize(); } async runExperiment(config) { const startTime = Date.now(); console.log(`\n๐Ÿงช Starting experiment: ${config.testName}`); console.log(`๐Ÿ“Š Table: ${config.tableName}`); console.log(`๐Ÿ“‹ Columns: ${config.selectedColumns.join(', ')}`); console.log(`๐ŸŽฏ Metric: ${config.metricType}`); // Generate column combinations const combinations = this.embeddings.generateColumnCombinations(config.selectedColumns, config.maxCombinations || 20); console.log(`๐Ÿ”„ Testing ${combinations.length} column combinations...\n`); const allResults = []; for (let i = 0; i < combinations.length; i++) { const combination = combinations[i]; console.log(`[${i + 1}/${combinations.length}] Testing: ${combination.name}`); try { const result = await this.runSingleTest(config, combination); allResults.push(result); console.log(` โœ… Score: ${result.averageScore.toFixed(3)} (${result.totalTests} tests)`); } catch (error) { console.error(` โŒ Failed: ${error instanceof Error ? error.message : String(error)}`); continue; } } if (allResults.length === 0) { throw new Error('No combinations produced valid results'); } // Calculate summary statistics const scores = allResults.map(r => r.averageScore); const bestResult = allResults.reduce((best, current) => current.averageScore > best.averageScore ? current : best); const worstResult = allResults.reduce((worst, current) => current.averageScore < worst.averageScore ? current : worst); const summary = { bestCombination: bestResult.combination, bestScore: bestResult.averageScore, worstCombination: worstResult.combination, worstScore: worstResult.averageScore, averageScore: scores.reduce((sum, score) => sum + score, 0) / scores.length, totalCombinations: allResults.length }; const processingTime = Date.now() - startTime; return { testName: config.testName, timestamp: new Date(), configuration: config, allResults, summary, processingTime }; } async runSingleTest(config, combination) { const startTime = Date.now(); // Fetch data from the table const data = await this.db.getTableData(config.tableName); if (data.length === 0) { throw new Error(`No data found in table ${config.tableName}`); } // Filter out rows with missing query or answer data const validData = data.filter(row => row[config.queryColumn] && row[config.answerColumn] && combination.columns.some(col => row[col])); if (validData.length < 5) { throw new Error('Not enough valid data rows for testing'); } // Split data into training and testing const shuffled = [...validData].sort(() => Math.random() - 0.5); const splitIndex = Math.floor(shuffled.length * config.trainingRatio); const trainingData = shuffled.slice(0, splitIndex); const testingData = shuffled.slice(splitIndex); if (testingData.length === 0) { throw new Error('No testing data available after split'); } // Generate embeddings for training data const trainingEmbeddings = await this.embeddings.processTrainingData(trainingData, combination, config.answerColumn); // Process test queries const results = []; let totalSimilarity = 0; for (let i = 0; i < testingData.length; i++) { const testRow = testingData[i]; const query = testRow[config.queryColumn]; const expectedAnswer = testRow[config.answerColumn]; try { // Find best match from training data const matches = await this.embeddings.processQuery(query, trainingEmbeddings, 1); if (matches.length === 0) { continue; } const bestMatch = matches[0]; const actualAnswer = bestMatch.result.targetValue; totalSimilarity += bestMatch.similarity; // Calculate metric based on type let score; if (config.metricType === 'similarity') { const result = this.similarityMetric.calculate(expectedAnswer, actualAnswer, bestMatch.similarity); score = result.overallScore; } else { const result = this.brdrMetric.calculate(expectedAnswer, actualAnswer, bestMatch.similarity); score = result.overallScore; } results.push({ query, expectedAnswer, actualAnswer, similarity: bestMatch.similarity, score }); } catch (error) { console.warn(` Skipped test query ${i + 1}: ${error instanceof Error ? error.message : String(error)}`); continue; } } if (results.length === 0) { throw new Error('No valid test results generated'); } // Calculate averages const averageScore = results.reduce((sum, r) => sum + r.score, 0) / results.length; const averageSimilarity = totalSimilarity / results.length; const processingTime = Date.now() - startTime; return { combination, averageScore, totalTests: results.length, processingTime, embeddingStats: { trainingEmbeddings: trainingEmbeddings.embeddings.length, testQueries: results.length, averageSimilarity } }; } async validateConfiguration(config) { const errors = []; const warnings = []; // Check if table exists and get info const tableInfo = await this.db.getTableInfo(config.tableName); if (!tableInfo) { errors.push(`Table "${config.tableName}" not found`); return { isValid: false, errors, warnings }; } const columnNames = tableInfo.columns.map(col => col.column_name); // Check if columns exist for (const column of config.selectedColumns) { if (!columnNames.includes(column)) { errors.push(`Column "${column}" not found in table "${config.tableName}"`); } } if (!columnNames.includes(config.queryColumn)) { errors.push(`Query column "${config.queryColumn}" not found`); } if (!columnNames.includes(config.answerColumn)) { errors.push(`Answer column "${config.answerColumn}" not found`); } // Validation checks if (config.selectedColumns.length === 0) { errors.push('At least one column must be selected for embeddings'); } if (config.selectedColumns.length > 5) { warnings.push('More than 5 columns may result in many combinations and slow processing'); } if (config.trainingRatio <= 0 || config.trainingRatio >= 1) { errors.push('Training ratio must be between 0 and 1'); } if (tableInfo.rowCount < 10) { warnings.push('Table has very few rows - results may not be reliable'); } return { isValid: errors.length === 0, errors, warnings }; } async getTableInfo(tableName) { return this.db.getTableInfo(tableName); } async getTables() { return this.db.getTables(); } async getSampleData(tableName, limit = 3) { return this.db.getTableData(tableName, ['*'], limit); } } exports.RAGTester = RAGTester; //# sourceMappingURL=tester.js.map