rag-cli-tester
Version:
A lightweight CLI tool for testing RAG (Retrieval-Augmented Generation) systems with different embedding combinations
331 lines โข 15.8 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.EnhancedRAGTester = void 0;
const base_metric_1 = require("./metrics/base-metric");
class EnhancedRAGTester {
constructor(dbConnection, embeddingGenerator) {
this.metricCache = new Map();
this.embeddingCache = new Map();
this.db = dbConnection;
this.embeddings = embeddingGenerator;
}
async initialize() {
await this.embeddings.initialize();
}
async runEnhancedExperiment(config) {
const startTime = Date.now();
const memoryStart = process.memoryUsage();
console.log(`\n๐งช Starting Enhanced RAG Experiment: ${config.testName}`);
console.log(`๐ Table: ${config.tableName}`);
console.log(`๐ Columns: ${config.selectedColumns.join(', ')}`);
console.log(`๐ฏ Metric: ${config.metricType}`);
console.log(`๐ฆ Batch Size: ${config.batchSize}`);
console.log(`๐ Max Training Samples: ${config.maxTrainingSamples}`);
console.log(`๐งช Max Testing Samples: ${config.maxTestingSamples}`);
// Validate configuration
const validation = await this.validateEnhancedConfiguration(config);
if (!validation.isValid) {
throw new Error(`Configuration validation failed: ${validation.errors.join(', ')}`);
}
// Generate column combinations
const combinations = this.embeddings.generateColumnCombinations(config.selectedColumns, config.maxCombinations || 20);
console.log(`๐ Testing ${combinations.length} column combinations...\n`);
const allResults = [];
for (let i = 0; i < combinations.length; i++) {
const combination = combinations[i];
console.log(`[${i + 1}/${combinations.length}] Testing: ${combination.name}`);
try {
const result = await this.runEnhancedSingleTest(config, combination);
allResults.push(result);
console.log(` โ
Score: ${result.averageScore.toFixed(3)} (${result.totalTests} tests)`);
console.log(` ๐ Confidence: ${result.detailedMetrics.confidence.toFixed(3)}`);
console.log(` โฑ๏ธ Total Time: ${(result.processingStats.trainingTime + result.processingStats.testingTime).toFixed(1)}ms`);
}
catch (error) {
console.error(` โ Failed: ${error instanceof Error ? error.message : String(error)}`);
continue;
}
}
if (allResults.length === 0) {
throw new Error('No combinations produced valid results');
}
// Calculate enhanced summary statistics
const summary = this.calculateEnhancedSummary(allResults);
const processingTime = Date.now() - startTime;
const memoryEnd = process.memoryUsage();
const memoryUsed = memoryEnd.heapUsed - memoryStart.heapUsed;
console.log(`\n๐ Memory Usage: ${(memoryUsed / 1024 / 1024).toFixed(2)} MB`);
return {
testName: config.testName,
timestamp: new Date(),
configuration: config,
allResults,
summary,
processingTime
};
}
async runEnhancedSingleTest(config, combination) {
const startTime = Date.now();
// Get metric instance
const metric = this.getMetric(config.metricType);
// Fetch data in batches for large datasets
const trainingData = await this.getTrainingData(config, combination);
const testingData = await this.getTestingData(config, combination);
if (trainingData.length === 0 || testingData.length === 0) {
throw new Error('Insufficient data for training or testing');
}
// Generate embeddings for training data (with caching)
const trainingStart = Date.now();
const trainingEmbeddings = await this.generateTrainingEmbeddings(trainingData, combination, config);
const trainingTime = Date.now() - trainingStart;
// Process test queries in batches
const testingStart = Date.now();
const results = await this.processTestQueries(testingData, trainingEmbeddings, metric, config);
const testingTime = Date.now() - testingStart;
if (results.length === 0) {
throw new Error('No valid test results generated');
}
// Calculate enhanced metrics
const averageScore = results.reduce((sum, r) => sum + r.score, 0) / results.length;
const averageSimilarity = results.reduce((sum, r) => sum + r.similarity, 0) / results.length;
// Calculate confidence interval
const scores = results.map(r => r.score);
const confidenceInterval = this.calculateConfidenceInterval(scores);
const totalTime = Date.now() - startTime;
return {
combination,
averageScore,
totalTests: results.length,
processingTime: totalTime,
detailedMetrics: results[0].detailedMetrics, // Use first result as representative
confidenceInterval,
processingStats: {
trainingTime,
testingTime,
embeddingTime: trainingTime,
memoryUsage: process.memoryUsage().heapUsed / 1024 / 1024
},
embeddingStats: {
trainingEmbeddings: trainingEmbeddings.embeddings.length,
testQueries: results.length,
averageSimilarity
}
};
}
async getTrainingData(config, combination) {
// Get total row count
const tableInfo = await this.db.getTableInfo(config.tableName);
if (!tableInfo)
throw new Error(`Table ${config.tableName} not found`);
// Calculate sample size for training
const sampleSize = Math.min(config.maxTrainingSamples, Math.floor(tableInfo.rowCount * config.trainingRatio));
// Use efficient sampling for large datasets
if (tableInfo.rowCount > 100000) {
return await this.db.getTableDataSample(config.tableName, sampleSize, config.trainingRatio);
}
else {
const allData = await this.db.getTableData(config.tableName);
return this.sampleData(allData, sampleSize, config.dataSamplingStrategy);
}
}
async getTestingData(config, combination) {
const tableInfo = await this.db.getTableInfo(config.tableName);
if (!tableInfo)
throw new Error(`Table ${config.tableName} not found`);
const sampleSize = Math.min(config.maxTestingSamples, Math.floor(tableInfo.rowCount * (1 - config.trainingRatio)));
if (tableInfo.rowCount > 100000) {
return await this.db.getTableDataSample(config.tableName, sampleSize, 1 - config.trainingRatio);
}
else {
const allData = await this.db.getTableData(config.tableName);
return this.sampleData(allData, sampleSize, config.dataSamplingStrategy);
}
}
async generateTrainingEmbeddings(trainingData, combination, config) {
const cacheKey = `${combination.name}_${trainingData.length}`;
if (config.enableCaching && this.embeddingCache.has(cacheKey)) {
console.log(` ๐พ Using cached embeddings for ${combination.name}`);
const cachedEmbeddings = this.embeddingCache.get(cacheKey);
return {
embeddings: cachedEmbeddings.map((embedding, index) => ({
id: `cached_${index}`,
combination,
embedding,
context: trainingData[index] ? this.combineColumns(trainingData[index], combination.columns) : '',
targetValue: trainingData[index]?.[config.answerColumn] || '',
metadata: { cached: true }
})),
combination,
totalRows: trainingData.length
};
}
// Generate new embeddings
const embeddings = await this.embeddings.processTrainingData(trainingData, combination, config.answerColumn);
// Cache embeddings if enabled
if (config.enableCaching) {
const embeddingArrays = embeddings.embeddings.map(e => e.embedding);
this.embeddingCache.set(cacheKey, embeddingArrays);
}
return embeddings;
}
async processTestQueries(testingData, trainingEmbeddings, metric, config) {
const results = [];
const batchSize = config.batchSize;
for (let i = 0; i < testingData.length; i += batchSize) {
const batch = testingData.slice(i, i + batchSize);
for (const testRow of batch) {
const query = testRow[config.queryColumn];
const expectedAnswer = testRow[config.answerColumn];
if (!query || !expectedAnswer)
continue;
try {
// Find best match from training data
const matches = await this.embeddings.processQuery(query, trainingEmbeddings, 1);
if (matches.length === 0)
continue;
const bestMatch = matches[0];
const actualAnswer = bestMatch.result.targetValue;
// Calculate metric
const detailedMetrics = metric.calculate(expectedAnswer, actualAnswer, bestMatch.similarity);
results.push({
query,
expectedAnswer,
actualAnswer,
similarity: bestMatch.similarity,
score: detailedMetrics.overallScore,
detailedMetrics
});
}
catch (error) {
console.warn(` Skipped test query ${i + 1}: ${error instanceof Error ? error.message : String(error)}`);
continue;
}
}
// Progress indicator for large datasets
if (testingData.length > 1000 && i % (batchSize * 10) === 0) {
console.log(` Processed ${Math.min(i + batchSize, testingData.length)}/${testingData.length} test queries`);
}
}
return results;
}
getMetric(metricType) {
if (!this.metricCache.has(metricType)) {
const metric = base_metric_1.MetricFactory.getMetric(metricType);
this.metricCache.set(metricType, metric);
}
return this.metricCache.get(metricType);
}
sampleData(data, sampleSize, strategy) {
if (data.length <= sampleSize)
return data;
switch (strategy) {
case 'random':
return this.randomSample(data, sampleSize);
case 'stratified':
return this.stratifiedSample(data, sampleSize);
case 'sequential':
return this.sequentialSample(data, sampleSize);
default:
return this.randomSample(data, sampleSize);
}
}
randomSample(data, sampleSize) {
const shuffled = [...data].sort(() => Math.random() - 0.5);
return shuffled.slice(0, sampleSize);
}
stratifiedSample(data, sampleSize) {
// Simple stratification - split by data length and sample proportionally
const sorted = [...data].sort((a, b) => a.length - b.length);
const chunkSize = Math.ceil(data.length / 10);
const samples = [];
for (let i = 0; i < 10 && samples.length < sampleSize; i++) {
const chunk = sorted.slice(i * chunkSize, (i + 1) * chunkSize);
const chunkSampleSize = Math.ceil((chunk.length / data.length) * sampleSize);
const chunkSamples = this.randomSample(chunk, Math.min(chunkSampleSize, sampleSize - samples.length));
samples.push(...chunkSamples);
}
return samples.slice(0, sampleSize);
}
sequentialSample(data, sampleSize) {
const step = Math.floor(data.length / sampleSize);
const samples = [];
for (let i = 0; i < sampleSize && i * step < data.length; i++) {
samples.push(data[i * step]);
}
return samples;
}
combineColumns(row, columns) {
return columns
.map(col => row[col])
.filter(val => val !== null && val !== undefined)
.join(' [SEP] ');
}
calculateConfidenceInterval(scores, confidenceLevel = 0.95) {
if (scores.length < 2)
return { lower: scores[0] || 0, upper: scores[0] || 0 };
const mean = scores.reduce((sum, score) => sum + score, 0) / scores.length;
const variance = scores.reduce((sum, score) => sum + Math.pow(score - mean, 2), 0) / (scores.length - 1);
const standardError = Math.sqrt(variance / scores.length);
// Simple t-distribution approximation for 95% confidence
const tValue = 1.96; // Approximate for large samples
return {
lower: Math.max(0, mean - tValue * standardError),
upper: Math.min(1, mean + tValue * standardError)
};
}
calculateEnhancedSummary(results) {
const scores = results.map(r => r.averageScore);
const bestResult = results.reduce((best, current) => current.averageScore > best.averageScore ? current : best);
const worstResult = results.reduce((worst, current) => current.averageScore < worst.averageScore ? current : worst);
// Calculate additional statistics
const sortedScores = [...scores].sort((a, b) => a - b);
const median = sortedScores[Math.floor(sortedScores.length / 2)];
const q1 = sortedScores[Math.floor(sortedScores.length * 0.25)];
const q3 = sortedScores[Math.floor(sortedScores.length * 0.75)];
return {
bestCombination: bestResult.combination,
bestScore: bestResult.averageScore,
worstCombination: worstResult.combination,
worstScore: worstResult.averageScore,
averageScore: scores.reduce((sum, score) => sum + score, 0) / scores.length,
medianScore: median,
q1Score: q1,
q3Score: q3,
totalCombinations: results.length,
totalTests: results.reduce((sum, r) => sum + r.totalTests, 0),
averageConfidence: results.reduce((sum, r) => sum + r.detailedMetrics.confidence, 0) / results.length
};
}
async validateEnhancedConfiguration(config) {
const errors = [];
const warnings = [];
// Basic validation
if (config.batchSize < 1)
errors.push('Batch size must be at least 1');
if (config.maxTrainingSamples < 10)
errors.push('Maximum training samples must be at least 10');
if (config.maxTestingSamples < 5)
errors.push('Maximum testing samples must be at least 5');
// Check if metric exists
try {
base_metric_1.MetricFactory.getMetric(config.metricType);
}
catch (error) {
errors.push(`Metric '${config.metricType}' not found`);
}
// Performance warnings for large datasets
if (config.maxTrainingSamples > 100000) {
warnings.push('Large training sample size may cause memory issues');
}
if (config.batchSize > 1000) {
warnings.push('Large batch size may cause memory issues');
}
return {
isValid: errors.length === 0,
errors,
warnings
};
}
}
exports.EnhancedRAGTester = EnhancedRAGTester;
//# sourceMappingURL=enhanced-tester.js.map