rag-cli-tester
Version:
A lightweight CLI tool for testing RAG (Retrieval-Augmented Generation) systems with different embedding combinations
197 lines ⢠9.63 kB
JavaScript
;
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
var desc = Object.getOwnPropertyDescriptor(m, k);
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
desc = { enumerable: true, get: function() { return m[k]; } };
}
Object.defineProperty(o, k2, desc);
}) : (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
o[k2] = m[k];
}));
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
Object.defineProperty(o, "default", { enumerable: true, value: v });
}) : function(o, v) {
o["default"] = v;
});
var __importStar = (this && this.__importStar) || (function () {
var ownKeys = function(o) {
ownKeys = Object.getOwnPropertyNames || function (o) {
var ar = [];
for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k;
return ar;
};
return ownKeys(o);
};
return function (mod) {
if (mod && mod.__esModule) return mod;
var result = {};
if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]);
__setModuleDefault(result, mod);
return result;
};
})();
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.EmbeddingService = void 0;
const providers_1 = require("./providers");
const embedding_models_1 = require("./models/embedding-models");
const ora_1 = __importDefault(require("ora"));
const chalk_1 = __importDefault(require("chalk"));
const readline = __importStar(require("readline"));
class EmbeddingService {
constructor(database, modelId) {
this.database = database;
const model = (0, embedding_models_1.getModelById)(modelId);
if (!model) {
throw new Error(`Unknown embedding model: ${modelId}`);
}
this.selectedModel = model;
// Create provider based on model type
const config = {
model: model.provider,
localModel: model.modelPath,
openaiApiKey: process.env.OPENAI_API_KEY,
openaiModel: model.apiModel,
geminiApiKey: process.env.GEMINI_API_KEY || process.env.GOOGLE_AI_API_KEY,
geminiModel: model.apiModel
};
this.embeddingProvider = providers_1.ProviderManager.createEmbeddingProvider(config);
}
async initialize() {
await this.embeddingProvider.initialize();
}
async generateEmbeddings(task) {
const spinner = (0, ora_1.default)('Initializing embedding generation...').start();
try {
// Validate table and columns exist
const tableInfo = await this.database.getTableInfo(task.tableName);
if (!tableInfo) {
throw new Error(`Table '${task.tableName}' not found`);
}
// Check if all source columns exist
const missingColumns = task.columns.filter(col => !tableInfo.columns.some(dbCol => dbCol.column_name === col));
if (missingColumns.length > 0) {
throw new Error(`Columns not found in table: ${missingColumns.join(', ')}`);
}
// Check if embedding column exists
const embeddingColumnExists = await this.database.checkColumnExists(task.tableName, task.embeddingColumn);
if (!embeddingColumnExists) {
throw new Error(`Embedding column '${task.embeddingColumn}' not found in table`);
}
// Get all rows from the table
spinner.text = 'Fetching all rows from table...';
const allRows = await this.database.getTableData(task.tableName, task.columns, 10000);
if (allRows.length === 0) {
spinner.succeed(chalk_1.default.green(`ā
No rows found in table '${task.tableName}'`));
return;
}
console.log(chalk_1.default.blue(`\nš Found ${allRows.length} rows to process`));
console.log(chalk_1.default.gray('ā'.repeat(50)));
let totalProcessed = 0;
let totalSkipped = 0;
// Process each row individually
for (let i = 0; i < allRows.length; i++) {
const row = allRows[i];
// Check if embedding already exists
const existingEmbedding = await this.database.getRowColumnValue(task.tableName, row.id, task.embeddingColumn);
if (existingEmbedding !== null && existingEmbedding !== '') {
console.log(chalk_1.default.gray(`āļø Skipping row ${row.id}: embedding already exists`));
totalSkipped++;
continue;
}
// Generate embedding for this row
try {
const text = this.combineColumns(row, task.columns, task.customOrder);
if (!text.trim()) {
console.log(chalk_1.default.yellow(`ā ļø Skipping row ${row.id}: no valid text to embed`));
totalSkipped++;
continue;
}
console.log(chalk_1.default.cyan(`š Processing row ${row.id} (${i + 1}/${allRows.length}):`));
console.log(chalk_1.default.gray(`Combined text: ${text.substring(0, 200)}${text.length > 200 ? '...' : ''}`));
console.log(chalk_1.default.blue(` š Generating embedding (${this.selectedModel.name})...`));
const embedding = await this.embeddingProvider.generateEmbedding(text);
// Validate embedding dimensions
if (embedding.length !== this.selectedModel.dimensions) {
throw new Error(`Expected ${this.selectedModel.dimensions} dimensions, got ${embedding.length}`);
}
console.log(chalk_1.default.green(` ā
Generated embedding: ${embedding.length} dimensions`));
console.log('');
// Save embedding to database
await this.database.updateRowEmbedding(task.tableName, row.id, task.embeddingColumn, embedding);
totalProcessed++;
spinner.text = `Processed ${totalProcessed} rows, skipped ${totalSkipped} rows...`;
// Small delay to respect API rate limits
await this.delay(100);
}
catch (error) {
console.error(chalk_1.default.red(`ā Failed to process row ${row.id}: ${error.message}`));
// Continue with next row
}
}
spinner.succeed(chalk_1.default.green(`ā
Embedding generation completed!`));
console.log(chalk_1.default.blue(`\nš Summary:`));
console.log(chalk_1.default.gray(` ⢠Table: ${task.tableName}`));
console.log(chalk_1.default.gray(` ⢠Source columns: ${task.columns.join(', ')}`));
console.log(chalk_1.default.gray(` ⢠Embedding column: ${task.embeddingColumn}`));
console.log(chalk_1.default.gray(` ⢠Total rows processed: ${totalProcessed}`));
console.log(chalk_1.default.gray(` ⢠Total rows skipped: ${totalSkipped}`));
console.log(chalk_1.default.gray(` ⢠Total rows in table: ${allRows.length}`));
console.log(chalk_1.default.gray(` ⢠Embedding provider: ${this.selectedModel.provider}`));
console.log(chalk_1.default.gray(` ⢠Model: ${this.selectedModel.name} (${this.selectedModel.dimensions} dimensions)`));
console.log('');
}
catch (error) {
spinner.fail(chalk_1.default.red(`ā Embedding generation failed: ${error.message}`));
throw error;
}
}
combineColumns(row, columns, customOrder) {
if (customOrder) {
// Use the exact order specified by user
return columns
.map(col => this.formatColumnValue(row[col]))
.filter(val => val.length > 0)
.join(' ');
}
else {
// Use alphabetical order or natural database order
const sortedColumns = [...columns].sort();
return sortedColumns
.map(col => this.formatColumnValue(row[col]))
.filter(val => val.length > 0)
.join(' ');
}
}
formatColumnValue(value) {
if (value === null || value === undefined) {
return '';
}
if (typeof value === 'object') {
return JSON.stringify(value);
}
return String(value).trim();
}
delay(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
async askUserConfirmation(prompt) {
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout
});
return new Promise((resolve) => {
rl.question(prompt, (answer) => {
rl.close();
const lowerAnswer = answer.toLowerCase().trim();
resolve(lowerAnswer === 'yes' || lowerAnswer === 'y');
});
});
}
}
exports.EmbeddingService = EmbeddingService;
//# sourceMappingURL=embedding-service.js.map