@quantumai/quantum-cli-core
Version:
Quantum CLI Core - Multi-LLM Collaboration System
671 lines • 26.9 kB
JavaScript
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { QueryType } from './types.js';
export class ModelSelector {
providers = new Map();
performanceMetrics = new Map();
roundRobinIndex = 0;
config;
constructor(config) {
this.config = {
...{
strategy: 'weighted',
enableFailureRecovery: true,
maxRetries: 3,
},
...config,
};
}
/**
* Register a provider with the selector
*/
registerProvider(provider) {
this.providers.set(provider.id, provider);
// Initialize performance metrics if not exists
if (!this.performanceMetrics.has(provider.id)) {
this.performanceMetrics.set(provider.id, {
providerId: provider.id,
averageLatency: 1000, // Default 1s
successRate: 0.9, // Default 90%
costPerToken: 0.0001, // Default cost
lastUpdateTime: new Date(),
totalQueries: 0,
failureCount: 0,
});
}
}
/**
* Select the best provider for a specific prompt with advanced analysis
*/
async selectBestProvider(prompt, options) {
const availableProviders = this.getAvailableProviders();
if (availableProviders.length === 0) {
throw new Error('No providers available');
}
if (availableProviders.length === 1) {
return availableProviders[0];
}
// Score each provider based on the prompt and options
const providerScores = availableProviders.map((provider) => {
let score = 0;
const metrics = this.performanceMetrics.get(provider.id);
// Base performance score (40% weight)
score +=
(metrics.successRate * 0.3 +
(1 - metrics.averageLatency / 10000) * 0.1) *
40;
// Cost efficiency score (20% weight)
if (options?.maxCost) {
const costEfficiency = Math.max(0, 1 - metrics.costPerToken / options.maxCost);
score += costEfficiency * 20;
}
else {
score += (1 - metrics.costPerToken / 0.001) * 20; // Normalize against $0.001 baseline
}
// Contextual suitability (40% weight)
if (options?.queryType) {
const contextWeight = this.getContextualWeight(provider, options.queryType);
score += (contextWeight - 1) * 40; // Convert 1.0-1.5 range to 0-20 points
}
// Latency penalty
if (options?.maxLatency && metrics.averageLatency > options.maxLatency) {
score *= 0.5; // Heavy penalty for exceeding latency requirements
}
return { provider, score };
});
// Sort by score and return the best
providerScores.sort((a, b) => b.score - a.score);
return providerScores[0].provider;
}
/**
* Select multiple providers for parallel execution
*/
async selectMultipleProviders(prompt, options) {
const availableProviders = this.getAvailableProviders();
const count = Math.min(options?.count || 2, availableProviders.length);
if (count === 1) {
return [await this.selectBestProvider(prompt, options)];
}
// Score all providers
const providerScores = availableProviders.map((provider) => {
let score = 0;
const metrics = this.performanceMetrics.get(provider.id);
// Performance score
score += metrics.successRate * 30;
score += Math.max(0, 1 - metrics.averageLatency / 5000) * 20;
// Cost efficiency
score += Math.max(0, 1 - metrics.costPerToken / 0.001) * 20;
// Contextual fit
if (options?.queryType) {
score +=
(this.getContextualWeight(provider, options.queryType) - 1) * 30;
}
return { provider, score };
});
// Sort by score
providerScores.sort((a, b) => b.score - a.score);
// Apply diversity selection if requested
const diversityWeight = options?.diversityWeight || 0.3;
const selectedProviders = [];
// Always include the best provider
selectedProviders.push(providerScores[0].provider);
// Select additional providers with diversity consideration
for (let i = 1; i < count && i < providerScores.length; i++) {
if (diversityWeight > 0) {
// Prefer providers with different capabilities
const candidate = providerScores[i].provider;
const isDiverse = selectedProviders.every((selected) => this.calculateProviderDiversity(selected, candidate) >
diversityWeight);
if (isDiverse || selectedProviders.length < 2) {
selectedProviders.push(candidate);
}
}
else {
selectedProviders.push(providerScores[i].provider);
}
}
return selectedProviders;
}
/**
* Get the best provider for the given context
*/
selectProvider(context) {
const availableProviders = Array.from(this.providers.values()).filter((provider) => provider.isEnabled());
if (availableProviders.length === 0) {
return null;
}
if (availableProviders.length === 1) {
return availableProviders[0];
}
switch (this.config.strategy) {
case 'round-robin':
return this.selectRoundRobin(availableProviders);
case 'weighted':
return this.selectWeighted(availableProviders, context);
case 'performance':
return this.selectByPerformance(availableProviders, context);
case 'cost-optimized':
return this.selectByCost(availableProviders, context);
case 'quality-first':
return this.selectByQuality(availableProviders, context);
default:
return this.selectWeighted(availableProviders, context);
}
}
/**
* Round-robin selection
*/
selectRoundRobin(providers) {
const provider = providers[this.roundRobinIndex % providers.length];
this.roundRobinIndex = (this.roundRobinIndex + 1) % providers.length;
return provider;
}
/**
* Weighted selection based on configuration and context
*/
selectWeighted(providers, context) {
const weights = new Map();
for (const provider of providers) {
let weight = this.config.providerWeights?.[provider.id] ?? 1.0;
// Adjust weight based on context
if (context && context.type) {
weight *= this.getContextualWeight(provider, context.type);
}
// Adjust weight based on performance metrics
const metrics = this.performanceMetrics.get(provider.id);
if (metrics) {
// Reduce weight for poor performance
weight *= metrics.successRate;
weight *= Math.max(0.1, 1 - metrics.averageLatency / 10000); // Penalty for high latency
}
weights.set(provider.id, Math.max(0.1, weight)); // Minimum weight 0.1
}
// Weighted random selection
const totalWeight = Array.from(weights.values()).reduce((sum, w) => sum + w, 0);
let random = Math.random() * totalWeight;
for (const provider of providers) {
const weight = weights.get(provider.id);
random -= weight;
if (random <= 0) {
return provider;
}
}
// Fallback to first provider
return providers[0];
}
/**
* Selection based on performance metrics
*/
selectByPerformance(providers, context) {
const thresholds = this.config.performanceThresholds;
// Filter providers that meet performance thresholds
let eligibleProviders = providers.filter((provider) => {
const metrics = this.performanceMetrics.get(provider.id);
if (!metrics || !thresholds)
return true;
return (metrics.successRate >= thresholds.minSuccessRate &&
metrics.averageLatency <= thresholds.maxLatency &&
metrics.costPerToken <= thresholds.maxCostPerToken);
});
// If no providers meet thresholds, use all available
if (eligibleProviders.length === 0) {
eligibleProviders = providers;
}
// Sort by performance score (success rate - normalized latency)
eligibleProviders.sort((a, b) => {
const metricsA = this.performanceMetrics.get(a.id);
const metricsB = this.performanceMetrics.get(b.id);
const scoreA = metricsA.successRate - metricsA.averageLatency / 10000;
const scoreB = metricsB.successRate - metricsB.averageLatency / 10000;
return scoreB - scoreA;
});
return eligibleProviders[0];
}
/**
* Selection based on cost optimization
*/
selectByCost(providers, context) {
const budget = this.config.costBudget;
// If approaching budget limit, prefer cheaper providers
if (budget && budget.currentUsage / budget.dailyLimit > 0.8) {
providers.sort((a, b) => {
const metricsA = this.performanceMetrics.get(a.id);
const metricsB = this.performanceMetrics.get(b.id);
return metricsA.costPerToken - metricsB.costPerToken;
});
}
else {
// Normal weighted selection with cost consideration
return this.selectWeighted(providers, context);
}
return providers[0];
}
/**
* Selection prioritizing quality over cost
*/
selectByQuality(providers, context) {
// Get provider strengths and match with query type
if (context && context.type) {
const qualityProviders = providers.filter((provider) => {
const strengths = provider.getStrengths();
return this.matchesQueryType(strengths, context.type);
});
if (qualityProviders.length > 0) {
return this.selectByPerformance(qualityProviders, context);
}
}
// Fallback to performance-based selection
return this.selectByPerformance(providers, context);
}
/**
* Get contextual weight adjustment for provider based on query type
*/
getContextualWeight(provider, queryType) {
const strengths = provider.getStrengths();
switch (queryType) {
case QueryType.CODE:
if (strengths.includes('code-generation') ||
strengths.includes('technical-accuracy')) {
return 1.5;
}
break;
case QueryType.CREATIVE:
if (strengths.includes('creative-writing') ||
strengths.includes('language-fluency')) {
return 1.5;
}
break;
case QueryType.ANALYSIS:
if (strengths.includes('reasoning') || strengths.includes('analysis')) {
return 1.5;
}
break;
case QueryType.SECURITY:
if (strengths.includes('security') ||
strengths.includes('safety-filtering')) {
return 1.5;
}
break;
default:
return 1.0;
}
return 1.0;
}
/**
* Check if provider strengths match query type
*/
matchesQueryType(strengths, queryType) {
const typeMapping = {
[QueryType.CODE]: [
'code-generation',
'technical-accuracy',
'programming',
],
[QueryType.CREATIVE]: [
'creative-writing',
'language-fluency',
'storytelling',
],
[QueryType.ANALYSIS]: ['reasoning', 'analysis', 'data-analysis'],
[QueryType.SECURITY]: ['security', 'safety-filtering', 'compliance'],
[QueryType.GENERAL]: ['general-purpose', 'conversation'],
};
const requiredStrengths = typeMapping[queryType] || [];
return requiredStrengths.some((strength) => strengths.includes(strength));
}
/**
* Update performance metrics for a provider
*/
updateMetrics(providerId, latency, success, cost) {
const metrics = this.performanceMetrics.get(providerId);
if (!metrics)
return;
// Update running averages
const alpha = 0.1; // Learning rate
metrics.averageLatency =
metrics.averageLatency * (1 - alpha) + latency * alpha;
// Update success rate
metrics.totalQueries++;
if (!success) {
metrics.failureCount++;
}
metrics.successRate =
(metrics.totalQueries - metrics.failureCount) / metrics.totalQueries;
// Update cost if provided
if (cost !== undefined) {
metrics.costPerToken = metrics.costPerToken * (1 - alpha) + cost * alpha;
}
metrics.lastUpdateTime = new Date();
}
/**
* Get a fallback provider when primary selection fails
*/
getFallbackProvider(failedProviderId, context) {
const availableProviders = Array.from(this.providers.values()).filter((provider) => provider.isEnabled() && provider.id !== failedProviderId);
if (availableProviders.length === 0) {
return null;
}
// Use fallback strategy or default to round-robin
const strategy = this.config.fallbackStrategy || 'round-robin';
const originalStrategy = this.config.strategy;
// Temporarily change strategy for fallback
this.config.strategy = strategy;
const fallbackProvider = this.selectProvider(context);
this.config.strategy = originalStrategy;
return fallbackProvider;
}
/**
* Check if provider should be retried based on failure count
*/
shouldRetryProvider(providerId) {
const metrics = this.performanceMetrics.get(providerId);
if (!metrics)
return true;
// Don't retry if failure rate is too high
const recentFailureRate = metrics.failureCount / Math.max(1, metrics.totalQueries);
return recentFailureRate < 0.5; // Allow retry if failure rate < 50%
}
/**
* Get current performance metrics for all providers
*/
getPerformanceMetrics() {
return new Map(this.performanceMetrics);
}
/**
* Update selection configuration
*/
updateConfig(config) {
this.config = { ...this.config, ...config };
}
/**
* Get available providers
*/
getAvailableProviders() {
return Array.from(this.providers.values()).filter((provider) => provider.isEnabled());
}
/**
* Calculate diversity between two providers
*/
calculateProviderDiversity(provider1, provider2) {
const strengths1 = new Set(provider1.getStrengths());
const strengths2 = new Set(provider2.getStrengths());
const intersection = new Set([...strengths1].filter((x) => strengths2.has(x)));
const union = new Set([...strengths1, ...strengths2]);
// Diversity is 1 - similarity
return union.size > 0 ? 1 - intersection.size / union.size : 1;
}
/**
* Get weighted performance score for a provider
*/
getProviderScore(provider, weights = {}, context) {
const defaultWeights = {
performance: 0.3,
cost: 0.2,
reliability: 0.3,
contextual: 0.2,
};
const w = { ...defaultWeights, ...weights };
const metrics = this.performanceMetrics.get(provider.id);
if (!metrics)
return 0;
let score = 0;
// Performance score (latency and success rate)
const performanceScore = metrics.successRate * 0.7 +
Math.max(0, 1 - metrics.averageLatency / 5000) * 0.3;
score += performanceScore * w.performance;
// Cost efficiency score
const costScore = Math.max(0, 1 - metrics.costPerToken / 0.001);
score += costScore * w.cost;
// Reliability score (based on recent performance)
const reliabilityScore = Math.max(0, 1 - metrics.failureCount / Math.max(1, metrics.totalQueries));
score += reliabilityScore * w.reliability;
// Contextual fit score
let contextualScore = 1.0;
if (context?.type) {
contextualScore = this.getContextualWeight(provider, context.type);
}
score += (contextualScore - 1) * w.contextual;
return Math.max(0, Math.min(1, score));
}
/**
* Adaptive selection based on query analysis and current state
*/
async selectAdaptive(prompt, options) {
const availableProviders = this.getAvailableProviders();
if (availableProviders.length === 0) {
throw new Error('No providers available');
}
if (availableProviders.length === 1) {
return availableProviders[0];
}
// Determine weights based on preferences
const weights = {
performance: 0.25,
cost: 0.25,
reliability: 0.25,
contextual: 0.25,
};
if (options?.preferFast) {
weights.performance = 0.5;
weights.cost = 0.2;
weights.reliability = 0.2;
weights.contextual = 0.1;
}
else if (options?.preferCheap) {
weights.cost = 0.5;
weights.performance = 0.2;
weights.reliability = 0.2;
weights.contextual = 0.1;
}
else if (options?.preferAccurate) {
weights.reliability = 0.4;
weights.contextual = 0.3;
weights.performance = 0.2;
weights.cost = 0.1;
}
// Score providers
const providerScores = availableProviders.map((provider) => ({
provider,
score: this.getProviderScore(provider, weights, {
type: options?.queryType,
}),
}));
// Sort by score and return the best
providerScores.sort((a, b) => b.score - a.score);
return providerScores[0].provider;
}
/**
* Get load balancing recommendation
*/
getLoadBalancingRecommendation() {
const providers = this.getAvailableProviders();
const currentTime = new Date();
// Check recent query distribution
const recentLoadDistribution = new Map();
providers.forEach((provider) => {
const metrics = this.performanceMetrics.get(provider.id);
if (metrics) {
recentLoadDistribution.set(provider.id, metrics.totalQueries);
}
});
const totalQueries = Array.from(recentLoadDistribution.values()).reduce((a, b) => a + b, 0);
const averageLoad = totalQueries / providers.length;
// Check for load imbalance
const imbalancedProviders = providers.filter((provider) => {
const load = recentLoadDistribution.get(provider.id) || 0;
return Math.abs(load - averageLoad) > averageLoad * 0.3; // 30% threshold
});
if (imbalancedProviders.length > 0) {
const underutilized = providers.filter((provider) => {
const load = recentLoadDistribution.get(provider.id) || 0;
return load < averageLoad * 0.7;
});
return {
recommendation: 'distribute',
reason: 'Load imbalance detected - some providers are underutilized',
suggestedProviders: underutilized.map((p) => p.id),
};
}
// Check for performance issues
const poorPerformingProviders = providers.filter((provider) => {
const metrics = this.performanceMetrics.get(provider.id);
return (metrics && (metrics.successRate < 0.8 || metrics.averageLatency > 5000));
});
if (poorPerformingProviders.length > providers.length / 2) {
return {
recommendation: 'throttle',
reason: 'Multiple providers showing poor performance',
suggestedProviders: providers
.filter((p) => !poorPerformingProviders.includes(p))
.map((p) => p.id),
};
}
return {
recommendation: 'fallback',
reason: 'System operating normally',
suggestedProviders: providers.map((p) => p.id),
};
}
/**
* Dynamic strategy adjustment based on current conditions
*/
adjustStrategyDynamically() {
const providers = this.getAvailableProviders();
const now = new Date();
// Analyze recent performance patterns
let totalSuccessRate = 0;
let totalLatency = 0;
let totalCost = 0;
let validMetrics = 0;
providers.forEach((provider) => {
const metrics = this.performanceMetrics.get(provider.id);
if (metrics && metrics.totalQueries > 0) {
totalSuccessRate += metrics.successRate;
totalLatency += metrics.averageLatency;
totalCost += metrics.costPerToken;
validMetrics++;
}
});
if (validMetrics === 0) {
return {
recommendedStrategy: 'weighted',
reason: 'Insufficient data for optimization',
confidence: 0.5,
};
}
const avgSuccessRate = totalSuccessRate / validMetrics;
const avgLatency = totalLatency / validMetrics;
const avgCost = totalCost / validMetrics;
// High performance variance -> use performance-based selection
const performanceVariance = providers.reduce((variance, provider) => {
const metrics = this.performanceMetrics.get(provider.id);
if (metrics) {
const diff = metrics.successRate - avgSuccessRate;
return variance + diff * diff;
}
return variance;
}, 0) / validMetrics;
if (performanceVariance > 0.1) {
return {
recommendedStrategy: 'performance',
reason: 'High performance variance detected',
confidence: 0.8,
};
}
// High cost variance -> use cost optimization
const costVariance = providers.reduce((variance, provider) => {
const metrics = this.performanceMetrics.get(provider.id);
if (metrics) {
const diff = metrics.costPerToken - avgCost;
return variance + diff * diff;
}
return variance;
}, 0) / validMetrics;
if (costVariance > avgCost * avgCost * 0.25) {
return {
recommendedStrategy: 'cost-optimized',
reason: 'Significant cost differences between providers',
confidence: 0.7,
};
}
// Consistent performance -> use round-robin for fair distribution
if (avgSuccessRate > 0.9 && avgLatency < 3000) {
return {
recommendedStrategy: 'round-robin',
reason: 'All providers performing well - distribute load evenly',
confidence: 0.9,
};
}
// Default to weighted selection
return {
recommendedStrategy: 'weighted',
reason: 'Balanced approach recommended',
confidence: 0.6,
};
}
/**
* Reset round-robin index (useful for testing)
*/
resetRoundRobin() {
this.roundRobinIndex = 0;
}
/**
* Get provider selection recommendation with reasoning
*/
getSelectionRecommendation(context, options) {
const availableProviders = this.getAvailableProviders();
if (availableProviders.length === 0) {
throw new Error('No providers available');
}
const reasoning = [];
const alternatives = [];
// Score all providers
const providerScores = availableProviders.map((provider) => {
const score = this.getProviderScore(provider, {
performance: options?.considerLatency ? 0.4 : 0.3,
cost: options?.considerCost ? 0.4 : 0.2,
reliability: options?.considerAccuracy ? 0.4 : 0.3,
contextual: context?.type ? 0.3 : 0.2,
}, context);
const metrics = this.performanceMetrics.get(provider.id);
// Add to alternatives list
alternatives.push({ provider, score });
return { provider, score, metrics };
});
// Sort by score
providerScores.sort((a, b) => b.score - a.score);
alternatives.sort((a, b) => b.score - a.score);
const bestProvider = providerScores[0];
const bestMetrics = bestProvider.metrics;
// Generate reasoning
reasoning.push(`Selected ${bestProvider.provider.id} with score ${(bestProvider.score * 100).toFixed(1)}%`);
if (bestMetrics.successRate > 0.9) {
reasoning.push(`High reliability: ${(bestMetrics.successRate * 100).toFixed(1)}% success rate`);
}
if (bestMetrics.averageLatency < 2000) {
reasoning.push(`Fast response: ${bestMetrics.averageLatency}ms average latency`);
}
if (context?.type) {
const contextWeight = this.getContextualWeight(bestProvider.provider, context.type);
if (contextWeight > 1.2) {
reasoning.push(`Well-suited for ${context.type} queries`);
}
}
// Calculate confidence based on score margin
const secondBest = providerScores[1];
const scoreDifference = secondBest
? bestProvider.score - secondBest.score
: bestProvider.score;
const confidence = Math.min(0.9, 0.5 + scoreDifference);
return {
provider: bestProvider.provider,
reasoning,
confidence,
alternatives: alternatives.slice(1), // Exclude the selected provider
};
}
}
//# sourceMappingURL=model-selector.js.map