@probelabs/probe
Version:
Node.js wrapper for the probe code search tool
546 lines (470 loc) • 18.4 kB
JavaScript
/**
* FallbackManager - Handles provider and model fallback configuration
*
* Provides flexible fallback strategies for AI API calls:
* - Same model across different providers (Azure Claude → Bedrock Claude)
* - Different models on same provider (Claude 3.7 → Claude 3.5)
* - Cross-provider fallback (Anthropic → OpenAI → Google)
* - Custom fallback chains with full configuration
*/
import { createAnthropic } from '@ai-sdk/anthropic';
import { createOpenAI } from '@ai-sdk/openai';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock';
/**
* Fallback strategies
*/
export const FALLBACK_STRATEGIES = {
SAME_MODEL: 'same-model', // Try same model on different providers
SAME_PROVIDER: 'same-provider', // Try different models on same provider
ANY: 'any', // Try any available provider/model
CUSTOM: 'custom' // Use custom provider list
};
/**
* Provider configuration schema
* @typedef {Object} ProviderConfig
* @property {string} provider - Provider name: 'anthropic', 'openai', 'google', 'bedrock'
* @property {string} [model] - Model name
* @property {string} [apiKey] - API key
* @property {string} [baseURL] - Custom API endpoint
* @property {number} [maxRetries] - Max retries for this provider (overrides global)
* @property {string} [region] - AWS region (for Bedrock)
* @property {string} [accessKeyId] - AWS access key ID (for Bedrock)
* @property {string} [secretAccessKey] - AWS secret access key (for Bedrock)
* @property {string} [sessionToken] - AWS session token (for Bedrock)
*/
/**
* Default model mappings for each provider
*/
const DEFAULT_MODELS = {
anthropic: 'claude-sonnet-4-5-20250929',
openai: 'gpt-4o',
google: 'gemini-2.0-flash-exp',
bedrock: 'anthropic.claude-sonnet-4-20250514-v1:0'
};
/**
* FallbackManager class for handling provider and model fallback
*/
export class FallbackManager {
/**
* Create a new FallbackManager
* @param {Object} options - Configuration options
* @param {string} [options.strategy='any'] - Fallback strategy
* @param {Array<string>} [options.models] - List of models for same-provider fallback
* @param {Array<ProviderConfig>} [options.providers] - List of provider configurations
* @param {boolean} [options.stopOnSuccess=true] - Stop on first success
* @param {boolean} [options.continueOnNonRetryableError=false] - Continue to fallback on non-retryable errors
* @param {number} [options.maxTotalAttempts=10] - Maximum total attempts across all providers
* @param {boolean} [options.debug=false] - Enable debug logging
*/
constructor(options = {}) {
this.strategy = options.strategy || FALLBACK_STRATEGIES.ANY;
this.models = Array.isArray(options.models) ? options.models : [];
this.providers = Array.isArray(options.providers) ? options.providers : [];
this.stopOnSuccess = options.stopOnSuccess ?? true;
this.continueOnNonRetryableError = options.continueOnNonRetryableError ?? false;
this.debug = options.debug ?? false;
// Validate maxTotalAttempts
const maxAttempts = options.maxTotalAttempts ?? 10;
if (typeof maxAttempts !== 'number' || isNaN(maxAttempts) || maxAttempts < 1 || maxAttempts > 100) {
throw new Error(`FallbackManager: maxTotalAttempts must be a number between 1 and 100, got: ${maxAttempts}`);
}
this.maxTotalAttempts = maxAttempts;
// Statistics
this.stats = {
totalAttempts: 0,
providerAttempts: {},
successfulProvider: null,
failedProviders: []
};
// Validate configuration
this._validateConfiguration();
}
/**
* Validate the fallback configuration
* @private
*/
_validateConfiguration() {
if (this.strategy === FALLBACK_STRATEGIES.SAME_PROVIDER && this.models.length === 0) {
throw new Error('FallbackManager: strategy "same-provider" requires models list');
}
if (this.strategy === FALLBACK_STRATEGIES.CUSTOM && this.providers.length === 0) {
throw new Error('FallbackManager: strategy "custom" requires providers list');
}
// Validate provider configurations
for (const config of this.providers) {
if (!config.provider) {
throw new Error('FallbackManager: Each provider config must have a "provider" field');
}
if (!['anthropic', 'openai', 'google', 'bedrock'].includes(config.provider)) {
throw new Error(`FallbackManager: Invalid provider "${config.provider}". Must be: anthropic, openai, google, or bedrock`);
}
// Validate Bedrock configuration
if (config.provider === 'bedrock') {
const hasCredentials = config.accessKeyId && config.secretAccessKey && config.region;
const hasApiKey = config.apiKey;
if (!hasCredentials && !hasApiKey) {
throw new Error('FallbackManager: Bedrock provider requires either (accessKeyId, secretAccessKey, region) or apiKey');
}
} else {
// Other providers require apiKey
if (!config.apiKey) {
throw new Error(`FallbackManager: Provider "${config.provider}" requires apiKey`);
}
}
}
}
/**
* Create a provider instance from configuration
* @param {ProviderConfig} config - Provider configuration
* @returns {Object} - Provider instance
* @throws {Error} - If provider creation fails
* @private
*/
_createProviderInstance(config) {
try {
switch (config.provider) {
case 'anthropic':
return createAnthropic({
apiKey: config.apiKey,
...(config.baseURL && { baseURL: config.baseURL })
});
case 'openai':
return createOpenAI({
compatibility: 'strict',
apiKey: config.apiKey,
...(config.baseURL && { baseURL: config.baseURL })
});
case 'google':
return createGoogleGenerativeAI({
apiKey: config.apiKey,
...(config.baseURL && { baseURL: config.baseURL })
});
case 'bedrock': {
const bedrockConfig = {};
if (config.apiKey) {
bedrockConfig.apiKey = config.apiKey;
} else if (config.accessKeyId && config.secretAccessKey) {
bedrockConfig.accessKeyId = config.accessKeyId;
bedrockConfig.secretAccessKey = config.secretAccessKey;
if (config.sessionToken) {
bedrockConfig.sessionToken = config.sessionToken;
}
}
if (config.region) {
bedrockConfig.region = config.region;
}
if (config.baseURL) {
bedrockConfig.baseURL = config.baseURL;
}
return createAmazonBedrock(bedrockConfig);
}
default:
throw new Error(`FallbackManager: Unknown provider "${config.provider}"`);
}
} catch (error) {
// Re-throw with more context
const providerName = this._getProviderDisplayName(config);
throw new Error(`Failed to create provider instance for ${providerName}: ${error.message}`);
}
}
/**
* Get the model name for a provider configuration
* @param {ProviderConfig} config - Provider configuration
* @returns {string} - Model name
* @private
*/
_getModelName(config) {
return config.model || DEFAULT_MODELS[config.provider];
}
/**
* Get provider display name for logging
* @param {ProviderConfig} config - Provider configuration
* @returns {string} - Display name
* @private
*/
_getProviderDisplayName(config) {
const model = this._getModelName(config);
const provider = config.provider;
const url = config.baseURL ? ` (${config.baseURL})` : '';
return `${provider}/${model}${url}`;
}
/**
* Execute a function with fallback support
* @param {Function} fn - Function that takes (provider, model, config) and returns a Promise
* @returns {Promise<*>} - Result from the function
* @throws {Error} - If all fallbacks are exhausted
*/
async executeWithFallback(fn) {
if (this.providers.length === 0) {
throw new Error('FallbackManager: No providers configured for fallback');
}
let lastError = null;
let totalAttempts = 0;
for (const config of this.providers) {
if (totalAttempts >= this.maxTotalAttempts) {
if (this.debug) {
console.log(`[FallbackManager] ⚠️ Max total attempts (${this.maxTotalAttempts}) reached`);
}
break;
}
totalAttempts++;
this.stats.totalAttempts++;
const providerName = this._getProviderDisplayName(config);
this.stats.providerAttempts[providerName] = (this.stats.providerAttempts[providerName] || 0) + 1;
try {
if (this.debug) {
console.log(`[FallbackManager] Attempting provider: ${providerName} (attempt ${totalAttempts}/${this.maxTotalAttempts})`);
}
const provider = this._createProviderInstance(config);
const model = this._getModelName(config);
const result = await fn(provider, model, config);
// Success!
this.stats.successfulProvider = providerName;
if (this.debug) {
console.log(`[FallbackManager] ✅ Success with provider: ${providerName}`);
}
return result;
} catch (error) {
lastError = error;
const errorInfo = {
message: error.message || error.toString(),
type: error.type || error.constructor.name,
statusCode: error.statusCode || error.status
};
this.stats.failedProviders.push({
provider: providerName,
error: errorInfo
});
if (this.debug) {
console.log(`[FallbackManager] ❌ Failed with provider: ${providerName}`, errorInfo);
}
// Check if we should continue to next provider
// If error is not retryable and continueOnNonRetryableError is false, stop
if (!this.continueOnNonRetryableError && error.nonRetryable) {
if (this.debug) {
console.log(`[FallbackManager] Non-retryable error, stopping fallback chain`);
}
throw error;
}
// Continue to next provider
if (this.debug) {
const remaining = this.providers.length - (this.providers.indexOf(config) + 1);
console.log(`[FallbackManager] Trying next provider (${remaining} remaining)...`);
}
}
}
// All providers failed
if (this.debug) {
console.log(`[FallbackManager] ❌ All providers exhausted. Total attempts: ${totalAttempts}`);
}
// Enhance error with fallback context
const fallbackError = new Error(
`All provider fallbacks exhausted after ${totalAttempts} attempts. Last error: ${lastError?.message || 'Unknown error'}`
);
fallbackError.cause = lastError;
fallbackError.stats = this.getStats();
fallbackError.allProvidersFailed = true;
throw fallbackError;
}
/**
* Get fallback statistics
* @returns {Object} - Statistics object
*/
getStats() {
return {
...this.stats,
providerAttempts: { ...this.stats.providerAttempts },
failedProviders: [...this.stats.failedProviders]
};
}
/**
* Reset statistics
*/
resetStats() {
this.stats = {
totalAttempts: 0,
providerAttempts: {},
successfulProvider: null,
failedProviders: []
};
}
}
/**
* Create a FallbackManager from environment variables
* @param {boolean} [debug=false] - Enable debug logging
* @returns {FallbackManager|null} - Configured FallbackManager instance or null if no fallback config
*/
export function createFallbackManagerFromEnv(debug = false) {
const fallbackProvidersEnv = process.env.FALLBACK_PROVIDERS;
const fallbackModelsEnv = process.env.FALLBACK_MODELS;
// If no fallback configuration, return null
if (!fallbackProvidersEnv && !fallbackModelsEnv) {
return null;
}
let providers = [];
let models = [];
let strategy = FALLBACK_STRATEGIES.ANY;
// Parse providers configuration
if (fallbackProvidersEnv) {
try {
// Validate input is a string and not suspiciously long
if (typeof fallbackProvidersEnv !== 'string' || fallbackProvidersEnv.length > 10000) {
console.error('[FallbackManager] FALLBACK_PROVIDERS must be a valid JSON string under 10KB');
return null;
}
const parsed = JSON.parse(fallbackProvidersEnv);
// Validate parsed result is an array
if (!Array.isArray(parsed)) {
console.error('[FallbackManager] FALLBACK_PROVIDERS must be a JSON array');
return null;
}
providers = parsed;
strategy = FALLBACK_STRATEGIES.CUSTOM;
} catch (error) {
console.error('[FallbackManager] Failed to parse FALLBACK_PROVIDERS:', error.message);
return null;
}
}
// Parse models configuration
if (fallbackModelsEnv) {
try {
// Validate input is a string and not suspiciously long
if (typeof fallbackModelsEnv !== 'string' || fallbackModelsEnv.length > 10000) {
console.error('[FallbackManager] FALLBACK_MODELS must be a valid JSON string under 10KB');
return null;
}
const parsed = JSON.parse(fallbackModelsEnv);
// Validate parsed result is an array
if (!Array.isArray(parsed)) {
console.error('[FallbackManager] FALLBACK_MODELS must be a JSON array');
return null;
}
models = parsed;
strategy = FALLBACK_STRATEGIES.SAME_PROVIDER;
} catch (error) {
console.error('[FallbackManager] Failed to parse FALLBACK_MODELS:', error.message);
return null;
}
}
const maxTotalAttempts = process.env.FALLBACK_MAX_TOTAL_ATTEMPTS
? (() => {
const val = parseInt(process.env.FALLBACK_MAX_TOTAL_ATTEMPTS, 10);
if (isNaN(val) || val < 1 || val > 100) {
console.warn('[FallbackManager] FALLBACK_MAX_TOTAL_ATTEMPTS must be between 1 and 100, using default: 10');
return 10;
}
return val;
})()
: 10;
return new FallbackManager({
strategy,
providers,
models,
maxTotalAttempts,
debug
});
}
/**
* Build a fallback provider list from current environment
* @param {Object} options - Options for building the list
* @param {string} [options.primaryProvider] - Primary provider to try first
* @param {string} [options.primaryModel] - Primary model to use
* @returns {Array<ProviderConfig>} - List of provider configurations
*/
export function buildFallbackProvidersFromEnv(options = {}) {
const providers = [];
// Get all available API keys from environment
const anthropicApiKey = process.env.ANTHROPIC_API_KEY || process.env.ANTHROPIC_AUTH_TOKEN;
const openaiApiKey = process.env.OPENAI_API_KEY;
const googleApiKey = process.env.GOOGLE_GENERATIVE_AI_API_KEY || process.env.GOOGLE_API_KEY;
const awsAccessKeyId = process.env.AWS_ACCESS_KEY_ID;
const awsSecretAccessKey = process.env.AWS_SECRET_ACCESS_KEY;
const awsRegion = process.env.AWS_REGION;
const awsApiKey = process.env.AWS_BEDROCK_API_KEY;
// Get custom URLs
const llmBaseUrl = process.env.LLM_BASE_URL;
const anthropicApiUrl = process.env.ANTHROPIC_API_URL || process.env.ANTHROPIC_BASE_URL || llmBaseUrl;
const openaiApiUrl = process.env.OPENAI_API_URL || llmBaseUrl;
const googleApiUrl = process.env.GOOGLE_API_URL || llmBaseUrl;
const awsBedrockBaseUrl = process.env.AWS_BEDROCK_BASE_URL || llmBaseUrl;
// Build primary provider config
const primaryProvider = options.primaryProvider?.toLowerCase();
const primaryModel = options.primaryModel;
// Add primary provider first if specified
if (primaryProvider === 'anthropic' && anthropicApiKey) {
providers.push({
provider: 'anthropic',
apiKey: anthropicApiKey,
...(anthropicApiUrl && { baseURL: anthropicApiUrl }),
...(primaryModel && { model: primaryModel })
});
} else if (primaryProvider === 'openai' && openaiApiKey) {
providers.push({
provider: 'openai',
apiKey: openaiApiKey,
...(openaiApiUrl && { baseURL: openaiApiUrl }),
...(primaryModel && { model: primaryModel })
});
} else if (primaryProvider === 'google' && googleApiKey) {
providers.push({
provider: 'google',
apiKey: googleApiKey,
...(googleApiUrl && { baseURL: googleApiUrl }),
...(primaryModel && { model: primaryModel })
});
} else if (primaryProvider === 'bedrock' && ((awsAccessKeyId && awsSecretAccessKey && awsRegion) || awsApiKey)) {
const config = { provider: 'bedrock' };
if (awsApiKey) {
config.apiKey = awsApiKey;
} else {
config.accessKeyId = awsAccessKeyId;
config.secretAccessKey = awsSecretAccessKey;
config.region = awsRegion;
if (process.env.AWS_SESSION_TOKEN) {
config.sessionToken = process.env.AWS_SESSION_TOKEN;
}
}
if (awsBedrockBaseUrl) config.baseURL = awsBedrockBaseUrl;
if (primaryModel) config.model = primaryModel;
providers.push(config);
}
// Add remaining available providers as fallbacks
if (anthropicApiKey && primaryProvider !== 'anthropic') {
providers.push({
provider: 'anthropic',
apiKey: anthropicApiKey,
...(anthropicApiUrl && { baseURL: anthropicApiUrl })
});
}
if (openaiApiKey && primaryProvider !== 'openai') {
providers.push({
provider: 'openai',
apiKey: openaiApiKey,
...(openaiApiUrl && { baseURL: openaiApiUrl })
});
}
if (googleApiKey && primaryProvider !== 'google') {
providers.push({
provider: 'google',
apiKey: googleApiKey,
...(googleApiUrl && { baseURL: googleApiUrl })
});
}
if (((awsAccessKeyId && awsSecretAccessKey && awsRegion) || awsApiKey) && primaryProvider !== 'bedrock') {
const config = { provider: 'bedrock' };
if (awsApiKey) {
config.apiKey = awsApiKey;
} else {
config.accessKeyId = awsAccessKeyId;
config.secretAccessKey = awsSecretAccessKey;
config.region = awsRegion;
if (process.env.AWS_SESSION_TOKEN) {
config.sessionToken = process.env.AWS_SESSION_TOKEN;
}
}
if (awsBedrockBaseUrl) config.baseURL = awsBedrockBaseUrl;
providers.push(config);
}
return providers;
}