UNPKG

@quantumai/quantum-cli-core

Version:

Quantum CLI Core - Multi-LLM Collaboration System

671 lines 26.9 kB
/** * @license * Copyright 2025 Google LLC * SPDX-License-Identifier: Apache-2.0 */ import { QueryType } from './types.js'; export class ModelSelector { providers = new Map(); performanceMetrics = new Map(); roundRobinIndex = 0; config; constructor(config) { this.config = { ...{ strategy: 'weighted', enableFailureRecovery: true, maxRetries: 3, }, ...config, }; } /** * Register a provider with the selector */ registerProvider(provider) { this.providers.set(provider.id, provider); // Initialize performance metrics if not exists if (!this.performanceMetrics.has(provider.id)) { this.performanceMetrics.set(provider.id, { providerId: provider.id, averageLatency: 1000, // Default 1s successRate: 0.9, // Default 90% costPerToken: 0.0001, // Default cost lastUpdateTime: new Date(), totalQueries: 0, failureCount: 0, }); } } /** * Select the best provider for a specific prompt with advanced analysis */ async selectBestProvider(prompt, options) { const availableProviders = this.getAvailableProviders(); if (availableProviders.length === 0) { throw new Error('No providers available'); } if (availableProviders.length === 1) { return availableProviders[0]; } // Score each provider based on the prompt and options const providerScores = availableProviders.map((provider) => { let score = 0; const metrics = this.performanceMetrics.get(provider.id); // Base performance score (40% weight) score += (metrics.successRate * 0.3 + (1 - metrics.averageLatency / 10000) * 0.1) * 40; // Cost efficiency score (20% weight) if (options?.maxCost) { const costEfficiency = Math.max(0, 1 - metrics.costPerToken / options.maxCost); score += costEfficiency * 20; } else { score += (1 - metrics.costPerToken / 0.001) * 20; // Normalize against $0.001 baseline } // Contextual suitability (40% weight) if (options?.queryType) { const contextWeight = this.getContextualWeight(provider, options.queryType); score += (contextWeight - 1) * 40; // Convert 1.0-1.5 range to 0-20 points } // Latency penalty if (options?.maxLatency && metrics.averageLatency > options.maxLatency) { score *= 0.5; // Heavy penalty for exceeding latency requirements } return { provider, score }; }); // Sort by score and return the best providerScores.sort((a, b) => b.score - a.score); return providerScores[0].provider; } /** * Select multiple providers for parallel execution */ async selectMultipleProviders(prompt, options) { const availableProviders = this.getAvailableProviders(); const count = Math.min(options?.count || 2, availableProviders.length); if (count === 1) { return [await this.selectBestProvider(prompt, options)]; } // Score all providers const providerScores = availableProviders.map((provider) => { let score = 0; const metrics = this.performanceMetrics.get(provider.id); // Performance score score += metrics.successRate * 30; score += Math.max(0, 1 - metrics.averageLatency / 5000) * 20; // Cost efficiency score += Math.max(0, 1 - metrics.costPerToken / 0.001) * 20; // Contextual fit if (options?.queryType) { score += (this.getContextualWeight(provider, options.queryType) - 1) * 30; } return { provider, score }; }); // Sort by score providerScores.sort((a, b) => b.score - a.score); // Apply diversity selection if requested const diversityWeight = options?.diversityWeight || 0.3; const selectedProviders = []; // Always include the best provider selectedProviders.push(providerScores[0].provider); // Select additional providers with diversity consideration for (let i = 1; i < count && i < providerScores.length; i++) { if (diversityWeight > 0) { // Prefer providers with different capabilities const candidate = providerScores[i].provider; const isDiverse = selectedProviders.every((selected) => this.calculateProviderDiversity(selected, candidate) > diversityWeight); if (isDiverse || selectedProviders.length < 2) { selectedProviders.push(candidate); } } else { selectedProviders.push(providerScores[i].provider); } } return selectedProviders; } /** * Get the best provider for the given context */ selectProvider(context) { const availableProviders = Array.from(this.providers.values()).filter((provider) => provider.isEnabled()); if (availableProviders.length === 0) { return null; } if (availableProviders.length === 1) { return availableProviders[0]; } switch (this.config.strategy) { case 'round-robin': return this.selectRoundRobin(availableProviders); case 'weighted': return this.selectWeighted(availableProviders, context); case 'performance': return this.selectByPerformance(availableProviders, context); case 'cost-optimized': return this.selectByCost(availableProviders, context); case 'quality-first': return this.selectByQuality(availableProviders, context); default: return this.selectWeighted(availableProviders, context); } } /** * Round-robin selection */ selectRoundRobin(providers) { const provider = providers[this.roundRobinIndex % providers.length]; this.roundRobinIndex = (this.roundRobinIndex + 1) % providers.length; return provider; } /** * Weighted selection based on configuration and context */ selectWeighted(providers, context) { const weights = new Map(); for (const provider of providers) { let weight = this.config.providerWeights?.[provider.id] ?? 1.0; // Adjust weight based on context if (context && context.type) { weight *= this.getContextualWeight(provider, context.type); } // Adjust weight based on performance metrics const metrics = this.performanceMetrics.get(provider.id); if (metrics) { // Reduce weight for poor performance weight *= metrics.successRate; weight *= Math.max(0.1, 1 - metrics.averageLatency / 10000); // Penalty for high latency } weights.set(provider.id, Math.max(0.1, weight)); // Minimum weight 0.1 } // Weighted random selection const totalWeight = Array.from(weights.values()).reduce((sum, w) => sum + w, 0); let random = Math.random() * totalWeight; for (const provider of providers) { const weight = weights.get(provider.id); random -= weight; if (random <= 0) { return provider; } } // Fallback to first provider return providers[0]; } /** * Selection based on performance metrics */ selectByPerformance(providers, context) { const thresholds = this.config.performanceThresholds; // Filter providers that meet performance thresholds let eligibleProviders = providers.filter((provider) => { const metrics = this.performanceMetrics.get(provider.id); if (!metrics || !thresholds) return true; return (metrics.successRate >= thresholds.minSuccessRate && metrics.averageLatency <= thresholds.maxLatency && metrics.costPerToken <= thresholds.maxCostPerToken); }); // If no providers meet thresholds, use all available if (eligibleProviders.length === 0) { eligibleProviders = providers; } // Sort by performance score (success rate - normalized latency) eligibleProviders.sort((a, b) => { const metricsA = this.performanceMetrics.get(a.id); const metricsB = this.performanceMetrics.get(b.id); const scoreA = metricsA.successRate - metricsA.averageLatency / 10000; const scoreB = metricsB.successRate - metricsB.averageLatency / 10000; return scoreB - scoreA; }); return eligibleProviders[0]; } /** * Selection based on cost optimization */ selectByCost(providers, context) { const budget = this.config.costBudget; // If approaching budget limit, prefer cheaper providers if (budget && budget.currentUsage / budget.dailyLimit > 0.8) { providers.sort((a, b) => { const metricsA = this.performanceMetrics.get(a.id); const metricsB = this.performanceMetrics.get(b.id); return metricsA.costPerToken - metricsB.costPerToken; }); } else { // Normal weighted selection with cost consideration return this.selectWeighted(providers, context); } return providers[0]; } /** * Selection prioritizing quality over cost */ selectByQuality(providers, context) { // Get provider strengths and match with query type if (context && context.type) { const qualityProviders = providers.filter((provider) => { const strengths = provider.getStrengths(); return this.matchesQueryType(strengths, context.type); }); if (qualityProviders.length > 0) { return this.selectByPerformance(qualityProviders, context); } } // Fallback to performance-based selection return this.selectByPerformance(providers, context); } /** * Get contextual weight adjustment for provider based on query type */ getContextualWeight(provider, queryType) { const strengths = provider.getStrengths(); switch (queryType) { case QueryType.CODE: if (strengths.includes('code-generation') || strengths.includes('technical-accuracy')) { return 1.5; } break; case QueryType.CREATIVE: if (strengths.includes('creative-writing') || strengths.includes('language-fluency')) { return 1.5; } break; case QueryType.ANALYSIS: if (strengths.includes('reasoning') || strengths.includes('analysis')) { return 1.5; } break; case QueryType.SECURITY: if (strengths.includes('security') || strengths.includes('safety-filtering')) { return 1.5; } break; default: return 1.0; } return 1.0; } /** * Check if provider strengths match query type */ matchesQueryType(strengths, queryType) { const typeMapping = { [QueryType.CODE]: [ 'code-generation', 'technical-accuracy', 'programming', ], [QueryType.CREATIVE]: [ 'creative-writing', 'language-fluency', 'storytelling', ], [QueryType.ANALYSIS]: ['reasoning', 'analysis', 'data-analysis'], [QueryType.SECURITY]: ['security', 'safety-filtering', 'compliance'], [QueryType.GENERAL]: ['general-purpose', 'conversation'], }; const requiredStrengths = typeMapping[queryType] || []; return requiredStrengths.some((strength) => strengths.includes(strength)); } /** * Update performance metrics for a provider */ updateMetrics(providerId, latency, success, cost) { const metrics = this.performanceMetrics.get(providerId); if (!metrics) return; // Update running averages const alpha = 0.1; // Learning rate metrics.averageLatency = metrics.averageLatency * (1 - alpha) + latency * alpha; // Update success rate metrics.totalQueries++; if (!success) { metrics.failureCount++; } metrics.successRate = (metrics.totalQueries - metrics.failureCount) / metrics.totalQueries; // Update cost if provided if (cost !== undefined) { metrics.costPerToken = metrics.costPerToken * (1 - alpha) + cost * alpha; } metrics.lastUpdateTime = new Date(); } /** * Get a fallback provider when primary selection fails */ getFallbackProvider(failedProviderId, context) { const availableProviders = Array.from(this.providers.values()).filter((provider) => provider.isEnabled() && provider.id !== failedProviderId); if (availableProviders.length === 0) { return null; } // Use fallback strategy or default to round-robin const strategy = this.config.fallbackStrategy || 'round-robin'; const originalStrategy = this.config.strategy; // Temporarily change strategy for fallback this.config.strategy = strategy; const fallbackProvider = this.selectProvider(context); this.config.strategy = originalStrategy; return fallbackProvider; } /** * Check if provider should be retried based on failure count */ shouldRetryProvider(providerId) { const metrics = this.performanceMetrics.get(providerId); if (!metrics) return true; // Don't retry if failure rate is too high const recentFailureRate = metrics.failureCount / Math.max(1, metrics.totalQueries); return recentFailureRate < 0.5; // Allow retry if failure rate < 50% } /** * Get current performance metrics for all providers */ getPerformanceMetrics() { return new Map(this.performanceMetrics); } /** * Update selection configuration */ updateConfig(config) { this.config = { ...this.config, ...config }; } /** * Get available providers */ getAvailableProviders() { return Array.from(this.providers.values()).filter((provider) => provider.isEnabled()); } /** * Calculate diversity between two providers */ calculateProviderDiversity(provider1, provider2) { const strengths1 = new Set(provider1.getStrengths()); const strengths2 = new Set(provider2.getStrengths()); const intersection = new Set([...strengths1].filter((x) => strengths2.has(x))); const union = new Set([...strengths1, ...strengths2]); // Diversity is 1 - similarity return union.size > 0 ? 1 - intersection.size / union.size : 1; } /** * Get weighted performance score for a provider */ getProviderScore(provider, weights = {}, context) { const defaultWeights = { performance: 0.3, cost: 0.2, reliability: 0.3, contextual: 0.2, }; const w = { ...defaultWeights, ...weights }; const metrics = this.performanceMetrics.get(provider.id); if (!metrics) return 0; let score = 0; // Performance score (latency and success rate) const performanceScore = metrics.successRate * 0.7 + Math.max(0, 1 - metrics.averageLatency / 5000) * 0.3; score += performanceScore * w.performance; // Cost efficiency score const costScore = Math.max(0, 1 - metrics.costPerToken / 0.001); score += costScore * w.cost; // Reliability score (based on recent performance) const reliabilityScore = Math.max(0, 1 - metrics.failureCount / Math.max(1, metrics.totalQueries)); score += reliabilityScore * w.reliability; // Contextual fit score let contextualScore = 1.0; if (context?.type) { contextualScore = this.getContextualWeight(provider, context.type); } score += (contextualScore - 1) * w.contextual; return Math.max(0, Math.min(1, score)); } /** * Adaptive selection based on query analysis and current state */ async selectAdaptive(prompt, options) { const availableProviders = this.getAvailableProviders(); if (availableProviders.length === 0) { throw new Error('No providers available'); } if (availableProviders.length === 1) { return availableProviders[0]; } // Determine weights based on preferences const weights = { performance: 0.25, cost: 0.25, reliability: 0.25, contextual: 0.25, }; if (options?.preferFast) { weights.performance = 0.5; weights.cost = 0.2; weights.reliability = 0.2; weights.contextual = 0.1; } else if (options?.preferCheap) { weights.cost = 0.5; weights.performance = 0.2; weights.reliability = 0.2; weights.contextual = 0.1; } else if (options?.preferAccurate) { weights.reliability = 0.4; weights.contextual = 0.3; weights.performance = 0.2; weights.cost = 0.1; } // Score providers const providerScores = availableProviders.map((provider) => ({ provider, score: this.getProviderScore(provider, weights, { type: options?.queryType, }), })); // Sort by score and return the best providerScores.sort((a, b) => b.score - a.score); return providerScores[0].provider; } /** * Get load balancing recommendation */ getLoadBalancingRecommendation() { const providers = this.getAvailableProviders(); const currentTime = new Date(); // Check recent query distribution const recentLoadDistribution = new Map(); providers.forEach((provider) => { const metrics = this.performanceMetrics.get(provider.id); if (metrics) { recentLoadDistribution.set(provider.id, metrics.totalQueries); } }); const totalQueries = Array.from(recentLoadDistribution.values()).reduce((a, b) => a + b, 0); const averageLoad = totalQueries / providers.length; // Check for load imbalance const imbalancedProviders = providers.filter((provider) => { const load = recentLoadDistribution.get(provider.id) || 0; return Math.abs(load - averageLoad) > averageLoad * 0.3; // 30% threshold }); if (imbalancedProviders.length > 0) { const underutilized = providers.filter((provider) => { const load = recentLoadDistribution.get(provider.id) || 0; return load < averageLoad * 0.7; }); return { recommendation: 'distribute', reason: 'Load imbalance detected - some providers are underutilized', suggestedProviders: underutilized.map((p) => p.id), }; } // Check for performance issues const poorPerformingProviders = providers.filter((provider) => { const metrics = this.performanceMetrics.get(provider.id); return (metrics && (metrics.successRate < 0.8 || metrics.averageLatency > 5000)); }); if (poorPerformingProviders.length > providers.length / 2) { return { recommendation: 'throttle', reason: 'Multiple providers showing poor performance', suggestedProviders: providers .filter((p) => !poorPerformingProviders.includes(p)) .map((p) => p.id), }; } return { recommendation: 'fallback', reason: 'System operating normally', suggestedProviders: providers.map((p) => p.id), }; } /** * Dynamic strategy adjustment based on current conditions */ adjustStrategyDynamically() { const providers = this.getAvailableProviders(); const now = new Date(); // Analyze recent performance patterns let totalSuccessRate = 0; let totalLatency = 0; let totalCost = 0; let validMetrics = 0; providers.forEach((provider) => { const metrics = this.performanceMetrics.get(provider.id); if (metrics && metrics.totalQueries > 0) { totalSuccessRate += metrics.successRate; totalLatency += metrics.averageLatency; totalCost += metrics.costPerToken; validMetrics++; } }); if (validMetrics === 0) { return { recommendedStrategy: 'weighted', reason: 'Insufficient data for optimization', confidence: 0.5, }; } const avgSuccessRate = totalSuccessRate / validMetrics; const avgLatency = totalLatency / validMetrics; const avgCost = totalCost / validMetrics; // High performance variance -> use performance-based selection const performanceVariance = providers.reduce((variance, provider) => { const metrics = this.performanceMetrics.get(provider.id); if (metrics) { const diff = metrics.successRate - avgSuccessRate; return variance + diff * diff; } return variance; }, 0) / validMetrics; if (performanceVariance > 0.1) { return { recommendedStrategy: 'performance', reason: 'High performance variance detected', confidence: 0.8, }; } // High cost variance -> use cost optimization const costVariance = providers.reduce((variance, provider) => { const metrics = this.performanceMetrics.get(provider.id); if (metrics) { const diff = metrics.costPerToken - avgCost; return variance + diff * diff; } return variance; }, 0) / validMetrics; if (costVariance > avgCost * avgCost * 0.25) { return { recommendedStrategy: 'cost-optimized', reason: 'Significant cost differences between providers', confidence: 0.7, }; } // Consistent performance -> use round-robin for fair distribution if (avgSuccessRate > 0.9 && avgLatency < 3000) { return { recommendedStrategy: 'round-robin', reason: 'All providers performing well - distribute load evenly', confidence: 0.9, }; } // Default to weighted selection return { recommendedStrategy: 'weighted', reason: 'Balanced approach recommended', confidence: 0.6, }; } /** * Reset round-robin index (useful for testing) */ resetRoundRobin() { this.roundRobinIndex = 0; } /** * Get provider selection recommendation with reasoning */ getSelectionRecommendation(context, options) { const availableProviders = this.getAvailableProviders(); if (availableProviders.length === 0) { throw new Error('No providers available'); } const reasoning = []; const alternatives = []; // Score all providers const providerScores = availableProviders.map((provider) => { const score = this.getProviderScore(provider, { performance: options?.considerLatency ? 0.4 : 0.3, cost: options?.considerCost ? 0.4 : 0.2, reliability: options?.considerAccuracy ? 0.4 : 0.3, contextual: context?.type ? 0.3 : 0.2, }, context); const metrics = this.performanceMetrics.get(provider.id); // Add to alternatives list alternatives.push({ provider, score }); return { provider, score, metrics }; }); // Sort by score providerScores.sort((a, b) => b.score - a.score); alternatives.sort((a, b) => b.score - a.score); const bestProvider = providerScores[0]; const bestMetrics = bestProvider.metrics; // Generate reasoning reasoning.push(`Selected ${bestProvider.provider.id} with score ${(bestProvider.score * 100).toFixed(1)}%`); if (bestMetrics.successRate > 0.9) { reasoning.push(`High reliability: ${(bestMetrics.successRate * 100).toFixed(1)}% success rate`); } if (bestMetrics.averageLatency < 2000) { reasoning.push(`Fast response: ${bestMetrics.averageLatency}ms average latency`); } if (context?.type) { const contextWeight = this.getContextualWeight(bestProvider.provider, context.type); if (contextWeight > 1.2) { reasoning.push(`Well-suited for ${context.type} queries`); } } // Calculate confidence based on score margin const secondBest = providerScores[1]; const scoreDifference = secondBest ? bestProvider.score - secondBest.score : bestProvider.score; const confidence = Math.min(0.9, 0.5 + scoreDifference); return { provider: bestProvider.provider, reasoning, confidence, alternatives: alternatives.slice(1), // Exclude the selected provider }; } } //# sourceMappingURL=model-selector.js.map