UNPKG

@just-every/ensemble

Version:

LLM provider abstraction layer with unified streaming interface

349 lines 14.9 kB
"use strict"; var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { if (k2 === undefined) k2 = k; var desc = Object.getOwnPropertyDescriptor(m, k); if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { desc = { enumerable: true, get: function() { return m[k]; } }; } Object.defineProperty(o, k2, desc); }) : (function(o, m, k, k2) { if (k2 === undefined) k2 = k; o[k2] = m[k]; })); var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { Object.defineProperty(o, "default", { enumerable: true, value: v }); }) : function(o, v) { o["default"] = v; }); var __importStar = (this && this.__importStar) || (function () { var ownKeys = function(o) { ownKeys = Object.getOwnPropertyNames || function (o) { var ar = []; for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k; return ar; }; return ownKeys(o); }; return function (mod) { if (mod && mod.__esModule) return mod; var result = {}; if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]); __setModuleDefault(result, mod); return result; }; })(); Object.defineProperty(exports, "__esModule", { value: true }); exports.isProviderKeyValid = isProviderKeyValid; exports.getProviderFromModel = getProviderFromModel; exports.getModelFromAgent = getModelFromAgent; exports.getModelFromClass = getModelFromClass; exports.getModelProvider = getModelProvider; exports.canRunAgent = canRunAgent; const external_models_js_1 = require("../utils/external_models.cjs"); const openai_js_1 = require("./openai.cjs"); const claude_js_1 = require("./claude.cjs"); const gemini_js_1 = require("./gemini.cjs"); const grok_js_1 = require("./grok.cjs"); const deepseek_js_1 = require("./deepseek.cjs"); const test_provider_js_1 = require("./test_provider.cjs"); const openrouter_js_1 = require("./openrouter.cjs"); const elevenlabs_js_1 = require("./elevenlabs.cjs"); const model_data_js_1 = require("../data/model_data.cjs"); const MODEL_PROVIDER_MAP = { 'gpt-': openai_js_1.openaiProvider, o1: openai_js_1.openaiProvider, o3: openai_js_1.openaiProvider, o4: openai_js_1.openaiProvider, 'text-': openai_js_1.openaiProvider, 'computer-use-preview': openai_js_1.openaiProvider, 'dall-e': openai_js_1.openaiProvider, 'gpt-image': openai_js_1.openaiProvider, 'tts-': openai_js_1.openaiProvider, 'codex-': openai_js_1.openaiProvider, 'claude-': claude_js_1.claudeProvider, 'gemini-': gemini_js_1.geminiProvider, 'imagen-': gemini_js_1.geminiProvider, 'grok-': grok_js_1.grokProvider, 'deepseek-': deepseek_js_1.deepSeekProvider, eleven_: elevenlabs_js_1.elevenLabsProvider, 'elevenlabs-': elevenlabs_js_1.elevenLabsProvider, 'test-': test_provider_js_1.testProvider, }; function isProviderKeyValid(provider) { switch (provider) { case 'openai': return !!process.env.OPENAI_API_KEY && process.env.OPENAI_API_KEY.startsWith('sk-'); case 'anthropic': return !!process.env.ANTHROPIC_API_KEY && process.env.ANTHROPIC_API_KEY.startsWith('sk-ant-'); case 'google': return !!process.env.GOOGLE_API_KEY; case 'xai': return !!process.env.XAI_API_KEY && process.env.XAI_API_KEY.startsWith('xai-'); case 'deepseek': return !!process.env.DEEPSEEK_API_KEY && process.env.DEEPSEEK_API_KEY.startsWith('sk-'); case 'openrouter': return !!process.env.OPENROUTER_API_KEY; case 'elevenlabs': return !!process.env.ELEVENLABS_API_KEY; case 'test': return true; default: { const externalProvider = (0, external_models_js_1.getExternalProvider)(provider); if (externalProvider) { return true; } return false; } } } function getProviderFromModel(model) { if ((0, external_models_js_1.isExternalModel)(model)) { const externalModel = (0, external_models_js_1.getExternalModel)(model); if (externalModel) { return externalModel.provider; } } if (model.startsWith('gpt-') || model.startsWith('o1') || model.startsWith('o3') || model.startsWith('o4') || model.startsWith('text-') || model.startsWith('computer-use-preview') || model.startsWith('dall-e') || model.startsWith('gpt-image') || model.startsWith('tts-')) { return 'openai'; } else if (model.startsWith('claude-')) { return 'anthropic'; } else if (model.startsWith('gemini-') || model.startsWith('imagen-')) { return 'google'; } else if (model.startsWith('grok-')) { return 'xai'; } else if (model.startsWith('deepseek-')) { return 'deepseek'; } else if (model.startsWith('eleven_') || model.startsWith('elevenlabs-')) { return 'elevenlabs'; } else if (model.startsWith('test-')) { return 'test'; } return 'openrouter'; } function filterModelsWithFallback(models, excludeModels, disabledModels) { const allExcluded = [...(excludeModels || []), ...(disabledModels || [])]; if (allExcluded.length === 0) { return models; } const originalModels = [...models]; const filteredModels = models.filter(model => !allExcluded.includes(model)); if (filteredModels.length === 0) { const lastUsedModel = [...(excludeModels || [])] .reverse() .find(excludedModel => originalModels.includes(excludedModel)); if (lastUsedModel) { let nextIndex = (originalModels.indexOf(lastUsedModel) + 1) % originalModels.length; let attempts = 0; while (attempts < originalModels.length) { const nextModel = originalModels[nextIndex]; if (!disabledModels?.includes(nextModel)) { return [nextModel]; } nextIndex = (nextIndex + 1) % originalModels.length; attempts++; } } const firstNonDisabled = originalModels.find(m => !disabledModels?.includes(m)); if (firstNonDisabled) { return [firstNonDisabled]; } if (originalModels.length > 0) { return [originalModels[0]]; } } return filteredModels; } function selectWeightedModel(models, scores) { if (!scores || models.length === 0) { return models[Math.floor(Math.random() * models.length)]; } const modelWeights = models .map(model => ({ model, weight: scores[model] !== undefined ? scores[model] : 50, })) .filter(m => m.weight > 0); if (modelWeights.length === 0) { return models[Math.floor(Math.random() * models.length)]; } const totalWeight = modelWeights.reduce((sum, m) => sum + m.weight, 0); if (totalWeight === 0) { return modelWeights[0].model; } let random = Math.random() * totalWeight; for (const { model, weight } of modelWeights) { random -= weight; if (random <= 0) { return model; } } return modelWeights[0].model; } async function getModelFromAgent(agent, defaultClass, excludeModels) { return (agent.model || (await getModelFromClass(agent.modelClass || defaultClass, excludeModels, agent.disabledModels, agent.modelScores))); } async function getModelFromClass(modelClass, excludeModels, disabledModels, modelScores) { const { quotaTracker } = await Promise.resolve().then(() => __importStar(require("../utils/quota_tracker.cjs"))); const modelClassStr = modelClass; const modelGroup = modelClassStr && modelClassStr in model_data_js_1.MODEL_CLASSES ? modelClassStr : 'standard'; if (modelGroup in model_data_js_1.MODEL_CLASSES) { const override = (0, external_models_js_1.getModelClassOverride)(modelGroup); let modelClassConfig = model_data_js_1.MODEL_CLASSES[modelGroup]; if (override) { modelClassConfig = { ...modelClassConfig, ...override, }; } let models = [...(override?.models || modelClassConfig.models)]; models = filterModelsWithFallback(models, excludeModels, disabledModels); const shouldRandomize = override?.random ?? ('random' in modelClassConfig && modelClassConfig.random); const validModels = [...models]; const modelsWithKeyAndQuota = validModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider) && quotaTracker.hasQuota(provider, model); }); if (modelsWithKeyAndQuota.length > 0) { const selectedModel = shouldRandomize && !modelScores ? modelsWithKeyAndQuota[Math.floor(Math.random() * modelsWithKeyAndQuota.length)] : selectWeightedModel(modelsWithKeyAndQuota, modelScores); console.log(`Using '${selectedModel}' model for '${modelGroup}' class.`); return selectedModel; } const modelsWithKey = validModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider); }); if (modelsWithKey.length > 0) { const selectedModel = shouldRandomize && !modelScores ? modelsWithKey[Math.floor(Math.random() * modelsWithKey.length)] : selectWeightedModel(modelsWithKey, modelScores); console.log(`Using '${selectedModel}' model for '${modelGroup}' class (may exceed quota).`); return selectedModel; } } if (modelGroup !== 'standard' && 'standard' in model_data_js_1.MODEL_CLASSES) { let standardModels = model_data_js_1.MODEL_CLASSES['standard'].models; standardModels = filterModelsWithFallback(standardModels, excludeModels, disabledModels); const standardModelsWithKeyAndQuota = standardModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider) && quotaTracker.hasQuota(provider, model); }); if (standardModelsWithKeyAndQuota.length > 0) { const selectedModel = selectWeightedModel(standardModelsWithKeyAndQuota, modelScores); console.warn(`Falling back to 'standard' class with model '${selectedModel}'.`); return selectedModel; } const standardModelsWithKey = standardModels.filter(model => { const provider = getProviderFromModel(model); return isProviderKeyValid(provider); }); if (standardModelsWithKey.length > 0) { const selectedModel = selectWeightedModel(standardModelsWithKey, modelScores); console.log(`Falling back to 'standard' class with model '${selectedModel}' (may exceed quota).`); return selectedModel; } } let defaultModel = 'gpt-4.1'; if (modelGroup in model_data_js_1.MODEL_CLASSES) { const models = model_data_js_1.MODEL_CLASSES[modelGroup].models; if (models.length > 0) { defaultModel = models[0]; } } console.log(`No valid API key found for any model in class ${modelGroup}, using default: ${defaultModel}`); return defaultModel; } function getModelProvider(model) { if (model) { if ((0, external_models_js_1.isExternalModel)(model)) { const externalModel = (0, external_models_js_1.getExternalModel)(model); if (externalModel) { const externalProvider = (0, external_models_js_1.getExternalProvider)(externalModel.provider); if (externalProvider) { return externalProvider; } } } for (const [prefix, provider] of Object.entries(MODEL_PROVIDER_MAP)) { if (model.startsWith(prefix)) { const providerName = getProviderFromModel(model); if (!isProviderKeyValid(providerName)) { throw new Error(`API key for ${providerName} provider is missing or invalid. Please set ${providerName.toUpperCase()}_API_KEY environment variable.`); } return provider; } } } if (!isProviderKeyValid(getProviderFromModel('openrouter'))) { throw new Error(`No valid provider found for the model ${model}. Please check your API keys.`); } return openrouter_js_1.openRouterProvider; } async function canRunAgent(agent) { if (agent.model) { const provider = getProviderFromModel(agent.model); const hasKey = isProviderKeyValid(provider); return { canRun: hasKey, model: agent.model, provider, missingProvider: hasKey ? undefined : provider, reason: hasKey ? undefined : `Missing API key for provider: ${provider}`, }; } if (agent.modelClass) { const modelClassStr = agent.modelClass; const modelGroup = modelClassStr && modelClassStr in model_data_js_1.MODEL_CLASSES ? modelClassStr : 'standard'; const override = (0, external_models_js_1.getModelClassOverride)(modelGroup); let modelClassConfig = model_data_js_1.MODEL_CLASSES[modelGroup]; if (override) { modelClassConfig = { ...modelClassConfig, ...override, }; } const models = [...(override?.models || modelClassConfig.models)]; const availableModels = []; const unavailableModels = []; const missingProviders = new Set(); for (const model of models) { const provider = getProviderFromModel(model); if (isProviderKeyValid(provider)) { availableModels.push(model); } else { unavailableModels.push(model); missingProviders.add(provider); } } return { canRun: availableModels.length > 0, availableModels, unavailableModels, missingProvider: availableModels.length === 0 && missingProviders.size === 1 ? Array.from(missingProviders)[0] : undefined, reason: availableModels.length === 0 ? `No API keys found for any models in class: ${agent.modelClass}. Missing providers: ${Array.from(missingProviders).join(', ')}` : undefined, }; } return canRunAgent({ modelClass: 'standard' }); } //# sourceMappingURL=model_provider.js.map