@just-every/ensemble
Version:
LLM provider abstraction layer with unified streaming interface
349 lines • 14.9 kB
JavaScript
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
var desc = Object.getOwnPropertyDescriptor(m, k);
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
desc = { enumerable: true, get: function() { return m[k]; } };
}
Object.defineProperty(o, k2, desc);
}) : (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
o[k2] = m[k];
}));
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
Object.defineProperty(o, "default", { enumerable: true, value: v });
}) : function(o, v) {
o["default"] = v;
});
var __importStar = (this && this.__importStar) || (function () {
var ownKeys = function(o) {
ownKeys = Object.getOwnPropertyNames || function (o) {
var ar = [];
for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k;
return ar;
};
return ownKeys(o);
};
return function (mod) {
if (mod && mod.__esModule) return mod;
var result = {};
if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]);
__setModuleDefault(result, mod);
return result;
};
})();
Object.defineProperty(exports, "__esModule", { value: true });
exports.isProviderKeyValid = isProviderKeyValid;
exports.getProviderFromModel = getProviderFromModel;
exports.getModelFromAgent = getModelFromAgent;
exports.getModelFromClass = getModelFromClass;
exports.getModelProvider = getModelProvider;
exports.canRunAgent = canRunAgent;
const external_models_js_1 = require("../utils/external_models.cjs");
const openai_js_1 = require("./openai.cjs");
const claude_js_1 = require("./claude.cjs");
const gemini_js_1 = require("./gemini.cjs");
const grok_js_1 = require("./grok.cjs");
const deepseek_js_1 = require("./deepseek.cjs");
const test_provider_js_1 = require("./test_provider.cjs");
const openrouter_js_1 = require("./openrouter.cjs");
const elevenlabs_js_1 = require("./elevenlabs.cjs");
const model_data_js_1 = require("../data/model_data.cjs");
const MODEL_PROVIDER_MAP = {
'gpt-': openai_js_1.openaiProvider,
o1: openai_js_1.openaiProvider,
o3: openai_js_1.openaiProvider,
o4: openai_js_1.openaiProvider,
'text-': openai_js_1.openaiProvider,
'computer-use-preview': openai_js_1.openaiProvider,
'dall-e': openai_js_1.openaiProvider,
'gpt-image': openai_js_1.openaiProvider,
'tts-': openai_js_1.openaiProvider,
'codex-': openai_js_1.openaiProvider,
'claude-': claude_js_1.claudeProvider,
'gemini-': gemini_js_1.geminiProvider,
'imagen-': gemini_js_1.geminiProvider,
'grok-': grok_js_1.grokProvider,
'deepseek-': deepseek_js_1.deepSeekProvider,
eleven_: elevenlabs_js_1.elevenLabsProvider,
'elevenlabs-': elevenlabs_js_1.elevenLabsProvider,
'test-': test_provider_js_1.testProvider,
};
function isProviderKeyValid(provider) {
switch (provider) {
case 'openai':
return !!process.env.OPENAI_API_KEY && process.env.OPENAI_API_KEY.startsWith('sk-');
case 'anthropic':
return !!process.env.ANTHROPIC_API_KEY && process.env.ANTHROPIC_API_KEY.startsWith('sk-ant-');
case 'google':
return !!process.env.GOOGLE_API_KEY;
case 'xai':
return !!process.env.XAI_API_KEY && process.env.XAI_API_KEY.startsWith('xai-');
case 'deepseek':
return !!process.env.DEEPSEEK_API_KEY && process.env.DEEPSEEK_API_KEY.startsWith('sk-');
case 'openrouter':
return !!process.env.OPENROUTER_API_KEY;
case 'elevenlabs':
return !!process.env.ELEVENLABS_API_KEY;
case 'test':
return true;
default: {
const externalProvider = (0, external_models_js_1.getExternalProvider)(provider);
if (externalProvider) {
return true;
}
return false;
}
}
}
function getProviderFromModel(model) {
if ((0, external_models_js_1.isExternalModel)(model)) {
const externalModel = (0, external_models_js_1.getExternalModel)(model);
if (externalModel) {
return externalModel.provider;
}
}
if (model.startsWith('gpt-') ||
model.startsWith('o1') ||
model.startsWith('o3') ||
model.startsWith('o4') ||
model.startsWith('text-') ||
model.startsWith('computer-use-preview') ||
model.startsWith('dall-e') ||
model.startsWith('gpt-image') ||
model.startsWith('tts-')) {
return 'openai';
}
else if (model.startsWith('claude-')) {
return 'anthropic';
}
else if (model.startsWith('gemini-') || model.startsWith('imagen-')) {
return 'google';
}
else if (model.startsWith('grok-')) {
return 'xai';
}
else if (model.startsWith('deepseek-')) {
return 'deepseek';
}
else if (model.startsWith('eleven_') || model.startsWith('elevenlabs-')) {
return 'elevenlabs';
}
else if (model.startsWith('test-')) {
return 'test';
}
return 'openrouter';
}
function filterModelsWithFallback(models, excludeModels, disabledModels) {
const allExcluded = [...(excludeModels || []), ...(disabledModels || [])];
if (allExcluded.length === 0) {
return models;
}
const originalModels = [...models];
const filteredModels = models.filter(model => !allExcluded.includes(model));
if (filteredModels.length === 0) {
const lastUsedModel = [...(excludeModels || [])]
.reverse()
.find(excludedModel => originalModels.includes(excludedModel));
if (lastUsedModel) {
let nextIndex = (originalModels.indexOf(lastUsedModel) + 1) % originalModels.length;
let attempts = 0;
while (attempts < originalModels.length) {
const nextModel = originalModels[nextIndex];
if (!disabledModels?.includes(nextModel)) {
return [nextModel];
}
nextIndex = (nextIndex + 1) % originalModels.length;
attempts++;
}
}
const firstNonDisabled = originalModels.find(m => !disabledModels?.includes(m));
if (firstNonDisabled) {
return [firstNonDisabled];
}
if (originalModels.length > 0) {
return [originalModels[0]];
}
}
return filteredModels;
}
function selectWeightedModel(models, scores) {
if (!scores || models.length === 0) {
return models[Math.floor(Math.random() * models.length)];
}
const modelWeights = models
.map(model => ({
model,
weight: scores[model] !== undefined ? scores[model] : 50,
}))
.filter(m => m.weight > 0);
if (modelWeights.length === 0) {
return models[Math.floor(Math.random() * models.length)];
}
const totalWeight = modelWeights.reduce((sum, m) => sum + m.weight, 0);
if (totalWeight === 0) {
return modelWeights[0].model;
}
let random = Math.random() * totalWeight;
for (const { model, weight } of modelWeights) {
random -= weight;
if (random <= 0) {
return model;
}
}
return modelWeights[0].model;
}
async function getModelFromAgent(agent, defaultClass, excludeModels) {
return (agent.model ||
(await getModelFromClass(agent.modelClass || defaultClass, excludeModels, agent.disabledModels, agent.modelScores)));
}
async function getModelFromClass(modelClass, excludeModels, disabledModels, modelScores) {
const { quotaTracker } = await Promise.resolve().then(() => __importStar(require("../utils/quota_tracker.cjs")));
const modelClassStr = modelClass;
const modelGroup = modelClassStr && modelClassStr in model_data_js_1.MODEL_CLASSES ? modelClassStr : 'standard';
if (modelGroup in model_data_js_1.MODEL_CLASSES) {
const override = (0, external_models_js_1.getModelClassOverride)(modelGroup);
let modelClassConfig = model_data_js_1.MODEL_CLASSES[modelGroup];
if (override) {
modelClassConfig = {
...modelClassConfig,
...override,
};
}
let models = [...(override?.models || modelClassConfig.models)];
models = filterModelsWithFallback(models, excludeModels, disabledModels);
const shouldRandomize = override?.random ?? ('random' in modelClassConfig && modelClassConfig.random);
const validModels = [...models];
const modelsWithKeyAndQuota = validModels.filter(model => {
const provider = getProviderFromModel(model);
return isProviderKeyValid(provider) && quotaTracker.hasQuota(provider, model);
});
if (modelsWithKeyAndQuota.length > 0) {
const selectedModel = shouldRandomize && !modelScores
? modelsWithKeyAndQuota[Math.floor(Math.random() * modelsWithKeyAndQuota.length)]
: selectWeightedModel(modelsWithKeyAndQuota, modelScores);
console.log(`Using '${selectedModel}' model for '${modelGroup}' class.`);
return selectedModel;
}
const modelsWithKey = validModels.filter(model => {
const provider = getProviderFromModel(model);
return isProviderKeyValid(provider);
});
if (modelsWithKey.length > 0) {
const selectedModel = shouldRandomize && !modelScores
? modelsWithKey[Math.floor(Math.random() * modelsWithKey.length)]
: selectWeightedModel(modelsWithKey, modelScores);
console.log(`Using '${selectedModel}' model for '${modelGroup}' class (may exceed quota).`);
return selectedModel;
}
}
if (modelGroup !== 'standard' && 'standard' in model_data_js_1.MODEL_CLASSES) {
let standardModels = model_data_js_1.MODEL_CLASSES['standard'].models;
standardModels = filterModelsWithFallback(standardModels, excludeModels, disabledModels);
const standardModelsWithKeyAndQuota = standardModels.filter(model => {
const provider = getProviderFromModel(model);
return isProviderKeyValid(provider) && quotaTracker.hasQuota(provider, model);
});
if (standardModelsWithKeyAndQuota.length > 0) {
const selectedModel = selectWeightedModel(standardModelsWithKeyAndQuota, modelScores);
console.warn(`Falling back to 'standard' class with model '${selectedModel}'.`);
return selectedModel;
}
const standardModelsWithKey = standardModels.filter(model => {
const provider = getProviderFromModel(model);
return isProviderKeyValid(provider);
});
if (standardModelsWithKey.length > 0) {
const selectedModel = selectWeightedModel(standardModelsWithKey, modelScores);
console.log(`Falling back to 'standard' class with model '${selectedModel}' (may exceed quota).`);
return selectedModel;
}
}
let defaultModel = 'gpt-4.1';
if (modelGroup in model_data_js_1.MODEL_CLASSES) {
const models = model_data_js_1.MODEL_CLASSES[modelGroup].models;
if (models.length > 0) {
defaultModel = models[0];
}
}
console.log(`No valid API key found for any model in class ${modelGroup}, using default: ${defaultModel}`);
return defaultModel;
}
function getModelProvider(model) {
if (model) {
if ((0, external_models_js_1.isExternalModel)(model)) {
const externalModel = (0, external_models_js_1.getExternalModel)(model);
if (externalModel) {
const externalProvider = (0, external_models_js_1.getExternalProvider)(externalModel.provider);
if (externalProvider) {
return externalProvider;
}
}
}
for (const [prefix, provider] of Object.entries(MODEL_PROVIDER_MAP)) {
if (model.startsWith(prefix)) {
const providerName = getProviderFromModel(model);
if (!isProviderKeyValid(providerName)) {
throw new Error(`API key for ${providerName} provider is missing or invalid. Please set ${providerName.toUpperCase()}_API_KEY environment variable.`);
}
return provider;
}
}
}
if (!isProviderKeyValid(getProviderFromModel('openrouter'))) {
throw new Error(`No valid provider found for the model ${model}. Please check your API keys.`);
}
return openrouter_js_1.openRouterProvider;
}
async function canRunAgent(agent) {
if (agent.model) {
const provider = getProviderFromModel(agent.model);
const hasKey = isProviderKeyValid(provider);
return {
canRun: hasKey,
model: agent.model,
provider,
missingProvider: hasKey ? undefined : provider,
reason: hasKey ? undefined : `Missing API key for provider: ${provider}`,
};
}
if (agent.modelClass) {
const modelClassStr = agent.modelClass;
const modelGroup = modelClassStr && modelClassStr in model_data_js_1.MODEL_CLASSES ? modelClassStr : 'standard';
const override = (0, external_models_js_1.getModelClassOverride)(modelGroup);
let modelClassConfig = model_data_js_1.MODEL_CLASSES[modelGroup];
if (override) {
modelClassConfig = {
...modelClassConfig,
...override,
};
}
const models = [...(override?.models || modelClassConfig.models)];
const availableModels = [];
const unavailableModels = [];
const missingProviders = new Set();
for (const model of models) {
const provider = getProviderFromModel(model);
if (isProviderKeyValid(provider)) {
availableModels.push(model);
}
else {
unavailableModels.push(model);
missingProviders.add(provider);
}
}
return {
canRun: availableModels.length > 0,
availableModels,
unavailableModels,
missingProvider: availableModels.length === 0 && missingProviders.size === 1
? Array.from(missingProviders)[0]
: undefined,
reason: availableModels.length === 0
? `No API keys found for any models in class: ${agent.modelClass}. Missing providers: ${Array.from(missingProviders).join(', ')}`
: undefined,
};
}
return canRunAgent({ modelClass: 'standard' });
}
//# sourceMappingURL=model_provider.js.map
;