UNPKG

@just-every/ensemble

Version:

LLM provider abstraction layer with unified streaming interface

428 lines 17 kB
import { isExternalModel, getExternalModel, getExternalProvider, getModelClassOverride, } from '../utils/external_models.js'; import { openaiProvider } from './openai.js'; import { claudeProvider } from './claude.js'; import { geminiProvider } from './gemini.js'; import { grokProvider } from './grok.js'; import { deepSeekProvider } from './deepseek.js'; import { testProvider } from './test_provider.js'; import { openRouterProvider } from './openrouter.js'; import { elevenLabsProvider } from './elevenlabs.js'; import { lumaProvider } from './luma.js'; import { ideogramProvider } from './ideogram.js'; import { midjourneyProvider } from './midjourney.js'; import { fireworksProvider } from './fireworks.js'; import { stabilityProvider } from './stability.js'; import { falProvider } from './fal.js'; import { runwayProvider } from './runway.js'; import { bytedanceProvider } from './bytedance.js'; import { MODEL_CLASSES, findModel } from '../data/model_data.js'; const PROVIDER_BY_ID = { openai: openaiProvider, anthropic: claudeProvider, google: geminiProvider, xai: grokProvider, deepseek: deepSeekProvider, openrouter: openRouterProvider, elevenlabs: elevenLabsProvider, luma: lumaProvider, ideogram: ideogramProvider, midjourney: midjourneyProvider, stability: stabilityProvider, fireworks: fireworksProvider, fal: falProvider, runway: runwayProvider, bytedance: bytedanceProvider, test: testProvider, }; const MODEL_PROVIDER_MAP = { 'text-embedding-004': geminiProvider, 'gpt-oss-': openRouterProvider, 'gpt-': openaiProvider, o1: openaiProvider, o3: openaiProvider, o4: openaiProvider, 'text-': openaiProvider, 'computer-use-preview': openaiProvider, 'dall-e': openaiProvider, 'gpt-image': openaiProvider, 'tts-': openaiProvider, 'codex-': openaiProvider, 'claude-': claudeProvider, 'gemini-': geminiProvider, 'imagen-': geminiProvider, 'luma-': lumaProvider, 'ideogram-': ideogramProvider, 'midjourney-': midjourneyProvider, 'flux-': fireworksProvider, 'fireworks-': fireworksProvider, 'stability-': stabilityProvider, 'sdxl-': stabilityProvider, sd3: stabilityProvider, 'runway-': runwayProvider, 'recraft-': falProvider, 'fal-': falProvider, 'seedream-': bytedanceProvider, 'bytedance-': bytedanceProvider, 'byteplus-': bytedanceProvider, 'grok-': grokProvider, 'deepseek-': deepSeekProvider, eleven_: elevenLabsProvider, 'elevenlabs-': elevenLabsProvider, 'test-': testProvider, }; export function isProviderKeyValid(provider) { switch (provider) { case 'openai': return !!process.env.OPENAI_API_KEY && process.env.OPENAI_API_KEY.startsWith('sk-'); case 'anthropic': return !!process.env.ANTHROPIC_API_KEY && process.env.ANTHROPIC_API_KEY.startsWith('sk-ant-'); case 'google': return !!process.env.GOOGLE_API_KEY; case 'xai': return !!process.env.XAI_API_KEY && process.env.XAI_API_KEY.startsWith('xai-'); case 'deepseek': return !!process.env.DEEPSEEK_API_KEY && process.env.DEEPSEEK_API_KEY.startsWith('sk-'); case 'openrouter': return !!process.env.OPENROUTER_API_KEY; case 'elevenlabs': return !!process.env.ELEVENLABS_API_KEY; case 'luma': return !!process.env.LUMA_API_KEY; case 'ideogram': return !!process.env.IDEOGRAM_API_KEY; case 'midjourney': return !!(process.env.MIDJOURNEY_API_KEY || process.env.MJ_API_KEY || process.env.KIE_API_KEY); case 'test': return true; case 'stability': return !!process.env.STABILITY_API_KEY; case 'fireworks': return !!process.env.FIREWORKS_API_KEY; case 'fal': return !!process.env.FAL_KEY; case 'bytedance': return !!(process.env.ARK_API_KEY || process.env.BYTEPLUS_API_KEY || process.env.BYTEDANCE_API_KEY); case 'runway': return !!process.env.RUNWAY_API_KEY && process.env.RUNWAY_API_KEY.startsWith('key_'); default: { const externalProvider = getExternalProvider(provider); if (externalProvider) { return true; } return false; } } } export function getProviderFromModel(model) { if (isExternalModel(model)) { const externalModel = getExternalModel(model); if (externalModel) { return externalModel.provider; } } const registeredModel = findModel(model); if (registeredModel) { return registeredModel.provider; } if (model.startsWith('gpt-oss-')) { return 'openrouter'; } if (model.startsWith('gpt-') || model.startsWith('o1') || model.startsWith('o3') || model.startsWith('o4') || model.startsWith('text-') || model.startsWith('computer-use-preview') || model.startsWith('dall-e') || model.startsWith('gpt-image') || model.startsWith('tts-')) { return 'openai'; } else if (model.startsWith('claude-')) { return 'anthropic'; } else if (model.startsWith('gemini-') || model.startsWith('imagen-')) { return 'google'; } else if (model.startsWith('firefly-')) { return 'openrouter'; } else if (model.startsWith('replicate-')) { return 'runway'; } else if (model.startsWith('flux-') || model.startsWith('fireworks-')) { return 'fireworks'; } else if (model.startsWith('stability-') || model.startsWith('sdxl-') || model.startsWith('sd3')) { return 'stability'; } else if (model.startsWith('runway-')) { return 'runway'; } else if (model.startsWith('runwayml-')) { return 'runway'; } else if (model.startsWith('seedream-') || model.startsWith('bytedance-') || model.startsWith('byteplus-')) { return 'bytedance'; } else if (model.startsWith('recraft-') || model.startsWith('fal-')) { return 'fal'; } else if (model.startsWith('grok-')) { return 'xai'; } else if (model.startsWith('deepseek-')) { return 'deepseek'; } else if (model.startsWith('eleven_') || model.startsWith('elevenlabs-')) { return 'elevenlabs'; } else if (model.startsWith('test-')) { return 'test'; } return 'openrouter'; } function filterModelsWithFallback(models, excludeModels, disabledModels) { const allExcluded = [...(excludeModels || []), ...(disabledModels || [])]; if (allExcluded.length === 0) { return models; } const originalModels = [...models]; const filteredModels = models.filter(model => !allExcluded.includes(model)); if (filteredModels.length === 0) { const lastUsedModel = [...(excludeModels || [])] .reverse() .find(excludedModel => originalModels.includes(excludedModel)); if (lastUsedModel) { let nextIndex = (originalModels.indexOf(lastUsedModel) + 1) % originalModels.length; let attempts = 0; while (attempts < originalModels.length) { const nextModel = originalModels[nextIndex]; if (!disabledModels?.includes(nextModel)) { return [nextModel]; } nextIndex = (nextIndex + 1) % originalModels.length; attempts++; } } const firstNonDisabled = originalModels.find(m => !disabledModels?.includes(m)); if (firstNonDisabled) { return [firstNonDisabled]; } if (originalModels.length > 0) { return [originalModels[0]]; } } return filteredModels; } function selectWeightedModel(models, scores) { if (!scores || models.length === 0) { return models[Math.floor(Math.random() * models.length)]; } const modelWeights = models .map(model => ({ model, weight: scores[model] !== undefined ? scores[model] : 50, })) .filter(m => m.weight > 0); if (modelWeights.length === 0) { return models[Math.floor(Math.random() * models.length)]; } const totalWeight = modelWeights.reduce((sum, m) => sum + m.weight, 0); if (totalWeight === 0) { return modelWeights[0].model; } let random = Math.random() * totalWeight; for (const { model, weight } of modelWeights) { random -= weight; if (random <= 0) { return model; } } return modelWeights[0].model; } export async function getModelFromAgent(agent, defaultClass, excludeModels) { const model = agent.model || (await getModelFromClass(agent.modelClass || defaultClass, excludeModels, agent.disabledModels, agent.modelScores)); const suffixes = ['-low', '-medium', '-high', '-max']; let suffix = ''; let baseModel = model; for (const s of suffixes) { if (model.endsWith(s)) { suffix = s; baseModel = model.slice(0, -s.length); break; } } const modelEntry = findModel(baseModel); if (modelEntry?.id) { return modelEntry.id + suffix; } return model; } export async function getModelFromClass(modelClass, excludeModels, disabledModels, modelScores) { const { quotaTracker } = await import('../utils/quota_tracker.js'); const modelClassStr = modelClass; const modelGroup = modelClassStr && modelClassStr in MODEL_CLASSES ? modelClassStr : 'standard'; if (modelGroup in MODEL_CLASSES) { const override = getModelClassOverride(modelGroup); let modelClassConfig = MODEL_CLASSES[modelGroup]; if (override) { modelClassConfig = { ...modelClassConfig, ...override, }; } let models = [...(override?.models || modelClassConfig.models)]; models = filterModelsWithFallback(models, excludeModels, disabledModels); const shouldRandomize = override?.random ?? ('random' in modelClassConfig && modelClassConfig.random); const validModels = [...models]; const modelsWithKeyAndQuota = validModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider) && quotaTracker.hasQuota(provider, model); }); if (modelsWithKeyAndQuota.length > 0) { const selectedModel = shouldRandomize && !modelScores ? modelsWithKeyAndQuota[Math.floor(Math.random() * modelsWithKeyAndQuota.length)] : selectWeightedModel(modelsWithKeyAndQuota, modelScores); console.log(`Using '${selectedModel}' model for '${modelGroup}' class.`); return selectedModel; } const modelsWithKey = validModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider); }); if (modelsWithKey.length > 0) { const selectedModel = shouldRandomize && !modelScores ? modelsWithKey[Math.floor(Math.random() * modelsWithKey.length)] : selectWeightedModel(modelsWithKey, modelScores); console.log(`Using '${selectedModel}' model for '${modelGroup}' class (may exceed quota).`); return selectedModel; } } if (modelGroup !== 'standard' && 'standard' in MODEL_CLASSES) { let standardModels = MODEL_CLASSES['standard'].models; standardModels = filterModelsWithFallback(standardModels, excludeModels, disabledModels); const standardModelsWithKeyAndQuota = standardModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider) && quotaTracker.hasQuota(provider, model); }); if (standardModelsWithKeyAndQuota.length > 0) { const selectedModel = selectWeightedModel(standardModelsWithKeyAndQuota, modelScores); console.warn(`Falling back to 'standard' class with model '${selectedModel}'.`); return selectedModel; } const standardModelsWithKey = standardModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider); }); if (standardModelsWithKey.length > 0) { const selectedModel = selectWeightedModel(standardModelsWithKey, modelScores); console.log(`Falling back to 'standard' class with model '${selectedModel}' (may exceed quota).`); return selectedModel; } } let defaultModel = 'gpt-5.2-chat-latest'; if (modelGroup in MODEL_CLASSES) { const models = MODEL_CLASSES[modelGroup].models; if (models.length > 0) { defaultModel = models[0]; } } console.log(`No valid API key found for any model in class ${modelGroup}, using default: ${defaultModel}`); return defaultModel; } export function getModelProvider(model) { if (model) { if (isExternalModel(model)) { const externalModel = getExternalModel(model); if (externalModel) { const externalProvider = getExternalProvider(externalModel.provider); if (externalProvider) { return externalProvider; } } } const registeredModel = findModel(model); if (registeredModel) { const providerName = registeredModel.provider; const provider = PROVIDER_BY_ID[providerName]; if (!provider) { throw new Error(`No provider implementation found for ${providerName}.`); } if (!isProviderKeyValid(providerName)) { throw new Error(`API key for ${providerName} provider is missing or invalid. Please set ${providerName.toUpperCase()}_API_KEY environment variable.`); } return provider; } for (const [prefix, provider] of Object.entries(MODEL_PROVIDER_MAP)) { if (model.startsWith(prefix)) { const providerName = getProviderFromModel(model); if (!isProviderKeyValid(providerName)) { throw new Error(`API key for ${providerName} provider is missing or invalid. Please set ${providerName.toUpperCase()}_API_KEY environment variable.`); } return provider; } } } if (isProviderKeyValid(getProviderFromModel('openrouter'))) { return openRouterProvider; } if (isProviderKeyValid('fal')) { return falProvider; } throw new Error(`No valid provider found for the model ${model}. Please check your API keys.`); } export async function canRunAgent(agent) { if (agent.model) { const provider = getProviderFromModel(agent.model); const hasKey = isProviderKeyValid(provider); return { canRun: hasKey, model: agent.model, provider, missingProvider: hasKey ? undefined : provider, reason: hasKey ? undefined : `Missing API key for provider: ${provider}`, }; } if (agent.modelClass) { const modelClassStr = agent.modelClass; const modelGroup = modelClassStr && modelClassStr in MODEL_CLASSES ? modelClassStr : 'standard'; const override = getModelClassOverride(modelGroup); let modelClassConfig = MODEL_CLASSES[modelGroup]; if (override) { modelClassConfig = { ...modelClassConfig, ...override, }; } const models = [...(override?.models || modelClassConfig.models)]; const availableModels = []; const unavailableModels = []; const missingProviders = new Set(); for (const model of models) { const provider = getProviderFromModel(model); if (isProviderKeyValid(provider)) { availableModels.push(model); } else { unavailableModels.push(model); missingProviders.add(provider); } } return { canRun: availableModels.length > 0, availableModels, unavailableModels, missingProvider: availableModels.length === 0 && missingProviders.size === 1 ? Array.from(missingProviders)[0] : undefined, reason: availableModels.length === 0 ? `No API keys found for any models in class: ${agent.modelClass}. Missing providers: ${Array.from(missingProviders).join(', ')}` : undefined, }; } return canRunAgent({ modelClass: 'standard' }); } //# sourceMappingURL=model_provider.js.map