UNPKG

@huggingface/inference

Version:

Typescript client for the Hugging Face Inference Providers and Inference Endpoints

1,400 lines (1,370 loc) 113 kB
var __defProp = Object.defineProperty; var __export = (target, all) => { for (var name2 in all) __defProp(target, name2, { get: all[name2], enumerable: true }); }; // src/tasks/index.ts var tasks_exports = {}; __export(tasks_exports, { audioClassification: () => audioClassification, audioToAudio: () => audioToAudio, automaticSpeechRecognition: () => automaticSpeechRecognition, chatCompletion: () => chatCompletion, chatCompletionStream: () => chatCompletionStream, documentQuestionAnswering: () => documentQuestionAnswering, featureExtraction: () => featureExtraction, fillMask: () => fillMask, imageClassification: () => imageClassification, imageSegmentation: () => imageSegmentation, imageToImage: () => imageToImage, imageToText: () => imageToText, objectDetection: () => objectDetection, questionAnswering: () => questionAnswering, request: () => request, sentenceSimilarity: () => sentenceSimilarity, streamingRequest: () => streamingRequest, summarization: () => summarization, tableQuestionAnswering: () => tableQuestionAnswering, tabularClassification: () => tabularClassification, tabularRegression: () => tabularRegression, textClassification: () => textClassification, textGeneration: () => textGeneration, textGenerationStream: () => textGenerationStream, textToImage: () => textToImage, textToSpeech: () => textToSpeech, textToVideo: () => textToVideo, tokenClassification: () => tokenClassification, translation: () => translation, visualQuestionAnswering: () => visualQuestionAnswering, zeroShotClassification: () => zeroShotClassification, zeroShotImageClassification: () => zeroShotImageClassification }); // src/config.ts var HF_HUB_URL = "https://huggingface.co"; var HF_ROUTER_URL = "https://router.huggingface.co"; var HF_HEADER_X_BILL_TO = "X-HF-Bill-To"; // src/providers/consts.ts var HARDCODED_MODEL_INFERENCE_MAPPING = { /** * "HF model ID" => "Model ID on Inference Provider's side" * * Example: * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", */ "black-forest-labs": {}, cerebras: {}, cohere: {}, "fal-ai": {}, "featherless-ai": {}, "fireworks-ai": {}, groq: {}, "hf-inference": {}, hyperbolic: {}, nebius: {}, novita: {}, nscale: {}, openai: {}, ovhcloud: {}, replicate: {}, sambanova: {}, together: {} }; // src/lib/InferenceOutputError.ts var InferenceOutputError = class extends TypeError { constructor(message) { super( `Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.` ); this.name = "InferenceOutputError"; } }; // src/utils/toArray.ts function toArray(obj) { if (Array.isArray(obj)) { return obj; } return [obj]; } // src/providers/providerHelper.ts var TaskProviderHelper = class { constructor(provider, baseUrl, clientSideRoutingOnly = false) { this.provider = provider; this.baseUrl = baseUrl; this.clientSideRoutingOnly = clientSideRoutingOnly; } /** * Prepare the base URL for the request */ makeBaseUrl(params) { return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl; } /** * Prepare the body for the request */ makeBody(params) { if ("data" in params.args && !!params.args.data) { return params.args.data; } return JSON.stringify(this.preparePayload(params)); } /** * Prepare the URL for the request */ makeUrl(params) { const baseUrl = this.makeBaseUrl(params); const route = this.makeRoute(params).replace(/^\/+/, ""); return `${baseUrl}/${route}`; } /** * Prepare the headers for the request */ prepareHeaders(params, isBinary) { const headers = { Authorization: `Bearer ${params.accessToken}` }; if (!isBinary) { headers["Content-Type"] = "application/json"; } return headers; } }; var BaseConversationalTask = class extends TaskProviderHelper { constructor(provider, baseUrl, clientSideRoutingOnly = false) { super(provider, baseUrl, clientSideRoutingOnly); } makeRoute() { return "v1/chat/completions"; } preparePayload(params) { return { ...params.args, model: params.model }; } async getResponse(response) { if (typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint (response.system_fingerprint === void 0 || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object") { return response; } throw new InferenceOutputError("Expected ChatCompletionOutput"); } }; var BaseTextGenerationTask = class extends TaskProviderHelper { constructor(provider, baseUrl, clientSideRoutingOnly = false) { super(provider, baseUrl, clientSideRoutingOnly); } preparePayload(params) { return { ...params.args, model: params.model }; } makeRoute() { return "v1/completions"; } async getResponse(response) { const res = toArray(response); if (Array.isArray(res) && res.length > 0 && res.every( (x) => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string" )) { return res[0]; } throw new InferenceOutputError("Expected Array<{generated_text: string}>"); } }; // src/providers/hf-inference.ts var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"]; var HFInferenceTask = class extends TaskProviderHelper { constructor() { super("hf-inference", `${HF_ROUTER_URL}/hf-inference`); } preparePayload(params) { return params.args; } makeUrl(params) { if (params.model.startsWith("http://") || params.model.startsWith("https://")) { return params.model; } return super.makeUrl(params); } makeRoute(params) { if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) { return `models/${params.model}/pipeline/${params.task}`; } return `models/${params.model}`; } async getResponse(response) { return response; } }; var HFInferenceTextToImageTask = class extends HFInferenceTask { async getResponse(response, url, headers, outputType) { if (!response) { throw new InferenceOutputError("response is undefined"); } if (typeof response == "object") { if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) { const base64Data = response.data[0].b64_json; if (outputType === "url") { return `data:image/jpeg;base64,${base64Data}`; } const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`); return await base64Response.blob(); } if ("output" in response && Array.isArray(response.output)) { if (outputType === "url") { return response.output[0]; } const urlResponse = await fetch(response.output[0]); const blob = await urlResponse.blob(); return blob; } } if (response instanceof Blob) { if (outputType === "url") { const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64")); return `data:image/jpeg;base64,${b64}`; } return response; } throw new InferenceOutputError("Expected a Blob "); } }; var HFInferenceConversationalTask = class extends HFInferenceTask { makeUrl(params) { let url; if (params.model.startsWith("http://") || params.model.startsWith("https://")) { url = params.model.trim(); } else { url = `${this.makeBaseUrl(params)}/models/${params.model}`; } url = url.replace(/\/+$/, ""); if (url.endsWith("/v1")) { url += "/chat/completions"; } else if (!url.endsWith("/chat/completions")) { url += "/v1/chat/completions"; } return url; } preparePayload(params) { return { ...params.args, model: params.model }; } async getResponse(response) { return response; } }; var HFInferenceTextGenerationTask = class extends HFInferenceTask { async getResponse(response) { const res = toArray(response); if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) { return res?.[0]; } throw new InferenceOutputError("Expected Array<{generated_text: string}>"); } }; var HFInferenceAudioClassificationTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every( (x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number" )) { return response; } throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format"); } }; var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask { async getResponse(response) { return response; } }; var HFInferenceAudioToAudioTask = class extends HFInferenceTask { async getResponse(response) { if (!Array.isArray(response)) { throw new InferenceOutputError("Expected Array"); } if (!response.every((elem) => { return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string"; })) { throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>"); } return response; } }; var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every( (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined") )) { return response[0]; } throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>"); } }; var HFInferenceFeatureExtractionTask = class extends HFInferenceTask { async getResponse(response) { const isNumArrayRec = (arr, maxDepth, curDepth = 0) => { if (curDepth > maxDepth) return false; if (arr.every((x) => Array.isArray(x))) { return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1)); } else { return arr.every((x) => typeof x === "number"); } }; if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) { return response; } throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>"); } }; var HFInferenceImageClassificationTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) { return response; } throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); } }; var HFInferenceImageSegmentationTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) { return response; } throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>"); } }; var HFInferenceImageToTextTask = class extends HFInferenceTask { async getResponse(response) { if (typeof response?.generated_text !== "string") { throw new InferenceOutputError("Expected {generated_text: string}"); } return response; } }; var HFInferenceImageToImageTask = class extends HFInferenceTask { async getResponse(response) { if (response instanceof Blob) { return response; } throw new InferenceOutputError("Expected Blob"); } }; var HFInferenceObjectDetectionTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every( (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number" )) { return response; } throw new InferenceOutputError( "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>" ); } }; var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) { return response; } throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); } }; var HFInferenceTextClassificationTask = class extends HFInferenceTask { async getResponse(response) { const output = response?.[0]; if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) { return output; } throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); } }; var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) ? response.every( (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number" ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") { return Array.isArray(response) ? response[0] : response; } throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>"); } }; var HFInferenceFillMaskTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every( (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string" )) { return response; } throw new InferenceOutputError( "Expected Array<{score: number, sequence: string, token: number, token_str: string}>" ); } }; var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every( (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string" )) { return response; } throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>"); } }; var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every((x) => typeof x === "number")) { return response; } throw new InferenceOutputError("Expected Array<number>"); } }; var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask { static validate(elem) { return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every( (coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number") ); } async getResponse(response) { if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) { return Array.isArray(response) ? response[0] : response; } throw new InferenceOutputError( "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}" ); } }; var HFInferenceTokenClassificationTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every( (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string" )) { return response; } throw new InferenceOutputError( "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>" ); } }; var HFInferenceTranslationTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) { return response?.length === 1 ? response?.[0] : response; } throw new InferenceOutputError("Expected Array<{translation_text: string}>"); } }; var HFInferenceSummarizationTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) { return response?.[0]; } throw new InferenceOutputError("Expected Array<{summary_text: string}>"); } }; var HFInferenceTextToSpeechTask = class extends HFInferenceTask { async getResponse(response) { return response; } }; var HFInferenceTabularClassificationTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every((x) => typeof x === "number")) { return response; } throw new InferenceOutputError("Expected Array<number>"); } }; var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every( (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number" )) { return response[0]; } throw new InferenceOutputError("Expected Array<{answer: string, score: number}>"); } }; var HFInferenceTabularRegressionTask = class extends HFInferenceTask { async getResponse(response) { if (Array.isArray(response) && response.every((x) => typeof x === "number")) { return response; } throw new InferenceOutputError("Expected Array<number>"); } }; var HFInferenceTextToAudioTask = class extends HFInferenceTask { async getResponse(response) { return response; } }; // src/utils/typedInclude.ts function typedInclude(arr, v) { return arr.includes(v); } // src/lib/getInferenceProviderMapping.ts var inferenceProviderMappingCache = /* @__PURE__ */ new Map(); async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) { let inferenceProviderMapping; if (inferenceProviderMappingCache.has(modelId)) { inferenceProviderMapping = inferenceProviderMappingCache.get(modelId); } else { const resp = await (options?.fetch ?? fetch)( `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`, { headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {} } ); if (resp.status === 404) { throw new Error(`Model ${modelId} does not exist`); } inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null); if (inferenceProviderMapping) { inferenceProviderMappingCache.set(modelId, inferenceProviderMapping); } } if (!inferenceProviderMapping) { throw new Error(`We have not been able to find inference provider information for model ${modelId}.`); } return inferenceProviderMapping; } async function getInferenceProviderMapping(params, options) { if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) { return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]; } const inferenceProviderMapping = await fetchInferenceProviderMappingForModel( params.modelId, params.accessToken, options ); const providerMapping = inferenceProviderMapping[params.provider]; if (providerMapping) { const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task]; if (!typedInclude(equivalentTasks, providerMapping.task)) { throw new Error( `Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.` ); } if (providerMapping.status === "staging") { console.warn( `Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.` ); } return { ...providerMapping, hfModelId: params.modelId }; } return null; } async function resolveProvider(provider, modelId, endpointUrl) { if (endpointUrl) { if (provider) { throw new Error("Specifying both endpointUrl and provider is not supported."); } return "hf-inference"; } if (!provider) { console.log( "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers." ); provider = "auto"; } if (provider === "auto") { if (!modelId) { throw new Error("Specifying a model is required when provider is 'auto'"); } const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId); provider = Object.keys(inferenceProviderMapping)[0]; } if (!provider) { throw new Error(`No Inference Provider available for model ${modelId}.`); } return provider; } // src/utils/delay.ts function delay(ms) { return new Promise((resolve) => { setTimeout(() => resolve(), ms); }); } // src/utils/pick.ts function pick(o, props) { return Object.assign( {}, ...props.map((prop) => { if (o[prop] !== void 0) { return { [prop]: o[prop] }; } }) ); } // src/utils/omit.ts function omit(o, props) { const propsArr = Array.isArray(props) ? props : [props]; const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop)); return pick(o, letsKeep); } // src/providers/black-forest-labs.ts var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai"; var BlackForestLabsTextToImageTask = class extends TaskProviderHelper { constructor() { super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL); } preparePayload(params) { return { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, prompt: params.args.inputs }; } prepareHeaders(params, binary) { const headers = { Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}` }; if (!binary) { headers["Content-Type"] = "application/json"; } return headers; } makeRoute(params) { if (!params) { throw new Error("Params are required"); } return `/v1/${params.model}`; } async getResponse(response, url, headers, outputType) { const urlObj = new URL(response.polling_url); for (let step = 0; step < 5; step++) { await delay(1e3); console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`); urlObj.searchParams.set("attempt", step.toString(10)); const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } }); if (!resp.ok) { throw new InferenceOutputError("Failed to fetch result from black forest labs API"); } const payload = await resp.json(); if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") { if (outputType === "url") { return payload.result.sample; } const image = await fetch(payload.result.sample); return await image.blob(); } } throw new InferenceOutputError("Failed to fetch result from black forest labs API"); } }; // src/providers/cerebras.ts var CerebrasConversationalTask = class extends BaseConversationalTask { constructor() { super("cerebras", "https://api.cerebras.ai"); } }; // src/providers/cohere.ts var CohereConversationalTask = class extends BaseConversationalTask { constructor() { super("cohere", "https://api.cohere.com"); } makeRoute() { return "/compatibility/v1/chat/completions"; } }; // src/lib/isUrl.ts function isUrl(modelOrUrl) { return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/"); } // src/providers/fal-ai.ts var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"]; var FalAITask = class extends TaskProviderHelper { constructor(url) { super("fal-ai", url || "https://fal.run"); } preparePayload(params) { return params.args; } makeRoute(params) { return `/${params.model}`; } prepareHeaders(params, binary) { const headers = { Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}` }; if (!binary) { headers["Content-Type"] = "application/json"; } return headers; } }; function buildLoraPath(modelId, adapterWeightsPath) { return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`; } var FalAITextToImageTask = class extends FalAITask { preparePayload(params) { const payload = { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, sync_mode: true, prompt: params.args.inputs }; if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) { payload.loras = [ { path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath), scale: 1 } ]; if (params.mapping.providerId === "fal-ai/lora") { payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0"; } } return payload; } async getResponse(response, outputType) { if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") { if (outputType === "url") { return response.images[0].url; } const urlResponse = await fetch(response.images[0].url); return await urlResponse.blob(); } throw new InferenceOutputError("Expected Fal.ai text-to-image response format"); } }; var FalAITextToVideoTask = class extends FalAITask { constructor() { super("https://queue.fal.run"); } makeRoute(params) { if (params.authMethod !== "provider-key") { return `/${params.model}?_subdomain=queue`; } return `/${params.model}`; } preparePayload(params) { return { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, prompt: params.args.inputs }; } async getResponse(response, url, headers) { if (!url || !headers) { throw new InferenceOutputError("URL and headers are required for text-to-video task"); } const requestId = response.request_id; if (!requestId) { throw new InferenceOutputError("No request ID found in the response"); } let status = response.status; const parsedUrl = new URL(url); const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`; const modelId = new URL(response.response_url).pathname; const queryParams = parsedUrl.search; const statusUrl = `${baseUrl}${modelId}/status${queryParams}`; const resultUrl = `${baseUrl}${modelId}${queryParams}`; while (status !== "COMPLETED") { await delay(500); const statusResponse = await fetch(statusUrl, { headers }); if (!statusResponse.ok) { throw new InferenceOutputError("Failed to fetch response status from fal-ai API"); } try { status = (await statusResponse.json()).status; } catch (error) { throw new InferenceOutputError("Failed to parse status response from fal-ai API"); } } const resultResponse = await fetch(resultUrl, { headers }); let result; try { result = await resultResponse.json(); } catch (error) { throw new InferenceOutputError("Failed to parse result response from fal-ai API"); } if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) { const urlResponse = await fetch(result.video.url); return await urlResponse.blob(); } else { throw new InferenceOutputError( "Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result) ); } } }; var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask { prepareHeaders(params, binary) { const headers = super.prepareHeaders(params, binary); headers["Content-Type"] = "application/json"; return headers; } async getResponse(response) { const res = response; if (typeof res?.text !== "string") { throw new InferenceOutputError( `Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}` ); } return { text: res.text }; } }; var FalAITextToSpeechTask = class extends FalAITask { preparePayload(params) { return { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, text: params.args.inputs }; } async getResponse(response) { const res = response; if (typeof res?.audio?.url !== "string") { throw new InferenceOutputError( `Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}` ); } try { const urlResponse = await fetch(res.audio.url); if (!urlResponse.ok) { throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`); } return await urlResponse.blob(); } catch (error) { throw new InferenceOutputError( `Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}` ); } } }; // src/providers/featherless-ai.ts var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai"; var FeatherlessAIConversationalTask = class extends BaseConversationalTask { constructor() { super("featherless-ai", FEATHERLESS_API_BASE_URL); } }; var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask { constructor() { super("featherless-ai", FEATHERLESS_API_BASE_URL); } preparePayload(params) { return { ...params.args, ...params.args.parameters, model: params.model, prompt: params.args.inputs }; } async getResponse(response) { if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") { const completion = response.choices[0]; return { generated_text: completion.text }; } throw new InferenceOutputError("Expected Featherless AI text generation response format"); } }; // src/providers/fireworks-ai.ts var FireworksConversationalTask = class extends BaseConversationalTask { constructor() { super("fireworks-ai", "https://api.fireworks.ai"); } makeRoute() { return "/inference/v1/chat/completions"; } }; // src/providers/groq.ts var GROQ_API_BASE_URL = "https://api.groq.com"; var GroqTextGenerationTask = class extends BaseTextGenerationTask { constructor() { super("groq", GROQ_API_BASE_URL); } makeRoute() { return "/openai/v1/chat/completions"; } }; var GroqConversationalTask = class extends BaseConversationalTask { constructor() { super("groq", GROQ_API_BASE_URL); } makeRoute() { return "/openai/v1/chat/completions"; } }; // src/providers/hyperbolic.ts var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz"; var HyperbolicConversationalTask = class extends BaseConversationalTask { constructor() { super("hyperbolic", HYPERBOLIC_API_BASE_URL); } }; var HyperbolicTextGenerationTask = class extends BaseTextGenerationTask { constructor() { super("hyperbolic", HYPERBOLIC_API_BASE_URL); } makeRoute() { return "v1/chat/completions"; } preparePayload(params) { return { messages: [{ content: params.args.inputs, role: "user" }], ...params.args.parameters ? { max_tokens: params.args.parameters.max_new_tokens, ...omit(params.args.parameters, "max_new_tokens") } : void 0, ...omit(params.args, ["inputs", "parameters"]), model: params.model }; } async getResponse(response) { if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") { const completion = response.choices[0]; return { generated_text: completion.message.content }; } throw new InferenceOutputError("Expected Hyperbolic text generation response format"); } }; var HyperbolicTextToImageTask = class extends TaskProviderHelper { constructor() { super("hyperbolic", HYPERBOLIC_API_BASE_URL); } makeRoute(params) { return `/v1/images/generations`; } preparePayload(params) { return { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, prompt: params.args.inputs, model_name: params.model }; } async getResponse(response, url, headers, outputType) { if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images[0] && typeof response.images[0].image === "string") { if (outputType === "url") { return `data:image/jpeg;base64,${response.images[0].image}`; } return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob()); } throw new InferenceOutputError("Expected Hyperbolic text-to-image response format"); } }; // src/providers/nebius.ts var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai"; var NebiusConversationalTask = class extends BaseConversationalTask { constructor() { super("nebius", NEBIUS_API_BASE_URL); } }; var NebiusTextGenerationTask = class extends BaseTextGenerationTask { constructor() { super("nebius", NEBIUS_API_BASE_URL); } }; var NebiusTextToImageTask = class extends TaskProviderHelper { constructor() { super("nebius", NEBIUS_API_BASE_URL); } preparePayload(params) { return { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, response_format: "b64_json", prompt: params.args.inputs, model: params.model }; } makeRoute(params) { return "v1/images/generations"; } async getResponse(response, url, headers, outputType) { if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") { const base64Data = response.data[0].b64_json; if (outputType === "url") { return `data:image/jpeg;base64,${base64Data}`; } return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob()); } throw new InferenceOutputError("Expected Nebius text-to-image response format"); } }; // src/providers/novita.ts var NOVITA_API_BASE_URL = "https://api.novita.ai"; var NovitaTextGenerationTask = class extends BaseTextGenerationTask { constructor() { super("novita", NOVITA_API_BASE_URL); } makeRoute() { return "/v3/openai/chat/completions"; } }; var NovitaConversationalTask = class extends BaseConversationalTask { constructor() { super("novita", NOVITA_API_BASE_URL); } makeRoute() { return "/v3/openai/chat/completions"; } }; // src/providers/nscale.ts var NSCALE_API_BASE_URL = "https://inference.api.nscale.com"; var NscaleConversationalTask = class extends BaseConversationalTask { constructor() { super("nscale", NSCALE_API_BASE_URL); } }; var NscaleTextToImageTask = class extends TaskProviderHelper { constructor() { super("nscale", NSCALE_API_BASE_URL); } preparePayload(params) { return { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, response_format: "b64_json", prompt: params.args.inputs, model: params.model }; } makeRoute() { return "v1/images/generations"; } async getResponse(response, url, headers, outputType) { if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") { const base64Data = response.data[0].b64_json; if (outputType === "url") { return `data:image/jpeg;base64,${base64Data}`; } return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob()); } throw new InferenceOutputError("Expected Nscale text-to-image response format"); } }; // src/providers/openai.ts var OPENAI_API_BASE_URL = "https://api.openai.com"; var OpenAIConversationalTask = class extends BaseConversationalTask { constructor() { super("openai", OPENAI_API_BASE_URL, true); } }; // src/providers/ovhcloud.ts var OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net"; var OvhCloudConversationalTask = class extends BaseConversationalTask { constructor() { super("ovhcloud", OVHCLOUD_API_BASE_URL); } }; var OvhCloudTextGenerationTask = class extends BaseTextGenerationTask { constructor() { super("ovhcloud", OVHCLOUD_API_BASE_URL); } preparePayload(params) { return { model: params.model, ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters ? { max_tokens: params.args.parameters.max_new_tokens, ...omit(params.args.parameters, "max_new_tokens") } : void 0, prompt: params.args.inputs }; } async getResponse(response) { if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") { const completion = response.choices[0]; return { generated_text: completion.text }; } throw new InferenceOutputError("Expected OVHcloud text generation response format"); } }; // src/providers/replicate.ts var ReplicateTask = class extends TaskProviderHelper { constructor(url) { super("replicate", url || "https://api.replicate.com"); } makeRoute(params) { if (params.model.includes(":")) { return "v1/predictions"; } return `v1/models/${params.model}/predictions`; } preparePayload(params) { return { input: { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, prompt: params.args.inputs }, version: params.model.includes(":") ? params.model.split(":")[1] : void 0 }; } prepareHeaders(params, binary) { const headers = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" }; if (!binary) { headers["Content-Type"] = "application/json"; } return headers; } makeUrl(params) { const baseUrl = this.makeBaseUrl(params); if (params.model.includes(":")) { return `${baseUrl}/v1/predictions`; } return `${baseUrl}/v1/models/${params.model}/predictions`; } }; var ReplicateTextToImageTask = class extends ReplicateTask { preparePayload(params) { return { input: { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, prompt: params.args.inputs, lora_weights: params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath ? `https://huggingface.co/${params.mapping.hfModelId}` : void 0 }, version: params.model.includes(":") ? params.model.split(":")[1] : void 0 }; } async getResponse(res, url, headers, outputType) { if (typeof res === "object" && "output" in res && Array.isArray(res.output) && res.output.length > 0 && typeof res.output[0] === "string") { if (outputType === "url") { return res.output[0]; } const urlResponse = await fetch(res.output[0]); return await urlResponse.blob(); } throw new InferenceOutputError("Expected Replicate text-to-image response format"); } }; var ReplicateTextToSpeechTask = class extends ReplicateTask { preparePayload(params) { const payload = super.preparePayload(params); const input = payload["input"]; if (typeof input === "object" && input !== null && "prompt" in input) { const inputObj = input; inputObj["text"] = inputObj["prompt"]; delete inputObj["prompt"]; } return payload; } async getResponse(response) { if (response instanceof Blob) { return response; } if (response && typeof response === "object") { if ("output" in response) { if (typeof response.output === "string") { const urlResponse = await fetch(response.output); return await urlResponse.blob(); } else if (Array.isArray(response.output)) { const urlResponse = await fetch(response.output[0]); return await urlResponse.blob(); } } } throw new InferenceOutputError("Expected Blob or object with output"); } }; var ReplicateTextToVideoTask = class extends ReplicateTask { async getResponse(response) { if (typeof response === "object" && !!response && "output" in response && typeof response.output === "string" && isUrl(response.output)) { const urlResponse = await fetch(response.output); return await urlResponse.blob(); } throw new InferenceOutputError("Expected { output: string }"); } }; // src/providers/sambanova.ts var SambanovaConversationalTask = class extends BaseConversationalTask { constructor() { super("sambanova", "https://api.sambanova.ai"); } }; var SambanovaFeatureExtractionTask = class extends TaskProviderHelper { constructor() { super("sambanova", "https://api.sambanova.ai"); } makeRoute() { return `/v1/embeddings`; } async getResponse(response) { if (typeof response === "object" && "data" in response && Array.isArray(response.data)) { return response.data.map((item) => item.embedding); } throw new InferenceOutputError( "Expected Sambanova feature-extraction (embeddings) response format to be {'data' : list of {'embedding' : number[]}}" ); } preparePayload(params) { return { model: params.model, input: params.args.inputs, ...params.args }; } }; // src/providers/together.ts var TOGETHER_API_BASE_URL = "https://api.together.xyz"; var TogetherConversationalTask = class extends BaseConversationalTask { constructor() { super("together", TOGETHER_API_BASE_URL); } }; var TogetherTextGenerationTask = class extends BaseTextGenerationTask { constructor() { super("together", TOGETHER_API_BASE_URL); } preparePayload(params) { return { model: params.model, ...params.args, prompt: params.args.inputs }; } async getResponse(response) { if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") { const completion = response.choices[0]; return { generated_text: completion.text }; } throw new InferenceOutputError("Expected Together text generation response format"); } }; var TogetherTextToImageTask = class extends TaskProviderHelper { constructor() { super("together", TOGETHER_API_BASE_URL); } makeRoute() { return "v1/images/generations"; } preparePayload(params) { return { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, prompt: params.args.inputs, response_format: "base64", model: params.model }; } async getResponse(response, outputType) { if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") { const base64Data = response.data[0].b64_json; if (outputType === "url") { return `data:image/jpeg;base64,${base64Data}`; } return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob()); } throw new InferenceOutputError("Expected Together text-to-image response format"); } }; // src/lib/getProviderHelper.ts var PROVIDERS = { "black-forest-labs": { "text-to-image": new BlackForestLabsTextToImageTask() }, cerebras: { conversational: new CerebrasConversationalTask() }, cohere: { conversational: new CohereConversationalTask() }, "fal-ai": { "text-to-image": new FalAITextToImageTask(), "text-to-speech": new FalAITextToSpeechTask(), "text-to-video": new FalAITextToVideoTask(), "automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask() }, "featherless-ai": { conversational: new FeatherlessAIConversationalTask(), "text-generation": new FeatherlessAITextGenerationTask() }, "hf-inference": { "text-to-image": new HFInferenceTextToImageTask(), conversational: new HFInferenceConversationalTask(), "text-generation": new HFInferenceTextGenerationTask(), "text-classification": new HFInferenceTextClassificationTask(), "question-answering": new HFInferenceQuestionAnsweringTask(), "audio-classification": new HFInferenceAudioClassificationTask(), "automatic-speech-recognition": new HFInferenceAutomaticSpeechRecognitionTask(), "fill-mask": new HFInferenceFillMaskTask(), "feature-extraction": new HFInferenceFeatureExtractionTask(), "image-classification": new HFInferenceImageClassificationTask(), "image-segmentation": new HFInferenceImageSegmentationTask(), "document-question-answering": new HFInferenceDocumentQuestionAnsweringTask(), "image-to-text": new HFInferenceImageToTextTask(), "object-detection": new HFInferenceObjectDetectionTask(), "audio-to-audio": new HFInferenceAudioToAudioTask(), "zero-shot-image-classification": new HFInferenceZeroShotImageClassificationTask(), "zero-shot-classification": new HFInferenceZeroShotClassificationTask(), "image-to-image": new HFInferenceImageToImageTask(), "sentence-similarity": new HFInferenceSentenceSimilarityTask(), "table-question-answering": new HFInferenceTableQuestionAnsweringTask(), "tabular-classification": new HFInferenceTabularClassificationTask(), "text-to-speech": new HFInferenceTextToSpeechTask(), "token-classification": new HFInferenceTokenClassificationTask(), translation: new HFInferenceTranslationTask(), summarization: new HFInferenceSummarizationTask(), "visual-question-answering": new HFInferenceVisualQuestionAnsweringTask(), "tabular-regression": new HFInferenceTabularRegressionTask(), "text-to-audio": new HFInferenceTextToAudioTask() }, "fireworks-ai": { conversational: new FireworksConversationalTask() }, groq: { conversational: new GroqConversationalTask(), "text-generation": new GroqTextGenerationTask() }, hyperbolic: { "text-to-image": new HyperbolicTextToImageTask(), conversational: new HyperbolicConversationalTask(), "text-generation": new HyperbolicTextGenerationTask() }, nebius: { "text-to-image": new NebiusTextToImageTask(), conversational: new NebiusConversationalTask(), "text-generation": new NebiusTextGenerationTask() }, novita: { conversational: new NovitaConversationalTask(), "text-generation": new NovitaTextGenerationTask() }, nscale: { "text-to-image": new NscaleTextToImageTask(), conversational: new NscaleConversationalTask() }, openai: { conversational: new OpenAIConversationalTask() }, ovhcloud: { conversational: new OvhCloudConversationalTask(), "text-generation": new OvhCloudTextGenerationTask() }, replicate: { "text-to-image": new ReplicateTextToImageTask(), "text-to-speech": new ReplicateTextToSpeechTask(), "text-to-video": new ReplicateTextToVideoTask() }, sambanova: { conversational: new SambanovaConversationalTask(), "feature-extraction": new SambanovaFeatureExtractionTask() }, together: { "text-to-image": new TogetherTextToImageTask(), conversational: new TogetherConversationalTask(), "text-generation": new TogetherTextGenerationTask() } }; function getProviderHelper(provider, task) { if (provider === "hf-inference") { if (!task) { return new HFInferenceTask(); } } if (!task) { throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'"); } if (!(provider in PROVIDERS))