UNPKG

@just-every/ensemble

Version:

LLM provider abstraction layer with unified streaming interface

130 lines 6.14 kB
import { BaseModelProvider } from './base_provider.js'; import { costTracker } from '../utils/cost_tracker.js'; import { fetchWithTimeout } from '../utils/fetch_with_timeout.js'; import { log_llm_error, log_llm_request, log_llm_response } from '../utils/llm_logger.js'; const KIE_BASE = process.env.MJ_API_BASE || 'https://api.kie.ai'; function mapMJAspect(size) { if (!size) return undefined; const s = String(size); if (s === 'square' || s === '1024x1024' || s === '512x512' || s === '256x256') return '1:1'; if (s === 'landscape' || s === '1792x1024' || s === '1536x1024') return '16:9'; if (s === 'portrait' || s === '1024x1792' || s === '1024x1536') return '9:16'; return undefined; } export class MidjourneyProvider extends BaseModelProvider { constructor() { super('midjourney'); } async *createResponseStream() { throw new Error('Midjourney provider does not support text streaming'); } async createImage(prompt, model, agent, opts) { const apiKey = process.env.MIDJOURNEY_API_KEY || process.env.MJ_API_KEY || process.env.KIE_API_KEY; if (!apiKey) throw new Error('Midjourney provider: MIDJOURNEY_API_KEY (or KIE_API_KEY) is not set'); if (!process.env.MIDJOURNEY_API_KEY && process.env.MJ_API_KEY) { console.warn('[Midjourney] MJ_API_KEY is deprecated. Please set MIDJOURNEY_API_KEY instead.'); } const requestId = log_llm_request(agent.agent_id || 'default', 'midjourney', model, { prompt, opts }, new Date()); try { const aspect = mapMJAspect(opts?.size) || '1:1'; const n = Math.max(1, Math.min(4, opts?.n || 1)); let taskType = 'mj_txt2img'; const fileUrls = []; if (opts?.source_images) { const srcs = Array.isArray(opts.source_images) ? opts.source_images : [opts.source_images]; for (const si of srcs) { const s = typeof si === 'string' ? si : si?.data || ''; if (typeof s === 'string' && /^https?:\/\//i.test(s)) fileUrls.push(s); } if (fileUrls.length > 0) taskType = 'mj_img2img'; } const body = { taskType, version: '7', prompt, aspectRatio: aspect, speed: 'fast', stylization: 100, weirdness: 0, variety: 0, }; if (taskType === 'mj_img2img' && fileUrls.length > 0) { body.fileUrls = fileUrls; } const headers = { 'Content-Type': 'application/json', Authorization: `Bearer ${apiKey}`, }; const res = await fetchWithTimeout(`${KIE_BASE}/api/v1/mj/generate`, { method: 'POST', headers, body: JSON.stringify(body), }, 20000); const data = await res.json().catch(async () => ({ code: res.status, msg: await res.text() })); if (!res.ok || data?.code && data.code !== 200) { throw new Error(`MJ create failed: code=${data?.code ?? res.status} msg=${data?.msg ?? ''}`); } const taskId = data?.data?.taskId || data?.taskId || data?.id; if (!taskId) throw new Error('Midjourney: missing taskId'); const timeoutMs = 300000; const started = Date.now(); let images = []; let notFoundCount = 0; while (true) { const r = await fetchWithTimeout(`${KIE_BASE}/api/v1/mj/record-info?taskId=${encodeURIComponent(taskId)}`, { headers: { Authorization: `Bearer ${apiKey}` }, }, 15000); const info = await r.json().catch(async () => ({ code: r.status, msg: await r.text() })); const code = info?.code ?? r.status; const status = info?.data?.status || info?.status; if (code && code !== 200) { if (code === 404) { notFoundCount++; if (notFoundCount > 5) throw new Error(`Midjourney record not found for task ${taskId}`); } else { throw new Error(`Midjourney poll error: code=${code} msg=${info?.msg ?? ''}`); } } const list = info?.data?.resultInfoJson?.resultUrls || info?.data?.resultUrls || info?.resultInfoJson?.resultUrls || info?.resultUrls || []; const urls = (list || []) .map((u) => (typeof u === 'string' ? u : u?.resultUrl)) .filter(Boolean); if (status === 'SUCCESS' || info?.successFlag === 1 || urls.length > 0) { const urls = list .map((u) => (typeof u === 'string' ? u : u?.resultUrl)) .filter(Boolean); if (!urls.length) throw new Error('Midjourney: no result URLs'); images = urls.slice(0, n); break; } if (info?.successFlag === 2 || status === 'FAILED') throw new Error('Midjourney generation failed'); if (Date.now() - started > timeoutMs) throw new Error('Midjourney generation timed out'); await new Promise(r2 => setTimeout(r2, 1000)); } costTracker.addUsage({ model, image_count: images.length, request_id: opts?.request_id, metadata: { source: 'kie' } }); return images; } catch (err) { log_llm_error(requestId, err); throw err; } finally { log_llm_response(requestId, { ok: true }); } } } export const midjourneyProvider = new MidjourneyProvider(); //# sourceMappingURL=midjourney.js.map