@just-every/ensemble
Version:
LLM provider abstraction layer with unified streaming interface
308 lines • 12.7 kB
JavaScript
import { isExternalModel, getExternalModel, getExternalProvider, getModelClassOverride, } from '../utils/external_models.js';
import { openaiProvider } from './openai.js';
import { claudeProvider } from './claude.js';
import { geminiProvider } from './gemini.js';
import { grokProvider } from './grok.js';
import { deepSeekProvider } from './deepseek.js';
import { testProvider } from './test_provider.js';
import { openRouterProvider } from './openrouter.js';
import { elevenLabsProvider } from './elevenlabs.js';
import { MODEL_CLASSES } from '../data/model_data.js';
const MODEL_PROVIDER_MAP = {
'gpt-': openaiProvider,
o1: openaiProvider,
o3: openaiProvider,
o4: openaiProvider,
'text-': openaiProvider,
'computer-use-preview': openaiProvider,
'dall-e': openaiProvider,
'gpt-image': openaiProvider,
'tts-': openaiProvider,
'codex-': openaiProvider,
'claude-': claudeProvider,
'gemini-': geminiProvider,
'imagen-': geminiProvider,
'grok-': grokProvider,
'deepseek-': deepSeekProvider,
eleven_: elevenLabsProvider,
'elevenlabs-': elevenLabsProvider,
'test-': testProvider,
};
export function isProviderKeyValid(provider) {
switch (provider) {
case 'openai':
return !!process.env.OPENAI_API_KEY && process.env.OPENAI_API_KEY.startsWith('sk-');
case 'anthropic':
return !!process.env.ANTHROPIC_API_KEY && process.env.ANTHROPIC_API_KEY.startsWith('sk-ant-');
case 'google':
return !!process.env.GOOGLE_API_KEY;
case 'xai':
return !!process.env.XAI_API_KEY && process.env.XAI_API_KEY.startsWith('xai-');
case 'deepseek':
return !!process.env.DEEPSEEK_API_KEY && process.env.DEEPSEEK_API_KEY.startsWith('sk-');
case 'openrouter':
return !!process.env.OPENROUTER_API_KEY;
case 'elevenlabs':
return !!process.env.ELEVENLABS_API_KEY;
case 'test':
return true;
default: {
const externalProvider = getExternalProvider(provider);
if (externalProvider) {
return true;
}
return false;
}
}
}
export function getProviderFromModel(model) {
if (isExternalModel(model)) {
const externalModel = getExternalModel(model);
if (externalModel) {
return externalModel.provider;
}
}
if (model.startsWith('gpt-') ||
model.startsWith('o1') ||
model.startsWith('o3') ||
model.startsWith('o4') ||
model.startsWith('text-') ||
model.startsWith('computer-use-preview') ||
model.startsWith('dall-e') ||
model.startsWith('gpt-image') ||
model.startsWith('tts-')) {
return 'openai';
}
else if (model.startsWith('claude-')) {
return 'anthropic';
}
else if (model.startsWith('gemini-') || model.startsWith('imagen-')) {
return 'google';
}
else if (model.startsWith('grok-')) {
return 'xai';
}
else if (model.startsWith('deepseek-')) {
return 'deepseek';
}
else if (model.startsWith('eleven_') || model.startsWith('elevenlabs-')) {
return 'elevenlabs';
}
else if (model.startsWith('test-')) {
return 'test';
}
return 'openrouter';
}
function filterModelsWithFallback(models, excludeModels, disabledModels) {
const allExcluded = [...(excludeModels || []), ...(disabledModels || [])];
if (allExcluded.length === 0) {
return models;
}
const originalModels = [...models];
const filteredModels = models.filter(model => !allExcluded.includes(model));
if (filteredModels.length === 0) {
const lastUsedModel = [...(excludeModels || [])]
.reverse()
.find(excludedModel => originalModels.includes(excludedModel));
if (lastUsedModel) {
let nextIndex = (originalModels.indexOf(lastUsedModel) + 1) % originalModels.length;
let attempts = 0;
while (attempts < originalModels.length) {
const nextModel = originalModels[nextIndex];
if (!disabledModels?.includes(nextModel)) {
return [nextModel];
}
nextIndex = (nextIndex + 1) % originalModels.length;
attempts++;
}
}
const firstNonDisabled = originalModels.find(m => !disabledModels?.includes(m));
if (firstNonDisabled) {
return [firstNonDisabled];
}
if (originalModels.length > 0) {
return [originalModels[0]];
}
}
return filteredModels;
}
function selectWeightedModel(models, scores) {
if (!scores || models.length === 0) {
return models[Math.floor(Math.random() * models.length)];
}
const modelWeights = models
.map(model => ({
model,
weight: scores[model] !== undefined ? scores[model] : 50,
}))
.filter(m => m.weight > 0);
if (modelWeights.length === 0) {
return models[Math.floor(Math.random() * models.length)];
}
const totalWeight = modelWeights.reduce((sum, m) => sum + m.weight, 0);
if (totalWeight === 0) {
return modelWeights[0].model;
}
let random = Math.random() * totalWeight;
for (const { model, weight } of modelWeights) {
random -= weight;
if (random <= 0) {
return model;
}
}
return modelWeights[0].model;
}
export async function getModelFromAgent(agent, defaultClass, excludeModels) {
return (agent.model ||
(await getModelFromClass(agent.modelClass || defaultClass, excludeModels, agent.disabledModels, agent.modelScores)));
}
export async function getModelFromClass(modelClass, excludeModels, disabledModels, modelScores) {
const { quotaTracker } = await import('../utils/quota_tracker.js');
const modelClassStr = modelClass;
const modelGroup = modelClassStr && modelClassStr in MODEL_CLASSES ? modelClassStr : 'standard';
if (modelGroup in MODEL_CLASSES) {
const override = getModelClassOverride(modelGroup);
let modelClassConfig = MODEL_CLASSES[modelGroup];
if (override) {
modelClassConfig = {
...modelClassConfig,
...override,
};
}
let models = [...(override?.models || modelClassConfig.models)];
models = filterModelsWithFallback(models, excludeModels, disabledModels);
const shouldRandomize = override?.random ?? ('random' in modelClassConfig && modelClassConfig.random);
const validModels = [...models];
const modelsWithKeyAndQuota = validModels.filter(model => {
const provider = getProviderFromModel(model);
return isProviderKeyValid(provider) && quotaTracker.hasQuota(provider, model);
});
if (modelsWithKeyAndQuota.length > 0) {
const selectedModel = shouldRandomize && !modelScores
? modelsWithKeyAndQuota[Math.floor(Math.random() * modelsWithKeyAndQuota.length)]
: selectWeightedModel(modelsWithKeyAndQuota, modelScores);
console.log(`Using '${selectedModel}' model for '${modelGroup}' class.`);
return selectedModel;
}
const modelsWithKey = validModels.filter(model => {
const provider = getProviderFromModel(model);
return isProviderKeyValid(provider);
});
if (modelsWithKey.length > 0) {
const selectedModel = shouldRandomize && !modelScores
? modelsWithKey[Math.floor(Math.random() * modelsWithKey.length)]
: selectWeightedModel(modelsWithKey, modelScores);
console.log(`Using '${selectedModel}' model for '${modelGroup}' class (may exceed quota).`);
return selectedModel;
}
}
if (modelGroup !== 'standard' && 'standard' in MODEL_CLASSES) {
let standardModels = MODEL_CLASSES['standard'].models;
standardModels = filterModelsWithFallback(standardModels, excludeModels, disabledModels);
const standardModelsWithKeyAndQuota = standardModels.filter(model => {
const provider = getProviderFromModel(model);
return isProviderKeyValid(provider) && quotaTracker.hasQuota(provider, model);
});
if (standardModelsWithKeyAndQuota.length > 0) {
const selectedModel = selectWeightedModel(standardModelsWithKeyAndQuota, modelScores);
console.warn(`Falling back to 'standard' class with model '${selectedModel}'.`);
return selectedModel;
}
const standardModelsWithKey = standardModels.filter(model => {
const provider = getProviderFromModel(model);
return isProviderKeyValid(provider);
});
if (standardModelsWithKey.length > 0) {
const selectedModel = selectWeightedModel(standardModelsWithKey, modelScores);
console.log(`Falling back to 'standard' class with model '${selectedModel}' (may exceed quota).`);
return selectedModel;
}
}
let defaultModel = 'gpt-4.1';
if (modelGroup in MODEL_CLASSES) {
const models = MODEL_CLASSES[modelGroup].models;
if (models.length > 0) {
defaultModel = models[0];
}
}
console.log(`No valid API key found for any model in class ${modelGroup}, using default: ${defaultModel}`);
return defaultModel;
}
export function getModelProvider(model) {
if (model) {
if (isExternalModel(model)) {
const externalModel = getExternalModel(model);
if (externalModel) {
const externalProvider = getExternalProvider(externalModel.provider);
if (externalProvider) {
return externalProvider;
}
}
}
for (const [prefix, provider] of Object.entries(MODEL_PROVIDER_MAP)) {
if (model.startsWith(prefix)) {
const providerName = getProviderFromModel(model);
if (!isProviderKeyValid(providerName)) {
throw new Error(`API key for ${providerName} provider is missing or invalid. Please set ${providerName.toUpperCase()}_API_KEY environment variable.`);
}
return provider;
}
}
}
if (!isProviderKeyValid(getProviderFromModel('openrouter'))) {
throw new Error(`No valid provider found for the model ${model}. Please check your API keys.`);
}
return openRouterProvider;
}
export async function canRunAgent(agent) {
if (agent.model) {
const provider = getProviderFromModel(agent.model);
const hasKey = isProviderKeyValid(provider);
return {
canRun: hasKey,
model: agent.model,
provider,
missingProvider: hasKey ? undefined : provider,
reason: hasKey ? undefined : `Missing API key for provider: ${provider}`,
};
}
if (agent.modelClass) {
const modelClassStr = agent.modelClass;
const modelGroup = modelClassStr && modelClassStr in MODEL_CLASSES ? modelClassStr : 'standard';
const override = getModelClassOverride(modelGroup);
let modelClassConfig = MODEL_CLASSES[modelGroup];
if (override) {
modelClassConfig = {
...modelClassConfig,
...override,
};
}
const models = [...(override?.models || modelClassConfig.models)];
const availableModels = [];
const unavailableModels = [];
const missingProviders = new Set();
for (const model of models) {
const provider = getProviderFromModel(model);
if (isProviderKeyValid(provider)) {
availableModels.push(model);
}
else {
unavailableModels.push(model);
missingProviders.add(provider);
}
}
return {
canRun: availableModels.length > 0,
availableModels,
unavailableModels,
missingProvider: availableModels.length === 0 && missingProviders.size === 1
? Array.from(missingProviders)[0]
: undefined,
reason: availableModels.length === 0
? `No API keys found for any models in class: ${agent.modelClass}. Missing providers: ${Array.from(missingProviders).join(', ')}`
: undefined,
};
}
return canRunAgent({ modelClass: 'standard' });
}
//# sourceMappingURL=model_provider.js.map