rag-cli-tester
Version:
A lightweight CLI tool for testing RAG (Retrieval-Augmented Generation) systems with different embedding combinations
156 lines • 5.9 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.EmbeddingGenerator = exports.TransformersPipelineProvider = void 0;
class TransformersPipelineProvider {
async createPipeline(task, model) {
// Use eval to bypass TypeScript's import compilation
const transformers = await eval('import("@xenova/transformers")');
return await transformers.pipeline(task, model);
}
}
exports.TransformersPipelineProvider = TransformersPipelineProvider;
class EmbeddingGenerator {
constructor(config, pipelineProvider) {
this.embeddingPipeline = null;
this.config = config;
this.pipelineProvider = pipelineProvider || new TransformersPipelineProvider();
}
async initialize() {
try {
this.embeddingPipeline = await this.pipelineProvider.createPipeline('feature-extraction', this.config.localModel || 'Xenova/all-MiniLM-L6-v2-small');
}
catch (error) {
console.error('Failed to initialize embedding model:', error);
throw error;
}
}
generateColumnCombinations(columns, maxCombinations = 20) {
const combinations = [];
const n = Math.min(columns.length, 5); // Limit to 5 columns max
// Generate all possible combinations from 1 to n columns
for (let i = 1; i <= n; i++) {
const combos = this.getCombinations(columns.slice(0, n), i);
combos.forEach(combo => {
if (combinations.length < maxCombinations) {
combinations.push({
columns: combo,
name: combo.join(' + ')
});
}
});
if (combinations.length >= maxCombinations)
break;
}
return combinations;
}
getCombinations(arr, k) {
if (k === 1)
return arr.map(item => [item]);
if (k === arr.length)
return [arr];
const combinations = [];
for (let i = 0; i <= arr.length - k; i++) {
const head = arr[i];
const tailCombos = this.getCombinations(arr.slice(i + 1), k - 1);
tailCombos.forEach(combo => {
combinations.push([head, ...combo]);
});
}
return combinations;
}
async generateEmbedding(text) {
try {
if (!this.embeddingPipeline) {
throw new Error('Embedding pipeline not initialized');
}
const result = await this.embeddingPipeline(text);
// Convert to flat array if needed
return Array.isArray(result.data) ? result.data : Array.from(result.data);
}
catch (error) {
console.error('Failed to generate embedding:', error);
throw error;
}
}
createContext(row, combination) {
const contextParts = combination.columns
.filter(col => row[col] !== null && row[col] !== undefined)
.map(col => {
const value = row[col];
return `${col}: ${typeof value === 'object' ? JSON.stringify(value) : value}`;
});
return contextParts.join(' | ');
}
async processTrainingData(data, combination, targetColumn, idColumn) {
const embeddings = [];
for (let i = 0; i < data.length; i++) {
const row = data[i];
const context = this.createContext(row, combination);
if (!context.trim()) {
console.warn(`Skipping row ${i} - no valid context generated`);
continue;
}
try {
const embedding = await this.generateEmbedding(context);
embeddings.push({
id: idColumn ? row[idColumn] : `row_${i}`,
combination,
embedding,
context,
targetValue: row[targetColumn],
metadata: {
originalRow: row,
rowIndex: i
}
});
// Progress logging every 50 rows to reduce noise
if ((i + 1) % 50 === 0) {
console.log(` Processed ${i + 1}/${data.length} rows`);
}
}
catch (error) {
console.error(`Failed to process row ${i}:`, error);
continue;
}
}
return {
embeddings,
combination,
totalRows: data.length
};
}
calculateCosineSimilarity(a, b) {
if (a.length !== b.length) {
throw new Error('Vectors must have the same length');
}
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
normA = Math.sqrt(normA);
normB = Math.sqrt(normB);
if (normA === 0 || normB === 0) {
return 0;
}
return dotProduct / (normA * normB);
}
async findBestMatch(queryEmbedding, trainingData, topK = 1) {
const similarities = trainingData.embeddings.map(item => ({
result: item,
similarity: this.calculateCosineSimilarity(queryEmbedding, item.embedding)
}));
// Sort by similarity (highest first)
similarities.sort((a, b) => b.similarity - a.similarity);
return similarities.slice(0, topK);
}
async processQuery(query, trainingData, topK = 1) {
const queryEmbedding = await this.generateEmbedding(query);
return this.findBestMatch(queryEmbedding, trainingData, topK);
}
}
exports.EmbeddingGenerator = EmbeddingGenerator;
//# sourceMappingURL=embeddings.js.map