UNPKG

@just-every/ensemble

Version:

LLM provider abstraction layer with unified streaming interface

308 lines 12.7 kB
import { isExternalModel, getExternalModel, getExternalProvider, getModelClassOverride, } from '../utils/external_models.js'; import { openaiProvider } from './openai.js'; import { claudeProvider } from './claude.js'; import { geminiProvider } from './gemini.js'; import { grokProvider } from './grok.js'; import { deepSeekProvider } from './deepseek.js'; import { testProvider } from './test_provider.js'; import { openRouterProvider } from './openrouter.js'; import { elevenLabsProvider } from './elevenlabs.js'; import { MODEL_CLASSES } from '../data/model_data.js'; const MODEL_PROVIDER_MAP = { 'gpt-': openaiProvider, o1: openaiProvider, o3: openaiProvider, o4: openaiProvider, 'text-': openaiProvider, 'computer-use-preview': openaiProvider, 'dall-e': openaiProvider, 'gpt-image': openaiProvider, 'tts-': openaiProvider, 'codex-': openaiProvider, 'claude-': claudeProvider, 'gemini-': geminiProvider, 'imagen-': geminiProvider, 'grok-': grokProvider, 'deepseek-': deepSeekProvider, eleven_: elevenLabsProvider, 'elevenlabs-': elevenLabsProvider, 'test-': testProvider, }; export function isProviderKeyValid(provider) { switch (provider) { case 'openai': return !!process.env.OPENAI_API_KEY && process.env.OPENAI_API_KEY.startsWith('sk-'); case 'anthropic': return !!process.env.ANTHROPIC_API_KEY && process.env.ANTHROPIC_API_KEY.startsWith('sk-ant-'); case 'google': return !!process.env.GOOGLE_API_KEY; case 'xai': return !!process.env.XAI_API_KEY && process.env.XAI_API_KEY.startsWith('xai-'); case 'deepseek': return !!process.env.DEEPSEEK_API_KEY && process.env.DEEPSEEK_API_KEY.startsWith('sk-'); case 'openrouter': return !!process.env.OPENROUTER_API_KEY; case 'elevenlabs': return !!process.env.ELEVENLABS_API_KEY; case 'test': return true; default: { const externalProvider = getExternalProvider(provider); if (externalProvider) { return true; } return false; } } } export function getProviderFromModel(model) { if (isExternalModel(model)) { const externalModel = getExternalModel(model); if (externalModel) { return externalModel.provider; } } if (model.startsWith('gpt-') || model.startsWith('o1') || model.startsWith('o3') || model.startsWith('o4') || model.startsWith('text-') || model.startsWith('computer-use-preview') || model.startsWith('dall-e') || model.startsWith('gpt-image') || model.startsWith('tts-')) { return 'openai'; } else if (model.startsWith('claude-')) { return 'anthropic'; } else if (model.startsWith('gemini-') || model.startsWith('imagen-')) { return 'google'; } else if (model.startsWith('grok-')) { return 'xai'; } else if (model.startsWith('deepseek-')) { return 'deepseek'; } else if (model.startsWith('eleven_') || model.startsWith('elevenlabs-')) { return 'elevenlabs'; } else if (model.startsWith('test-')) { return 'test'; } return 'openrouter'; } function filterModelsWithFallback(models, excludeModels, disabledModels) { const allExcluded = [...(excludeModels || []), ...(disabledModels || [])]; if (allExcluded.length === 0) { return models; } const originalModels = [...models]; const filteredModels = models.filter(model => !allExcluded.includes(model)); if (filteredModels.length === 0) { const lastUsedModel = [...(excludeModels || [])] .reverse() .find(excludedModel => originalModels.includes(excludedModel)); if (lastUsedModel) { let nextIndex = (originalModels.indexOf(lastUsedModel) + 1) % originalModels.length; let attempts = 0; while (attempts < originalModels.length) { const nextModel = originalModels[nextIndex]; if (!disabledModels?.includes(nextModel)) { return [nextModel]; } nextIndex = (nextIndex + 1) % originalModels.length; attempts++; } } const firstNonDisabled = originalModels.find(m => !disabledModels?.includes(m)); if (firstNonDisabled) { return [firstNonDisabled]; } if (originalModels.length > 0) { return [originalModels[0]]; } } return filteredModels; } function selectWeightedModel(models, scores) { if (!scores || models.length === 0) { return models[Math.floor(Math.random() * models.length)]; } const modelWeights = models .map(model => ({ model, weight: scores[model] !== undefined ? scores[model] : 50, })) .filter(m => m.weight > 0); if (modelWeights.length === 0) { return models[Math.floor(Math.random() * models.length)]; } const totalWeight = modelWeights.reduce((sum, m) => sum + m.weight, 0); if (totalWeight === 0) { return modelWeights[0].model; } let random = Math.random() * totalWeight; for (const { model, weight } of modelWeights) { random -= weight; if (random <= 0) { return model; } } return modelWeights[0].model; } export async function getModelFromAgent(agent, defaultClass, excludeModels) { return (agent.model || (await getModelFromClass(agent.modelClass || defaultClass, excludeModels, agent.disabledModels, agent.modelScores))); } export async function getModelFromClass(modelClass, excludeModels, disabledModels, modelScores) { const { quotaTracker } = await import('../utils/quota_tracker.js'); const modelClassStr = modelClass; const modelGroup = modelClassStr && modelClassStr in MODEL_CLASSES ? modelClassStr : 'standard'; if (modelGroup in MODEL_CLASSES) { const override = getModelClassOverride(modelGroup); let modelClassConfig = MODEL_CLASSES[modelGroup]; if (override) { modelClassConfig = { ...modelClassConfig, ...override, }; } let models = [...(override?.models || modelClassConfig.models)]; models = filterModelsWithFallback(models, excludeModels, disabledModels); const shouldRandomize = override?.random ?? ('random' in modelClassConfig && modelClassConfig.random); const validModels = [...models]; const modelsWithKeyAndQuota = validModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider) && quotaTracker.hasQuota(provider, model); }); if (modelsWithKeyAndQuota.length > 0) { const selectedModel = shouldRandomize && !modelScores ? modelsWithKeyAndQuota[Math.floor(Math.random() * modelsWithKeyAndQuota.length)] : selectWeightedModel(modelsWithKeyAndQuota, modelScores); console.log(`Using '${selectedModel}' model for '${modelGroup}' class.`); return selectedModel; } const modelsWithKey = validModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider); }); if (modelsWithKey.length > 0) { const selectedModel = shouldRandomize && !modelScores ? modelsWithKey[Math.floor(Math.random() * modelsWithKey.length)] : selectWeightedModel(modelsWithKey, modelScores); console.log(`Using '${selectedModel}' model for '${modelGroup}' class (may exceed quota).`); return selectedModel; } } if (modelGroup !== 'standard' && 'standard' in MODEL_CLASSES) { let standardModels = MODEL_CLASSES['standard'].models; standardModels = filterModelsWithFallback(standardModels, excludeModels, disabledModels); const standardModelsWithKeyAndQuota = standardModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider) && quotaTracker.hasQuota(provider, model); }); if (standardModelsWithKeyAndQuota.length > 0) { const selectedModel = selectWeightedModel(standardModelsWithKeyAndQuota, modelScores); console.warn(`Falling back to 'standard' class with model '${selectedModel}'.`); return selectedModel; } const standardModelsWithKey = standardModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider); }); if (standardModelsWithKey.length > 0) { const selectedModel = selectWeightedModel(standardModelsWithKey, modelScores); console.log(`Falling back to 'standard' class with model '${selectedModel}' (may exceed quota).`); return selectedModel; } } let defaultModel = 'gpt-4.1'; if (modelGroup in MODEL_CLASSES) { const models = MODEL_CLASSES[modelGroup].models; if (models.length > 0) { defaultModel = models[0]; } } console.log(`No valid API key found for any model in class ${modelGroup}, using default: ${defaultModel}`); return defaultModel; } export function getModelProvider(model) { if (model) { if (isExternalModel(model)) { const externalModel = getExternalModel(model); if (externalModel) { const externalProvider = getExternalProvider(externalModel.provider); if (externalProvider) { return externalProvider; } } } for (const [prefix, provider] of Object.entries(MODEL_PROVIDER_MAP)) { if (model.startsWith(prefix)) { const providerName = getProviderFromModel(model); if (!isProviderKeyValid(providerName)) { throw new Error(`API key for ${providerName} provider is missing or invalid. Please set ${providerName.toUpperCase()}_API_KEY environment variable.`); } return provider; } } } if (!isProviderKeyValid(getProviderFromModel('openrouter'))) { throw new Error(`No valid provider found for the model ${model}. Please check your API keys.`); } return openRouterProvider; } export async function canRunAgent(agent) { if (agent.model) { const provider = getProviderFromModel(agent.model); const hasKey = isProviderKeyValid(provider); return { canRun: hasKey, model: agent.model, provider, missingProvider: hasKey ? undefined : provider, reason: hasKey ? undefined : `Missing API key for provider: ${provider}`, }; } if (agent.modelClass) { const modelClassStr = agent.modelClass; const modelGroup = modelClassStr && modelClassStr in MODEL_CLASSES ? modelClassStr : 'standard'; const override = getModelClassOverride(modelGroup); let modelClassConfig = MODEL_CLASSES[modelGroup]; if (override) { modelClassConfig = { ...modelClassConfig, ...override, }; } const models = [...(override?.models || modelClassConfig.models)]; const availableModels = []; const unavailableModels = []; const missingProviders = new Set(); for (const model of models) { const provider = getProviderFromModel(model); if (isProviderKeyValid(provider)) { availableModels.push(model); } else { unavailableModels.push(model); missingProviders.add(provider); } } return { canRun: availableModels.length > 0, availableModels, unavailableModels, missingProvider: availableModels.length === 0 && missingProviders.size === 1 ? Array.from(missingProviders)[0] : undefined, reason: availableModels.length === 0 ? `No API keys found for any models in class: ${agent.modelClass}. Missing providers: ${Array.from(missingProviders).join(', ')}` : undefined, }; } return canRunAgent({ modelClass: 'standard' }); } //# sourceMappingURL=model_provider.js.map