UNPKG

@just-every/ensemble

Version:

LLM provider abstraction layer with unified streaming interface

126 lines 5.77 kB
import { BaseModelProvider } from './base_provider.js'; import { costTracker } from '../utils/cost_tracker.js'; import { log_llm_error, log_llm_request, log_llm_response } from '../utils/llm_logger.js'; const STABILITY_BASE = 'https://api.stability.ai/v2beta'; function mapAspect(size) { if (!size) return undefined; const s = String(size); if (s === 'square' || s === '1024x1024' || s === '512x512' || s === '256x256') return '1:1'; if (s === 'landscape' || s === '1792x1024' || s === '1536x1024') return '16:9'; if (s === 'portrait' || s === '1024x1792' || s === '1024x1536') return '9:16'; return undefined; } export class StabilityProvider extends BaseModelProvider { constructor() { super('stability'); } async *createResponseStream() { throw new Error('Stability provider does not support text streaming'); } endpointFor(model) { const m = model.toLowerCase(); if (m.includes('ultra')) return `${STABILITY_BASE}/stable-image/generate/ultra`; if (m.includes('core')) return `${STABILITY_BASE}/stable-image/generate/core`; return `${STABILITY_BASE}/stable-image/generate/sd3`; } async createImage(prompt, model, agent, opts = {}) { const apiKey = process.env.STABILITY_API_KEY; const requestId = log_llm_request(agent.agent_id || 'default', 'stability', model, { prompt, opts }, new Date()); try { if (!apiKey) throw new Error('STABILITY_API_KEY is not set'); const endpoint = this.endpointFor(model); const form = new FormData(); form.set('prompt', prompt); if (opts?.response_format === 'url') form.set('output_format', 'png'); const ml = model.toLowerCase(); if (ml.includes('sd3.5-large-turbo')) form.set('model', 'sd3.5-large-turbo'); else if (ml.includes('sd3.5-large')) form.set('model', 'sd3.5-large'); else if (ml.includes('sd3.5-medium')) form.set('model', 'sd3.5-medium'); else if (ml.includes('sd3.5-flash')) form.set('model', 'sd3.5-flash'); let isI2I = false; if (opts?.source_images) { const arr = Array.isArray(opts.source_images) ? opts.source_images : [opts.source_images]; const first = arr[0]; if (typeof first === 'string' && first.startsWith('data:image/')) { const [, mime, b64] = first.match(/^data:(image\/[^;]+);base64,(.+)$/) || []; if (!b64) throw new Error('Invalid base64 source image'); const bin = Buffer.from(b64, 'base64'); form.set('image', new Blob([bin], { type: mime || 'image/png' }), 'image.png'); form.set('mode', 'image-to-image'); form.set('strength', '0.75'); isI2I = true; } else if (typeof first === 'string') { const r = await fetch(first); const ct = r.headers.get('content-type') || 'image/png'; const buf = new Uint8Array(await r.arrayBuffer()); form.set('image', new Blob([buf], { type: ct }), 'image'); form.set('mode', 'image-to-image'); form.set('strength', '0.75'); isI2I = true; } } if (!isI2I) { const aspect = mapAspect(opts.size); if (aspect) form.set('aspect_ratio', aspect); } const res = await fetch(endpoint, { method: 'POST', headers: { Authorization: `Bearer ${apiKey}`, Accept: 'application/json', }, body: form, }); if (!res.ok) throw new Error(`Stability create failed: ${res.status} ${await res.text()}`); const contentType = res.headers.get('content-type') || ''; if (contentType.startsWith('image/')) { const buf = new Uint8Array(await res.arrayBuffer()); const b64 = Buffer.from(buf).toString('base64'); const mime = contentType.split(';')[0]; const dataUrl = `data:${mime};base64,${b64}`; costTracker.addUsage({ model, image_count: 1, request_id: opts?.request_id, metadata: { source: 'stability' } }); return [dataUrl]; } const json = await res.json(); const images = []; if (Array.isArray(json?.artifacts)) { for (const a of json.artifacts) if (a?.base64) images.push(`data:image/png;base64,${a.base64}`); } if (json?.image) images.push(`data:image/png;base64,${json.image}`); if (json?.images?.[0]?.base64) images.push(`data:image/png;base64,${json.images[0].base64}`); if (!images.length) throw new Error('Stability: no image in response'); costTracker.addUsage({ model, image_count: images.length, request_id: opts?.request_id, metadata: { source: 'stability' } }); return images; } catch (err) { log_llm_error(requestId, err); throw err; } finally { log_llm_response(requestId, { ok: true }); } } } export const stabilityProvider = new StabilityProvider(); //# sourceMappingURL=stability.js.map