@just-every/ensemble
Version:
LLM provider abstraction layer with unified streaming interface
189 lines • 8.55 kB
JavaScript
import { BaseModelProvider } from './base_provider.js';
import { costTracker } from '../utils/cost_tracker.js';
import { fetchWithTimeout } from '../utils/fetch_with_timeout.js';
import { log_llm_error, log_llm_request, log_llm_response } from '../utils/llm_logger.js';
const FW_BASE = 'https://api.fireworks.ai';
function mapAspect(size) {
if (!size)
return undefined;
const s = String(size);
if (s === 'square' || s === '1024x1024' || s === '512x512' || s === '256x256')
return '1:1';
if (s === 'landscape' || s === '1792x1024' || s === '1536x1024')
return '16:9';
if (s === 'portrait' || s === '1024x1792' || s === '1024x1536')
return '9:16';
return undefined;
}
export class FireworksProvider extends BaseModelProvider {
constructor() {
super('fireworks');
}
async *createResponseStream() {
throw new Error('Fireworks provider does not support text streaming');
}
isKontext(model) {
return model.includes('kontext');
}
fireworksModelId(model) {
const m = model.toLowerCase();
if (m.includes('kontext') && m.includes('max'))
return 'flux-kontext-max';
if (m.includes('kontext'))
return 'flux-kontext-pro';
if (m.includes('pro'))
return 'flux-pro-1.1';
if (m.includes('schnell'))
return 'flux-schnell';
return model;
}
async createImage(prompt, model, agent, opts = {}) {
const apiKey = process.env.FIREWORKS_API_KEY;
const falKey = process.env.FAL_KEY;
const requestId = log_llm_request(agent.agent_id || 'default', 'fireworks', model, { prompt, opts }, new Date());
try {
if (!apiKey)
throw new Error('FIREWORKS_API_KEY is not set');
const aspect_ratio = mapAspect(opts.size);
const modelId = this.fireworksModelId(model);
if (this.isKontext(modelId)) {
const createUrl = `${FW_BASE}/inference/v1/workflows/accounts/fireworks/models/${modelId}`;
let input_image;
if (opts?.source_images) {
const s = Array.isArray(opts.source_images) ? opts.source_images[0] : opts.source_images;
const v = typeof s === 'string' ? s : s?.data || s;
if (typeof v === 'string')
input_image = v;
}
const res = await fetchWithTimeout(createUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify({ prompt, aspect_ratio: aspect_ratio, output_format: 'png', ...(input_image ? { input_image } : {}) }),
}, 60000);
if (!res.ok) {
if ((res.status === 401 || res.status === 403) && falKey) {
return this.fallbackToFAL(prompt, modelId, opts);
}
throw new Error(`Fireworks Kontext create failed: ${res.status} ${await res.text()}`);
}
const data = await res.json();
const id = data.request_id || data.id;
if (!id)
throw new Error('Fireworks Kontext: missing request id');
const pollUrl = `${FW_BASE}/inference/v1/workflows/accounts/fireworks/models/${modelId}/get_result`;
const started = Date.now();
const timeoutMs = 240000;
while (true) {
const r = await fetchWithTimeout(pollUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify({ id }),
}, 60000);
if (!r.ok)
throw new Error(`Fireworks Kontext poll failed: ${r.status} ${await r.text()}`);
const out = await r.json();
const status = String(out.status || '').toLowerCase();
if (status.includes('ready') || status === 'succeeded' || status === 'completed') {
const urls = [];
if (Array.isArray(out.result?.images)) {
for (const im of out.result.images)
if (im?.url)
urls.push(im.url);
}
if (out.result?.image?.url)
urls.push(out.result.image.url);
if (out.result?.url)
urls.push(out.result.url);
if (out.result?.sample)
urls.push(out.result.sample);
if (!urls.length && out.result?.image_base64) {
urls.push(`data:image/png;base64,${out.result.image_base64}`);
}
if (!urls.length && typeof out.result === 'string' && /^https?:\/\//.test(out.result)) {
urls.push(out.result);
}
if (!urls.length)
throw new Error('Fireworks Kontext: no image result found');
costTracker.addUsage({ model, image_count: urls.length, request_id: opts?.request_id, metadata: { source: 'fireworks', model: modelId } });
return urls;
}
if (status.includes('error') || status.includes('failed')) {
throw new Error(`Fireworks Kontext failed: ${status}`);
}
if (Date.now() - started > timeoutMs)
throw new Error('Fireworks Kontext timed out');
await new Promise(r2 => setTimeout(r2, 1500));
}
}
try {
return await this.createImage(prompt, 'flux-kontext-pro', agent, opts);
}
catch (e) {
if (falKey)
return this.fallbackToFAL(prompt, model, opts);
throw e;
}
}
catch (err) {
log_llm_error(requestId, err);
if (process.env.FAL_KEY) {
try {
const urls = await this.fallbackToFAL(prompt, model, opts);
log_llm_response(requestId, { ok: true, fallback: 'fal' });
return urls;
}
catch (_) {
}
}
throw err;
}
finally {
log_llm_response(requestId, { ok: true });
}
}
async fallbackToFAL(prompt, model, opts = {}) {
const lower = model.toLowerCase();
let endpoint = '';
if (lower.includes('schnell'))
endpoint = 'fal-ai/flux/schnell';
else if (lower.includes('dev'))
endpoint = 'fal-ai/flux/dev';
else if (lower.includes('pro') || lower.includes('kontext'))
endpoint = 'fal-ai/flux-pro/kontext';
else
endpoint = 'fal-ai/flux/schnell';
const res = await fetch(`https://fal.run/${endpoint}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Key ${process.env.FAL_KEY}`,
},
body: JSON.stringify({
prompt,
input: { prompt },
}),
});
if (!res.ok)
throw new Error(`FAL fallback failed: ${res.status} ${await res.text()}`);
const data = await res.json();
const images = [];
const arr = data?.images || data?.output?.images || [];
for (const im of arr)
if (im?.url)
images.push(im.url);
if (!images.length && data?.url)
images.push(data.url);
if (!images.length)
throw new Error('FAL fallback: no image url');
costTracker.addUsage({ model, image_count: images.length, request_id: opts?.request_id, metadata: { source: 'fal-fallback' } });
return images;
}
}
export const fireworksProvider = new FireworksProvider();
//# sourceMappingURL=fireworks.js.map