rag-cli-tester
Version:
A lightweight CLI tool for testing RAG (Retrieval-Augmented Generation) systems with different embedding combinations
278 lines ⢠15.2 kB
JavaScript
;
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
var desc = Object.getOwnPropertyDescriptor(m, k);
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
desc = { enumerable: true, get: function() { return m[k]; } };
}
Object.defineProperty(o, k2, desc);
}) : (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
o[k2] = m[k];
}));
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
Object.defineProperty(o, "default", { enumerable: true, value: v });
}) : function(o, v) {
o["default"] = v;
});
var __importStar = (this && this.__importStar) || (function () {
var ownKeys = function(o) {
ownKeys = Object.getOwnPropertyNames || function (o) {
var ar = [];
for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k;
return ar;
};
return ownKeys(o);
};
return function (mod) {
if (mod && mod.__esModule) return mod;
var result = {};
if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]);
__setModuleDefault(result, mod);
return result;
};
})();
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.EmbeddingService = void 0;
const providers_1 = require("./providers");
const ora_1 = __importDefault(require("ora"));
const chalk_1 = __importDefault(require("chalk"));
const readline = __importStar(require("readline"));
class EmbeddingService {
constructor(database, config) {
this.database = database;
this.config = config;
this.embeddingProvider = providers_1.ProviderManager.createEmbeddingProvider(config);
}
async initialize() {
await this.embeddingProvider.initialize();
}
async generateEmbeddings(task) {
const spinner = (0, ora_1.default)('Initializing embedding generation...').start();
try {
// Validate table and columns exist
const tableInfo = await this.database.getTableInfo(task.tableName);
if (!tableInfo) {
throw new Error(`Table '${task.tableName}' not found`);
}
// Check if all source columns exist
const missingColumns = task.columns.filter(col => !tableInfo.columns.some(dbCol => dbCol.column_name === col));
if (missingColumns.length > 0) {
throw new Error(`Columns not found in table: ${missingColumns.join(', ')}`);
}
// Check if embedding column exists
const embeddingColumnExists = await this.database.checkColumnExists(task.tableName, task.embeddingColumn);
if (!embeddingColumnExists) {
throw new Error(`Embedding column '${task.embeddingColumn}' not found in table`);
}
// Check if embedding column already has data and warn user
spinner.text = 'Checking for existing embeddings...';
const existingEmbeddingCount = await this.database.getColumnDataCount(task.tableName, task.embeddingColumn);
console.log(chalk_1.default.gray(` Found ${existingEmbeddingCount} rows with existing embeddings`));
if (existingEmbeddingCount > 0) {
spinner.stop();
console.log(chalk_1.default.yellow(`ā ļø Warning: Column '${task.embeddingColumn}' already contains embeddings in ${existingEmbeddingCount} rows!`));
console.log(chalk_1.default.yellow(` This operation will overwrite existing embeddings.`));
const confirm = await this.askUserConfirmation(`Are you sure you want to continue and potentially overwrite existing embeddings in '${task.embeddingColumn}'? (yes/no): `);
if (!confirm) {
console.log(chalk_1.default.blue('Operation cancelled by user.'));
return;
}
spinner.start('Continuing with embedding generation...');
}
// Get total count of rows that need processing
spinner.text = 'Counting rows that need processing...';
const totalRowsToProcess = await this.database.getEmptyColumnCount(task.tableName, task.embeddingColumn);
console.log(chalk_1.default.gray(` Found ${totalRowsToProcess} rows that need embeddings`));
if (totalRowsToProcess === 0) {
spinner.succeed(chalk_1.default.green(`ā
Column '${task.embeddingColumn}' already has all embeddings generated!`));
return;
}
spinner.text = `Found ${totalRowsToProcess} rows to process...`;
console.log(chalk_1.default.blue(`\nš Total rows to process: ${totalRowsToProcess}`));
console.log(chalk_1.default.gray('ā'.repeat(50)));
let totalProcessed = 0;
let processedRowIds = new Set(); // Track processed rows to avoid duplicates
// Process rows in batches but track each row individually
while (totalProcessed < totalRowsToProcess) {
const remainingRows = totalRowsToProcess - totalProcessed;
const currentBatchSize = Math.min(task.batchSize, remainingRows);
console.log(chalk_1.default.gray(`\nš Fetching batch of ${currentBatchSize} rows...`));
// Get next batch of unprocessed rows
const rows = await this.database.getRowsWithoutEmbeddings(task.tableName, task.embeddingColumn, task.columns, currentBatchSize);
if (rows.length === 0) {
console.log(chalk_1.default.yellow(`ā ļø No more rows to process`));
break; // No more rows to process
}
console.log(chalk_1.default.gray(`š Retrieved ${rows.length} rows from database`));
// Debug: show first few rows
if (rows.length > 0) {
console.log(chalk_1.default.gray(`š First row sample: ${JSON.stringify(rows[0], null, 2).substring(0, 200)}...`));
// Validate row structure
const firstRow = rows[0];
if (firstRow && typeof firstRow === 'object') {
console.log(chalk_1.default.gray(`š Row keys: ${Object.keys(firstRow).join(', ')}`));
console.log(chalk_1.default.gray(`š Row ID value: ${firstRow.id || 'undefined'}`));
console.log(chalk_1.default.gray(`š Row ID type: ${typeof firstRow.id}`));
}
}
// Get the ID column name from the first row
const tableInfo = await this.database.getTableInfo(task.tableName);
const idColumn = tableInfo?.primaryKey || 'id';
// Filter out rows that have already been processed
const unprocessedRows = rows.filter(row => !processedRowIds.has(row[idColumn]));
if (unprocessedRows.length === 0) {
console.log(chalk_1.default.yellow(`ā ļø All rows in this batch were already processed`));
break; // All rows in this batch were already processed
}
spinner.text = `Processing batch ${Math.floor(totalProcessed / task.batchSize) + 1} (${unprocessedRows.length} rows)...`;
console.log(chalk_1.default.blue(`\nš Processing batch ${Math.floor(totalProcessed / task.batchSize) + 1} (${unprocessedRows.length} rows)`));
console.log(chalk_1.default.gray('ā'.repeat(50)));
const results = await this.processBatch(unprocessedRows, task, idColumn);
spinner.text = `Updating database with ${results.length} embeddings...`;
console.log(chalk_1.default.green(`ā
Generated embeddings for ${results.length} rows`));
// Update database and track processed rows
for (const result of results) {
await this.database.updateRowEmbedding(task.tableName, result.id, task.embeddingColumn, result.embedding, idColumn);
processedRowIds.add(result.id);
}
totalProcessed += results.length;
const progressPercentage = Math.round((totalProcessed / totalRowsToProcess) * 100);
spinner.text = `Processed ${totalProcessed}/${totalRowsToProcess} rows (${progressPercentage}%)...`;
console.log(chalk_1.default.gray('ā'.repeat(50)));
console.log('');
// Add delay between batches to respect API rate limits
if (totalProcessed < totalRowsToProcess) {
await this.delay(500);
}
}
spinner.succeed(chalk_1.default.green(`ā
Successfully generated embeddings for ${totalProcessed} rows`));
console.log(chalk_1.default.blue(`\nš Summary:`));
console.log(chalk_1.default.gray(` ⢠Table: ${task.tableName}`));
console.log(chalk_1.default.gray(` ⢠Source columns: ${task.columns.join(', ')}`));
console.log(chalk_1.default.gray(` ⢠Embedding column: ${task.embeddingColumn}`));
console.log(chalk_1.default.gray(` ⢠Total rows processed: ${totalProcessed}`));
console.log(chalk_1.default.gray(` ⢠Total rows in table: ${totalRowsToProcess + (await this.database.getColumnDataCount(task.tableName, task.embeddingColumn))}`));
console.log(chalk_1.default.gray(` ⢠Embedding provider: ${this.config.model}`));
console.log(chalk_1.default.gray(` ⢠Batch size used: ${task.batchSize}`));
console.log('');
}
catch (error) {
spinner.fail(chalk_1.default.red(`ā Embedding generation failed: ${error.message}`));
throw error;
}
}
async processBatch(rows, task, idColumn = 'id') {
const results = [];
for (const row of rows) {
try {
// Validate row data
if (!row || typeof row !== 'object') {
console.warn(chalk_1.default.yellow(`ā ļø Skipping invalid row: ${JSON.stringify(row)}`));
continue;
}
const rowId = row[idColumn];
if (!rowId) {
console.warn(chalk_1.default.yellow(`ā ļø Skipping row without ${idColumn}: ${JSON.stringify(row)}`));
continue;
}
const text = this.combineColumns(row, task.columns, task.customOrder);
if (!text.trim()) {
console.warn(chalk_1.default.yellow(`ā ļø Skipping row ${rowId}: no valid text generated from columns`));
continue;
}
// Print the source text for each row being processed
console.log(chalk_1.default.cyan(`š Processing row ${rowId} (${rows.indexOf(row) + 1}/${rows.length}):`));
console.log(chalk_1.default.gray(`Combined text: ${text.substring(0, 200)}${text.length > 200 ? '...' : ''}`));
console.log(chalk_1.default.blue(` š Generating embedding (${this.config.model})...`));
const embedding = await this.embeddingProvider.generateEmbedding(text);
// Validate embedding dimensions
if (!Array.isArray(embedding) || embedding.length === 0) {
throw new Error(`Invalid embedding: expected array, got ${typeof embedding}`);
}
// For All-MiniLM-L6-v2, we expect 384 dimensions
const expectedDimensions = 384;
if (embedding.length !== expectedDimensions) {
throw new Error(`Expected ${expectedDimensions} dimensions, got ${embedding.length}`);
}
console.log(chalk_1.default.green(` ā
Generated embedding: ${embedding.length} dimensions`));
console.log('');
results.push({
id: rowId,
embedding
});
// Add small delay to respect API rate limits
await this.delay(100);
}
catch (error) {
const rowId = row ? row[idColumn] || 'unknown' : 'unknown';
console.error(chalk_1.default.red(`ā Failed to process row ${rowId}: ${error.message}`));
// Continue with other rows
}
}
return results;
}
combineColumns(row, columns, customOrder) {
if (customOrder) {
// Use the exact order specified by user
return columns
.map(col => this.formatColumnValue(row[col]))
.filter(val => val.length > 0)
.join(' ');
}
else {
// Use alphabetical order or natural database order
const sortedColumns = [...columns].sort();
return sortedColumns
.map(col => this.formatColumnValue(row[col]))
.filter(val => val.length > 0)
.join(' ');
}
}
formatColumnValue(value) {
if (value === null || value === undefined) {
return '';
}
if (typeof value === 'object') {
return JSON.stringify(value);
}
return String(value).trim();
}
delay(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
async askUserConfirmation(prompt) {
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout
});
return new Promise((resolve) => {
rl.question(prompt, (answer) => {
rl.close();
const lowerAnswer = answer.toLowerCase().trim();
resolve(lowerAnswer === 'yes' || lowerAnswer === 'y');
});
});
}
async getEmbeddingProgress(tableName, embeddingColumn) {
try {
const tableInfo = await this.database.getTableInfo(tableName);
const total = tableInfo?.rowCount || 0;
const remainingRows = await this.database.getRowsWithoutEmbeddings(tableName, embeddingColumn, ['id'], 10000 // Large limit to get accurate count
);
const remaining = remainingRows.length;
const completed = total - remaining;
const percentage = total > 0 ? Math.round((completed / total) * 100) : 0;
return { total, completed, remaining, percentage };
}
catch (error) {
console.error('Failed to get embedding progress:', error);
return { total: 0, completed: 0, remaining: 0, percentage: 0 };
}
}
}
exports.EmbeddingService = EmbeddingService;
//# sourceMappingURL=embedding-service.js.map