UNPKG

rag-cli-tester

Version:

A lightweight CLI tool for testing RAG (Retrieval-Augmented Generation) systems with different embedding combinations

197 lines • 9.63 kB
"use strict"; var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { if (k2 === undefined) k2 = k; var desc = Object.getOwnPropertyDescriptor(m, k); if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { desc = { enumerable: true, get: function() { return m[k]; } }; } Object.defineProperty(o, k2, desc); }) : (function(o, m, k, k2) { if (k2 === undefined) k2 = k; o[k2] = m[k]; })); var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { Object.defineProperty(o, "default", { enumerable: true, value: v }); }) : function(o, v) { o["default"] = v; }); var __importStar = (this && this.__importStar) || (function () { var ownKeys = function(o) { ownKeys = Object.getOwnPropertyNames || function (o) { var ar = []; for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k; return ar; }; return ownKeys(o); }; return function (mod) { if (mod && mod.__esModule) return mod; var result = {}; if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]); __setModuleDefault(result, mod); return result; }; })(); var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.EmbeddingService = void 0; const providers_1 = require("./providers"); const embedding_models_1 = require("./models/embedding-models"); const ora_1 = __importDefault(require("ora")); const chalk_1 = __importDefault(require("chalk")); const readline = __importStar(require("readline")); class EmbeddingService { constructor(database, modelId) { this.database = database; const model = (0, embedding_models_1.getModelById)(modelId); if (!model) { throw new Error(`Unknown embedding model: ${modelId}`); } this.selectedModel = model; // Create provider based on model type const config = { model: model.provider, localModel: model.modelPath, openaiApiKey: process.env.OPENAI_API_KEY, openaiModel: model.apiModel, geminiApiKey: process.env.GEMINI_API_KEY || process.env.GOOGLE_AI_API_KEY, geminiModel: model.apiModel }; this.embeddingProvider = providers_1.ProviderManager.createEmbeddingProvider(config); } async initialize() { await this.embeddingProvider.initialize(); } async generateEmbeddings(task) { const spinner = (0, ora_1.default)('Initializing embedding generation...').start(); try { // Validate table and columns exist const tableInfo = await this.database.getTableInfo(task.tableName); if (!tableInfo) { throw new Error(`Table '${task.tableName}' not found`); } // Check if all source columns exist const missingColumns = task.columns.filter(col => !tableInfo.columns.some(dbCol => dbCol.column_name === col)); if (missingColumns.length > 0) { throw new Error(`Columns not found in table: ${missingColumns.join(', ')}`); } // Check if embedding column exists const embeddingColumnExists = await this.database.checkColumnExists(task.tableName, task.embeddingColumn); if (!embeddingColumnExists) { throw new Error(`Embedding column '${task.embeddingColumn}' not found in table`); } // Get all rows from the table spinner.text = 'Fetching all rows from table...'; const allRows = await this.database.getTableData(task.tableName, task.columns, 10000); if (allRows.length === 0) { spinner.succeed(chalk_1.default.green(`āœ… No rows found in table '${task.tableName}'`)); return; } console.log(chalk_1.default.blue(`\nšŸ“Š Found ${allRows.length} rows to process`)); console.log(chalk_1.default.gray('─'.repeat(50))); let totalProcessed = 0; let totalSkipped = 0; // Process each row individually for (let i = 0; i < allRows.length; i++) { const row = allRows[i]; // Check if embedding already exists const existingEmbedding = await this.database.getRowColumnValue(task.tableName, row.id, task.embeddingColumn); if (existingEmbedding !== null && existingEmbedding !== '') { console.log(chalk_1.default.gray(`ā­ļø Skipping row ${row.id}: embedding already exists`)); totalSkipped++; continue; } // Generate embedding for this row try { const text = this.combineColumns(row, task.columns, task.customOrder); if (!text.trim()) { console.log(chalk_1.default.yellow(`āš ļø Skipping row ${row.id}: no valid text to embed`)); totalSkipped++; continue; } console.log(chalk_1.default.cyan(`šŸ“ Processing row ${row.id} (${i + 1}/${allRows.length}):`)); console.log(chalk_1.default.gray(`Combined text: ${text.substring(0, 200)}${text.length > 200 ? '...' : ''}`)); console.log(chalk_1.default.blue(` šŸ”„ Generating embedding (${this.selectedModel.name})...`)); const embedding = await this.embeddingProvider.generateEmbedding(text); // Validate embedding dimensions if (embedding.length !== this.selectedModel.dimensions) { throw new Error(`Expected ${this.selectedModel.dimensions} dimensions, got ${embedding.length}`); } console.log(chalk_1.default.green(` āœ… Generated embedding: ${embedding.length} dimensions`)); console.log(''); // Save embedding to database await this.database.updateRowEmbedding(task.tableName, row.id, task.embeddingColumn, embedding); totalProcessed++; spinner.text = `Processed ${totalProcessed} rows, skipped ${totalSkipped} rows...`; // Small delay to respect API rate limits await this.delay(100); } catch (error) { console.error(chalk_1.default.red(`āŒ Failed to process row ${row.id}: ${error.message}`)); // Continue with next row } } spinner.succeed(chalk_1.default.green(`āœ… Embedding generation completed!`)); console.log(chalk_1.default.blue(`\nšŸ“Š Summary:`)); console.log(chalk_1.default.gray(` • Table: ${task.tableName}`)); console.log(chalk_1.default.gray(` • Source columns: ${task.columns.join(', ')}`)); console.log(chalk_1.default.gray(` • Embedding column: ${task.embeddingColumn}`)); console.log(chalk_1.default.gray(` • Total rows processed: ${totalProcessed}`)); console.log(chalk_1.default.gray(` • Total rows skipped: ${totalSkipped}`)); console.log(chalk_1.default.gray(` • Total rows in table: ${allRows.length}`)); console.log(chalk_1.default.gray(` • Embedding provider: ${this.selectedModel.provider}`)); console.log(chalk_1.default.gray(` • Model: ${this.selectedModel.name} (${this.selectedModel.dimensions} dimensions)`)); console.log(''); } catch (error) { spinner.fail(chalk_1.default.red(`āŒ Embedding generation failed: ${error.message}`)); throw error; } } combineColumns(row, columns, customOrder) { if (customOrder) { // Use the exact order specified by user return columns .map(col => this.formatColumnValue(row[col])) .filter(val => val.length > 0) .join(' '); } else { // Use alphabetical order or natural database order const sortedColumns = [...columns].sort(); return sortedColumns .map(col => this.formatColumnValue(row[col])) .filter(val => val.length > 0) .join(' '); } } formatColumnValue(value) { if (value === null || value === undefined) { return ''; } if (typeof value === 'object') { return JSON.stringify(value); } return String(value).trim(); } delay(ms) { return new Promise(resolve => setTimeout(resolve, ms)); } async askUserConfirmation(prompt) { const rl = readline.createInterface({ input: process.stdin, output: process.stdout }); return new Promise((resolve) => { rl.question(prompt, (answer) => { rl.close(); const lowerAnswer = answer.toLowerCase().trim(); resolve(lowerAnswer === 'yes' || lowerAnswer === 'y'); }); }); } } exports.EmbeddingService = EmbeddingService; //# sourceMappingURL=embedding-service.js.map