UNPKG

rag-cli-tester

Version:

A lightweight CLI tool for testing RAG (Retrieval-Augmented Generation) systems with different embedding combinations

392 lines โ€ข 18.1 kB
"use strict"; var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.ProviderManager = exports.OpenAICompatibleLLMProvider = exports.GeminiEmbeddingProvider = exports.OpenAIEmbeddingProvider = exports.LocalEmbeddingProvider = void 0; const chalk_1 = __importDefault(require("chalk")); class LocalEmbeddingProvider { constructor(config) { this.pipeline = null; this.config = config; } async initialize() { try { const modelName = this.config.localModel || 'Xenova/all-MiniLM-L6-v2'; console.log(`๐Ÿ”„ Initializing local embedding model: ${modelName}`); console.log(`๐Ÿ”ง Model configuration:`, { configModel: this.config.localModel, defaultModel: 'Xenova/all-MiniLM-L6-v2', finalModel: modelName, envModel: process.env.EMBEDDING_MODEL, envLocalModel: process.env.LOCAL_EMBEDDING_MODEL }); // Ensure we're using the correct model name if (!modelName.includes('all-MiniLM-L6-v2')) { console.warn(`โš ๏ธ Warning: Expected 'all-MiniLM-L6-v2' model, got '${modelName}'`); console.warn(` This might cause dimension mismatch issues.`); } const transformers = await eval('import("@xenova/transformers")'); console.log(`๐Ÿ“ฆ Transformers library loaded successfully`); this.pipeline = await transformers.pipeline('feature-extraction', modelName); console.log(`โœ… Local embedding model initialized successfully`); // Test the model with a simple input to verify it's working console.log(`๐Ÿงช Testing model with sample input...`); const testResult = await this.pipeline('test'); console.log(`โœ… Model test successful. Output shape:`, { hasData: !!testResult.data, dataType: typeof testResult.data, dataLength: testResult.data ? (Array.isArray(testResult.data) ? testResult.data.length : 'not array') : 'no data' }); if (testResult.data && Array.isArray(testResult.data)) { console.log(`โœ… Expected 384 dimensions, got ${testResult.data.length}`); if (testResult.data.length !== 384) { console.warn(`โš ๏ธ Warning: Model returned ${testResult.data.length} dimensions instead of expected 384`); console.warn(` This might indicate the wrong model was loaded or there's a configuration issue.`); // Try to determine what happened if (testResult.data.length > 1000) { console.warn(` Large dimension count (${testResult.data.length}) suggests the model output is not being processed correctly.`); console.warn(` This could be due to a flattened 2D array or incorrect model loading.`); } } } else { console.warn(`โš ๏ธ Warning: Model test result does not have expected structure`); console.warn(` Result:`, testResult); } } catch (error) { console.error('Failed to initialize local embedding model:', error); throw error; } } async generateEmbedding(text) { if (!this.pipeline) { throw new Error('Pipeline not initialized'); } try { console.log(` ๐Ÿ“ Input text length: ${text.length} characters`); const result = await this.pipeline(text); console.log(` ๐Ÿ” Raw model output:`, { hasResult: !!result, resultType: typeof result, hasData: result && !!result.data, dataType: result && result.data ? typeof result.data : 'no data', dataLength: result && result.data ? (Array.isArray(result.data) ? result.data.length : 'not array') : 'no data' }); // Declare embedding variable let embedding; // If result.data is a 2D array, we need to handle it differently if (result && result.data && Array.isArray(result.data) && result.data.length > 0 && Array.isArray(result.data[0])) { console.log(` ๐Ÿ” Detected 2D array output: ${result.data.length}x${result.data[0].length}`); console.log(` ๐Ÿ” This suggests the model is returning a batch of embeddings instead of a single embedding`); // Take the first embedding if it's a batch if (result.data.length === 1) { embedding = result.data[0]; console.log(` โœ… Using first (and only) embedding from batch`); } else { console.warn(` โš ๏ธ Model returned ${result.data.length} embeddings, using the first one`); embedding = result.data[0]; } } else { // Extract the embedding data if (result && result.data) { // Handle different result formats if (Array.isArray(result.data)) { embedding = result.data; console.log(` โœ… Using result.data (array)`); } else if (result.data.data) { // Sometimes the data is nested embedding = Array.from(result.data.data); console.log(` โœ… Using result.data.data (nested)`); } else { // Convert to array if it's a tensor-like object embedding = Array.from(result.data); console.log(` โœ… Using result.data (converted to array)`); } } else if (Array.isArray(result)) { embedding = result; console.log(` โœ… Using result directly (array)`); } else { // Fallback: try to convert the entire result embedding = Array.from(result); console.log(` โœ… Using result (converted to array)`); } } // Validate embedding dimensions if (!Array.isArray(embedding) || embedding.length === 0) { throw new Error(`Invalid embedding format: expected array, got ${typeof embedding}`); } // For All-MiniLM-L6-v2, we expect 384 dimensions const expectedDimensions = 384; if (embedding.length !== expectedDimensions) { console.warn(`โš ๏ธ Warning: Expected ${expectedDimensions} dimensions, got ${embedding.length}`); console.warn(` This might indicate a model loading issue.`); // If the embedding is too large, it might be a flattened 2D array if (embedding.length > expectedDimensions * 10) { console.warn(` Large embedding detected (${embedding.length} dimensions).`); console.warn(` This suggests the model output is not being processed correctly.`); throw new Error(`Expected ${expectedDimensions} dimensions, got ${embedding.length} - model may not be loaded correctly`); } } return embedding; } catch (error) { console.error('Failed to generate embedding:', error); throw error; } } } exports.LocalEmbeddingProvider = LocalEmbeddingProvider; class OpenAIEmbeddingProvider { constructor(config) { this.config = config; } async initialize() { if (!this.config.openaiApiKey) { throw new Error('OpenAI API key is required'); } } async generateEmbedding(text) { try { const response = await fetch('https://api.openai.com/v1/embeddings', { method: 'POST', headers: { 'Authorization': `Bearer ${this.config.openaiApiKey}`, 'Content-Type': 'application/json', }, body: JSON.stringify({ model: this.config.openaiModel || 'text-embedding-3-small', input: text, }), }); const data = await response.json(); if (!response.ok) { throw new Error(`OpenAI API error: ${data.error?.message || response.statusText}`); } return data.data[0].embedding; } catch (error) { console.error('Failed to generate OpenAI embedding:', error); throw error; } } } exports.OpenAIEmbeddingProvider = OpenAIEmbeddingProvider; class GeminiEmbeddingProvider { constructor(config) { this.config = config; } async initialize() { if (!this.config.geminiApiKey) { throw new Error('Gemini API key is required'); } } async generateEmbedding(text) { try { const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${this.config.geminiModel || 'embedding-001'}:embedContent?key=${this.config.geminiApiKey}`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({ content: { parts: [{ text }] } }), }); const data = await response.json(); if (!response.ok) { throw new Error(`Gemini API error: ${data.error?.message || response.statusText}`); } return data.embedding.values; } catch (error) { console.error('Failed to generate Gemini embedding:', error); throw error; } } } exports.GeminiEmbeddingProvider = GeminiEmbeddingProvider; class OpenAICompatibleLLMProvider { constructor(config) { this.config = config; } async initialize() { if (!this.config.apiKey) { throw new Error('API key is required'); } } async generateText(prompt, context) { try { // Determine the API endpoint and format based on provider const { endpoint, headers, body, responseExtractor } = this.getProviderConfig(prompt, context); console.log(chalk_1.default.gray(` ๐Ÿ”— API Endpoint: ${endpoint}`)); console.log(chalk_1.default.gray(` ๐Ÿ“ค Request Body: ${JSON.stringify(body, null, 2)}`)); const response = await fetch(endpoint, { method: 'POST', headers, body: JSON.stringify(body), }); // Check if response is HTML (error page) instead of JSON const contentType = response.headers.get('content-type'); if (contentType && contentType.includes('text/html')) { const htmlResponse = await response.text(); console.error(chalk_1.default.red(`โŒ API returned HTML instead of JSON:`)); console.error(chalk_1.default.red(` Status: ${response.status} ${response.statusText}`)); console.error(chalk_1.default.red(` Endpoint: ${endpoint}`)); console.error(chalk_1.default.red(` Response preview: ${htmlResponse.substring(0, 200)}...`)); throw new Error(`API endpoint returned HTML instead of JSON. Check your endpoint URL: ${endpoint}`); } const data = await response.json(); if (!response.ok) { throw new Error(`${this.config.provider} API error: ${data.error?.message || response.statusText}`); } return responseExtractor(data); } catch (error) { console.error(`Failed to generate ${this.config.provider} text:`, error); throw error; } } getProviderConfig(prompt, context) { const basePrompt = `${prompt}\n\nContext: ${context}`; switch (this.config.provider) { case 'openai': return { endpoint: 'https://api.openai.com/v1/chat/completions', headers: { 'Authorization': `Bearer ${this.config.apiKey}`, 'Content-Type': 'application/json', }, body: { model: this.config.model || 'gpt-4o', messages: [ { role: 'system', content: prompt }, { role: 'user', content: context } ], temperature: 0.7, max_tokens: 1000, }, responseExtractor: (data) => data.choices[0].message.content.trim() }; case 'gemini': return { endpoint: `https://generativelanguage.googleapis.com/v1beta/models/${this.config.model || 'gemini-1.5-flash'}:generateContent?key=${this.config.apiKey}`, headers: { 'Content-Type': 'application/json', }, body: { contents: [{ parts: [{ text: basePrompt }] }] }, responseExtractor: (data) => data.candidates[0].content.parts[0].text.trim() }; case 'anthropic': return { endpoint: 'https://api.anthropic.com/v1/messages', headers: { 'x-api-key': this.config.apiKey, 'Content-Type': 'application/json', 'anthropic-version': '2023-06-01', }, body: { model: this.config.model || 'claude-3-sonnet-20240229', max_tokens: 1000, messages: [ { role: 'user', content: basePrompt } ], }, responseExtractor: (data) => data.content[0].text.trim() }; case 'custom': // For any OpenAI-compatible API let customEndpoint = this.config.endpoint || 'https://api.openai.com/v1/chat/completions'; // Ensure the endpoint has the correct path for chat completions if (!customEndpoint.endsWith('/chat/completions')) { if (customEndpoint.endsWith('/')) { customEndpoint = customEndpoint + 'chat/completions'; } else if (customEndpoint.endsWith('/v1')) { customEndpoint = customEndpoint + '/chat/completions'; } else if (customEndpoint.endsWith('/v1/')) { customEndpoint = customEndpoint + 'chat/completions'; } else { customEndpoint = customEndpoint + '/chat/completions'; } } return { endpoint: customEndpoint, headers: { 'Authorization': `Bearer ${this.config.apiKey}`, 'Content-Type': 'application/json', }, body: { model: this.config.model || 'gpt-4o', messages: [ { role: 'system', content: prompt }, { role: 'user', content: context } ], temperature: 0.7, max_tokens: 1000, enable_thinking: false, stream: false, }, responseExtractor: (data) => data.choices[0].message.content.trim() }; default: throw new Error(`Unsupported provider: ${this.config.provider}`); } } } exports.OpenAICompatibleLLMProvider = OpenAICompatibleLLMProvider; class ProviderManager { static createEmbeddingProvider(config) { switch (config.model) { case 'openai': return new OpenAIEmbeddingProvider(config); case 'gemini': return new GeminiEmbeddingProvider(config); case 'local': default: return new LocalEmbeddingProvider(config); } } static createLLMProvider(config) { return new OpenAICompatibleLLMProvider(config); } static detectAvailableProviders() { const providers = { embedding: ['local'], llm: [] }; // Check for API keys in environment if (process.env.OPENAI_API_KEY) { providers.embedding.push('openai'); providers.llm.push('openai'); } if (process.env.GEMINI_API_KEY || process.env.GOOGLE_AI_API_KEY) { providers.embedding.push('gemini'); providers.llm.push('gemini'); } if (process.env.ANTHROPIC_API_KEY) { providers.llm.push('anthropic'); } // Check for custom OpenAI-compatible API if (process.env.CUSTOM_API_KEY) { providers.llm.push('custom'); } return providers; } } exports.ProviderManager = ProviderManager; //# sourceMappingURL=providers.js.map