@huggingface/inference
Version:
Typescript client for the Hugging Face Inference Providers and Inference Endpoints
1,389 lines (1,358 loc) • 115 kB
JavaScript
"use strict";
var __defProp = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __export = (target, all) => {
for (var name2 in all)
__defProp(target, name2, { get: all[name2], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames(from))
if (!__hasOwnProp.call(to, key) && key !== except)
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
// src/index.ts
var src_exports = {};
__export(src_exports, {
HfInference: () => HfInference,
INFERENCE_PROVIDERS: () => INFERENCE_PROVIDERS,
InferenceClient: () => InferenceClient,
InferenceClientEndpoint: () => InferenceClientEndpoint,
InferenceOutputError: () => InferenceOutputError,
PROVIDERS_OR_POLICIES: () => PROVIDERS_OR_POLICIES,
audioClassification: () => audioClassification,
audioToAudio: () => audioToAudio,
automaticSpeechRecognition: () => automaticSpeechRecognition,
chatCompletion: () => chatCompletion,
chatCompletionStream: () => chatCompletionStream,
documentQuestionAnswering: () => documentQuestionAnswering,
featureExtraction: () => featureExtraction,
fillMask: () => fillMask,
imageClassification: () => imageClassification,
imageSegmentation: () => imageSegmentation,
imageToImage: () => imageToImage,
imageToText: () => imageToText,
objectDetection: () => objectDetection,
questionAnswering: () => questionAnswering,
request: () => request,
sentenceSimilarity: () => sentenceSimilarity,
snippets: () => snippets_exports,
streamingRequest: () => streamingRequest,
summarization: () => summarization,
tableQuestionAnswering: () => tableQuestionAnswering,
tabularClassification: () => tabularClassification,
tabularRegression: () => tabularRegression,
textClassification: () => textClassification,
textGeneration: () => textGeneration,
textGenerationStream: () => textGenerationStream,
textToImage: () => textToImage,
textToSpeech: () => textToSpeech,
textToVideo: () => textToVideo,
tokenClassification: () => tokenClassification,
translation: () => translation,
visualQuestionAnswering: () => visualQuestionAnswering,
zeroShotClassification: () => zeroShotClassification,
zeroShotImageClassification: () => zeroShotImageClassification
});
module.exports = __toCommonJS(src_exports);
// src/tasks/index.ts
var tasks_exports = {};
__export(tasks_exports, {
audioClassification: () => audioClassification,
audioToAudio: () => audioToAudio,
automaticSpeechRecognition: () => automaticSpeechRecognition,
chatCompletion: () => chatCompletion,
chatCompletionStream: () => chatCompletionStream,
documentQuestionAnswering: () => documentQuestionAnswering,
featureExtraction: () => featureExtraction,
fillMask: () => fillMask,
imageClassification: () => imageClassification,
imageSegmentation: () => imageSegmentation,
imageToImage: () => imageToImage,
imageToText: () => imageToText,
objectDetection: () => objectDetection,
questionAnswering: () => questionAnswering,
request: () => request,
sentenceSimilarity: () => sentenceSimilarity,
streamingRequest: () => streamingRequest,
summarization: () => summarization,
tableQuestionAnswering: () => tableQuestionAnswering,
tabularClassification: () => tabularClassification,
tabularRegression: () => tabularRegression,
textClassification: () => textClassification,
textGeneration: () => textGeneration,
textGenerationStream: () => textGenerationStream,
textToImage: () => textToImage,
textToSpeech: () => textToSpeech,
textToVideo: () => textToVideo,
tokenClassification: () => tokenClassification,
translation: () => translation,
visualQuestionAnswering: () => visualQuestionAnswering,
zeroShotClassification: () => zeroShotClassification,
zeroShotImageClassification: () => zeroShotImageClassification
});
// src/config.ts
var HF_HUB_URL = "https://huggingface.co";
var HF_ROUTER_URL = "https://router.huggingface.co";
var HF_HEADER_X_BILL_TO = "X-HF-Bill-To";
// src/providers/consts.ts
var HARDCODED_MODEL_INFERENCE_MAPPING = {
/**
* "HF model ID" => "Model ID on Inference Provider's side"
*
* Example:
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
*/
"black-forest-labs": {},
cerebras: {},
cohere: {},
"fal-ai": {},
"featherless-ai": {},
"fireworks-ai": {},
groq: {},
"hf-inference": {},
hyperbolic: {},
nebius: {},
novita: {},
nscale: {},
openai: {},
ovhcloud: {},
replicate: {},
sambanova: {},
together: {}
};
// src/lib/InferenceOutputError.ts
var InferenceOutputError = class extends TypeError {
constructor(message) {
super(
`Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
);
this.name = "InferenceOutputError";
}
};
// src/utils/toArray.ts
function toArray(obj) {
if (Array.isArray(obj)) {
return obj;
}
return [obj];
}
// src/providers/providerHelper.ts
var TaskProviderHelper = class {
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
this.provider = provider;
this.baseUrl = baseUrl;
this.clientSideRoutingOnly = clientSideRoutingOnly;
}
/**
* Prepare the base URL for the request
*/
makeBaseUrl(params) {
return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl;
}
/**
* Prepare the body for the request
*/
makeBody(params) {
if ("data" in params.args && !!params.args.data) {
return params.args.data;
}
return JSON.stringify(this.preparePayload(params));
}
/**
* Prepare the URL for the request
*/
makeUrl(params) {
const baseUrl = this.makeBaseUrl(params);
const route = this.makeRoute(params).replace(/^\/+/, "");
return `${baseUrl}/${route}`;
}
/**
* Prepare the headers for the request
*/
prepareHeaders(params, isBinary) {
const headers = { Authorization: `Bearer ${params.accessToken}` };
if (!isBinary) {
headers["Content-Type"] = "application/json";
}
return headers;
}
};
var BaseConversationalTask = class extends TaskProviderHelper {
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
super(provider, baseUrl, clientSideRoutingOnly);
}
makeRoute() {
return "v1/chat/completions";
}
preparePayload(params) {
return {
...params.args,
model: params.model
};
}
async getResponse(response) {
if (typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint
(response.system_fingerprint === void 0 || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object") {
return response;
}
throw new InferenceOutputError("Expected ChatCompletionOutput");
}
};
var BaseTextGenerationTask = class extends TaskProviderHelper {
constructor(provider, baseUrl, clientSideRoutingOnly = false) {
super(provider, baseUrl, clientSideRoutingOnly);
}
preparePayload(params) {
return {
...params.args,
model: params.model
};
}
makeRoute() {
return "v1/completions";
}
async getResponse(response) {
const res = toArray(response);
if (Array.isArray(res) && res.length > 0 && res.every(
(x) => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string"
)) {
return res[0];
}
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
}
};
// src/providers/hf-inference.ts
var EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
var HFInferenceTask = class extends TaskProviderHelper {
constructor() {
super("hf-inference", `${HF_ROUTER_URL}/hf-inference`);
}
preparePayload(params) {
return params.args;
}
makeUrl(params) {
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
return params.model;
}
return super.makeUrl(params);
}
makeRoute(params) {
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
return `models/${params.model}/pipeline/${params.task}`;
}
return `models/${params.model}`;
}
async getResponse(response) {
return response;
}
};
var HFInferenceTextToImageTask = class extends HFInferenceTask {
async getResponse(response, url, headers, outputType) {
if (!response) {
throw new InferenceOutputError("response is undefined");
}
if (typeof response == "object") {
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
const base64Data = response.data[0].b64_json;
if (outputType === "url") {
return `data:image/jpeg;base64,${base64Data}`;
}
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
return await base64Response.blob();
}
if ("output" in response && Array.isArray(response.output)) {
if (outputType === "url") {
return response.output[0];
}
const urlResponse = await fetch(response.output[0]);
const blob = await urlResponse.blob();
return blob;
}
}
if (response instanceof Blob) {
if (outputType === "url") {
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
return `data:image/jpeg;base64,${b64}`;
}
return response;
}
throw new InferenceOutputError("Expected a Blob ");
}
};
var HFInferenceConversationalTask = class extends HFInferenceTask {
makeUrl(params) {
let url;
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
url = params.model.trim();
} else {
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
}
url = url.replace(/\/+$/, "");
if (url.endsWith("/v1")) {
url += "/chat/completions";
} else if (!url.endsWith("/chat/completions")) {
url += "/v1/chat/completions";
}
return url;
}
preparePayload(params) {
return {
...params.args,
model: params.model
};
}
async getResponse(response) {
return response;
}
};
var HFInferenceTextGenerationTask = class extends HFInferenceTask {
async getResponse(response) {
const res = toArray(response);
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
return res?.[0];
}
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
}
};
var HFInferenceAudioClassificationTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every(
(x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number"
)) {
return response;
}
throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format");
}
};
var HFInferenceAutomaticSpeechRecognitionTask = class extends HFInferenceTask {
async getResponse(response) {
return response;
}
};
var HFInferenceAudioToAudioTask = class extends HFInferenceTask {
async getResponse(response) {
if (!Array.isArray(response)) {
throw new InferenceOutputError("Expected Array");
}
if (!response.every((elem) => {
return typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string";
})) {
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
}
return response;
}
};
var HFInferenceDocumentQuestionAnsweringTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every(
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined")
)) {
return response[0];
}
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
}
};
var HFInferenceFeatureExtractionTask = class extends HFInferenceTask {
async getResponse(response) {
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
if (curDepth > maxDepth)
return false;
if (arr.every((x) => Array.isArray(x))) {
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
} else {
return arr.every((x) => typeof x === "number");
}
};
if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
return response;
}
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
}
};
var HFInferenceImageClassificationTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
return response;
}
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
}
};
var HFInferenceImageSegmentationTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number")) {
return response;
}
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
}
};
var HFInferenceImageToTextTask = class extends HFInferenceTask {
async getResponse(response) {
if (typeof response?.generated_text !== "string") {
throw new InferenceOutputError("Expected {generated_text: string}");
}
return response;
}
};
var HFInferenceImageToImageTask = class extends HFInferenceTask {
async getResponse(response) {
if (response instanceof Blob) {
return response;
}
throw new InferenceOutputError("Expected Blob");
}
};
var HFInferenceObjectDetectionTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every(
(x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"
)) {
return response;
}
throw new InferenceOutputError(
"Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>"
);
}
};
var HFInferenceZeroShotImageClassificationTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
return response;
}
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
}
};
var HFInferenceTextClassificationTask = class extends HFInferenceTask {
async getResponse(response) {
const output = response?.[0];
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
return output;
}
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
}
};
var HFInferenceQuestionAnsweringTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) ? response.every(
(elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number"
) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number") {
return Array.isArray(response) ? response[0] : response;
}
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
}
};
var HFInferenceFillMaskTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every(
(x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"
)) {
return response;
}
throw new InferenceOutputError(
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
);
}
};
var HFInferenceZeroShotClassificationTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every(
(x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string"
)) {
return response;
}
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
}
};
var HFInferenceSentenceSimilarityTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
return response;
}
throw new InferenceOutputError("Expected Array<number>");
}
};
var HFInferenceTableQuestionAnsweringTask = class extends HFInferenceTask {
static validate(elem) {
return typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x) => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every(
(coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")
);
}
async getResponse(response) {
if (Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response)) {
return Array.isArray(response) ? response[0] : response;
}
throw new InferenceOutputError(
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
);
}
};
var HFInferenceTokenClassificationTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every(
(x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"
)) {
return response;
}
throw new InferenceOutputError(
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
);
}
};
var HFInferenceTranslationTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
return response?.length === 1 ? response?.[0] : response;
}
throw new InferenceOutputError("Expected Array<{translation_text: string}>");
}
};
var HFInferenceSummarizationTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
return response?.[0];
}
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
}
};
var HFInferenceTextToSpeechTask = class extends HFInferenceTask {
async getResponse(response) {
return response;
}
};
var HFInferenceTabularClassificationTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
return response;
}
throw new InferenceOutputError("Expected Array<number>");
}
};
var HFInferenceVisualQuestionAnsweringTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every(
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
)) {
return response[0];
}
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
}
};
var HFInferenceTabularRegressionTask = class extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
return response;
}
throw new InferenceOutputError("Expected Array<number>");
}
};
var HFInferenceTextToAudioTask = class extends HFInferenceTask {
async getResponse(response) {
return response;
}
};
// src/utils/typedInclude.ts
function typedInclude(arr, v) {
return arr.includes(v);
}
// src/lib/getInferenceProviderMapping.ts
var inferenceProviderMappingCache = /* @__PURE__ */ new Map();
async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) {
let inferenceProviderMapping;
if (inferenceProviderMappingCache.has(modelId)) {
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId);
} else {
const resp = await (options?.fetch ?? fetch)(
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
{
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {}
}
);
if (resp.status === 404) {
throw new Error(`Model ${modelId} does not exist`);
}
inferenceProviderMapping = await resp.json().then((json) => json.inferenceProviderMapping).catch(() => null);
if (inferenceProviderMapping) {
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
}
}
if (!inferenceProviderMapping) {
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
}
return inferenceProviderMapping;
}
async function getInferenceProviderMapping(params, options) {
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
params.modelId,
params.accessToken,
options
);
const providerMapping = inferenceProviderMapping[params.provider];
if (providerMapping) {
const equivalentTasks = params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task];
if (!typedInclude(equivalentTasks, providerMapping.task)) {
throw new Error(
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
);
}
if (providerMapping.status === "staging") {
console.warn(
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
);
}
return { ...providerMapping, hfModelId: params.modelId };
}
return null;
}
async function resolveProvider(provider, modelId, endpointUrl) {
if (endpointUrl) {
if (provider) {
throw new Error("Specifying both endpointUrl and provider is not supported.");
}
return "hf-inference";
}
if (!provider) {
console.log(
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
);
provider = "auto";
}
if (provider === "auto") {
if (!modelId) {
throw new Error("Specifying a model is required when provider is 'auto'");
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
provider = Object.keys(inferenceProviderMapping)[0];
}
if (!provider) {
throw new Error(`No Inference Provider available for model ${modelId}.`);
}
return provider;
}
// src/utils/delay.ts
function delay(ms) {
return new Promise((resolve) => {
setTimeout(() => resolve(), ms);
});
}
// src/utils/pick.ts
function pick(o, props) {
return Object.assign(
{},
...props.map((prop) => {
if (o[prop] !== void 0) {
return { [prop]: o[prop] };
}
})
);
}
// src/utils/omit.ts
function omit(o, props) {
const propsArr = Array.isArray(props) ? props : [props];
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop));
return pick(o, letsKeep);
}
// src/providers/black-forest-labs.ts
var BLACK_FOREST_LABS_AI_API_BASE_URL = "https://api.us1.bfl.ai";
var BlackForestLabsTextToImageTask = class extends TaskProviderHelper {
constructor() {
super("black-forest-labs", BLACK_FOREST_LABS_AI_API_BASE_URL);
}
preparePayload(params) {
return {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
prompt: params.args.inputs
};
}
prepareHeaders(params, binary) {
const headers = {
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `X-Key ${params.accessToken}`
};
if (!binary) {
headers["Content-Type"] = "application/json";
}
return headers;
}
makeRoute(params) {
if (!params) {
throw new Error("Params are required");
}
return `/v1/${params.model}`;
}
async getResponse(response, url, headers, outputType) {
const urlObj = new URL(response.polling_url);
for (let step = 0; step < 5; step++) {
await delay(1e3);
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
urlObj.searchParams.set("attempt", step.toString(10));
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
if (!resp.ok) {
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
}
const payload = await resp.json();
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
if (outputType === "url") {
return payload.result.sample;
}
const image = await fetch(payload.result.sample);
return await image.blob();
}
}
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
}
};
// src/providers/cerebras.ts
var CerebrasConversationalTask = class extends BaseConversationalTask {
constructor() {
super("cerebras", "https://api.cerebras.ai");
}
};
// src/providers/cohere.ts
var CohereConversationalTask = class extends BaseConversationalTask {
constructor() {
super("cohere", "https://api.cohere.com");
}
makeRoute() {
return "/compatibility/v1/chat/completions";
}
};
// src/lib/isUrl.ts
function isUrl(modelOrUrl) {
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
}
// src/providers/fal-ai.ts
var FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];
var FalAITask = class extends TaskProviderHelper {
constructor(url) {
super("fal-ai", url || "https://fal.run");
}
preparePayload(params) {
return params.args;
}
makeRoute(params) {
return `/${params.model}`;
}
prepareHeaders(params, binary) {
const headers = {
Authorization: params.authMethod !== "provider-key" ? `Bearer ${params.accessToken}` : `Key ${params.accessToken}`
};
if (!binary) {
headers["Content-Type"] = "application/json";
}
return headers;
}
};
function buildLoraPath(modelId, adapterWeightsPath) {
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
}
var FalAITextToImageTask = class extends FalAITask {
preparePayload(params) {
const payload = {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
sync_mode: true,
prompt: params.args.inputs
};
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
payload.loras = [
{
path: buildLoraPath(params.mapping.hfModelId, params.mapping.adapterWeightsPath),
scale: 1
}
];
if (params.mapping.providerId === "fal-ai/lora") {
payload.model_name = "stabilityai/stable-diffusion-xl-base-1.0";
}
}
return payload;
}
async getResponse(response, outputType) {
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images.length > 0 && "url" in response.images[0] && typeof response.images[0].url === "string") {
if (outputType === "url") {
return response.images[0].url;
}
const urlResponse = await fetch(response.images[0].url);
return await urlResponse.blob();
}
throw new InferenceOutputError("Expected Fal.ai text-to-image response format");
}
};
var FalAITextToVideoTask = class extends FalAITask {
constructor() {
super("https://queue.fal.run");
}
makeRoute(params) {
if (params.authMethod !== "provider-key") {
return `/${params.model}?_subdomain=queue`;
}
return `/${params.model}`;
}
preparePayload(params) {
return {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
prompt: params.args.inputs
};
}
async getResponse(response, url, headers) {
if (!url || !headers) {
throw new InferenceOutputError("URL and headers are required for text-to-video task");
}
const requestId = response.request_id;
if (!requestId) {
throw new InferenceOutputError("No request ID found in the response");
}
let status = response.status;
const parsedUrl = new URL(url);
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""}`;
const modelId = new URL(response.response_url).pathname;
const queryParams = parsedUrl.search;
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
while (status !== "COMPLETED") {
await delay(500);
const statusResponse = await fetch(statusUrl, { headers });
if (!statusResponse.ok) {
throw new InferenceOutputError("Failed to fetch response status from fal-ai API");
}
try {
status = (await statusResponse.json()).status;
} catch (error) {
throw new InferenceOutputError("Failed to parse status response from fal-ai API");
}
}
const resultResponse = await fetch(resultUrl, { headers });
let result;
try {
result = await resultResponse.json();
} catch (error) {
throw new InferenceOutputError("Failed to parse result response from fal-ai API");
}
if (typeof result === "object" && !!result && "video" in result && typeof result.video === "object" && !!result.video && "url" in result.video && typeof result.video.url === "string" && isUrl(result.video.url)) {
const urlResponse = await fetch(result.video.url);
return await urlResponse.blob();
} else {
throw new InferenceOutputError(
"Expected { video: { url: string } } result format, got instead: " + JSON.stringify(result)
);
}
}
};
var FalAIAutomaticSpeechRecognitionTask = class extends FalAITask {
prepareHeaders(params, binary) {
const headers = super.prepareHeaders(params, binary);
headers["Content-Type"] = "application/json";
return headers;
}
async getResponse(response) {
const res = response;
if (typeof res?.text !== "string") {
throw new InferenceOutputError(
`Expected { text: string } format from Fal.ai Automatic Speech Recognition, got: ${JSON.stringify(response)}`
);
}
return { text: res.text };
}
};
var FalAITextToSpeechTask = class extends FalAITask {
preparePayload(params) {
return {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
text: params.args.inputs
};
}
async getResponse(response) {
const res = response;
if (typeof res?.audio?.url !== "string") {
throw new InferenceOutputError(
`Expected { audio: { url: string } } format from Fal.ai Text-to-Speech, got: ${JSON.stringify(response)}`
);
}
try {
const urlResponse = await fetch(res.audio.url);
if (!urlResponse.ok) {
throw new Error(`Failed to fetch audio from ${res.audio.url}: ${urlResponse.statusText}`);
}
return await urlResponse.blob();
} catch (error) {
throw new InferenceOutputError(
`Error fetching or processing audio from Fal.ai Text-to-Speech URL: ${res.audio.url}. ${error instanceof Error ? error.message : String(error)}`
);
}
}
};
// src/providers/featherless-ai.ts
var FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
var FeatherlessAIConversationalTask = class extends BaseConversationalTask {
constructor() {
super("featherless-ai", FEATHERLESS_API_BASE_URL);
}
};
var FeatherlessAITextGenerationTask = class extends BaseTextGenerationTask {
constructor() {
super("featherless-ai", FEATHERLESS_API_BASE_URL);
}
preparePayload(params) {
return {
...params.args,
...params.args.parameters,
model: params.model,
prompt: params.args.inputs
};
}
async getResponse(response) {
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
const completion = response.choices[0];
return {
generated_text: completion.text
};
}
throw new InferenceOutputError("Expected Featherless AI text generation response format");
}
};
// src/providers/fireworks-ai.ts
var FireworksConversationalTask = class extends BaseConversationalTask {
constructor() {
super("fireworks-ai", "https://api.fireworks.ai");
}
makeRoute() {
return "/inference/v1/chat/completions";
}
};
// src/providers/groq.ts
var GROQ_API_BASE_URL = "https://api.groq.com";
var GroqTextGenerationTask = class extends BaseTextGenerationTask {
constructor() {
super("groq", GROQ_API_BASE_URL);
}
makeRoute() {
return "/openai/v1/chat/completions";
}
};
var GroqConversationalTask = class extends BaseConversationalTask {
constructor() {
super("groq", GROQ_API_BASE_URL);
}
makeRoute() {
return "/openai/v1/chat/completions";
}
};
// src/providers/hyperbolic.ts
var HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";
var HyperbolicConversationalTask = class extends BaseConversationalTask {
constructor() {
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
}
};
var HyperbolicTextGenerationTask = class extends BaseTextGenerationTask {
constructor() {
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
}
makeRoute() {
return "v1/chat/completions";
}
preparePayload(params) {
return {
messages: [{ content: params.args.inputs, role: "user" }],
...params.args.parameters ? {
max_tokens: params.args.parameters.max_new_tokens,
...omit(params.args.parameters, "max_new_tokens")
} : void 0,
...omit(params.args, ["inputs", "parameters"]),
model: params.model
};
}
async getResponse(response) {
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
const completion = response.choices[0];
return {
generated_text: completion.message.content
};
}
throw new InferenceOutputError("Expected Hyperbolic text generation response format");
}
};
var HyperbolicTextToImageTask = class extends TaskProviderHelper {
constructor() {
super("hyperbolic", HYPERBOLIC_API_BASE_URL);
}
makeRoute(params) {
return `/v1/images/generations`;
}
preparePayload(params) {
return {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
prompt: params.args.inputs,
model_name: params.model
};
}
async getResponse(response, url, headers, outputType) {
if (typeof response === "object" && "images" in response && Array.isArray(response.images) && response.images[0] && typeof response.images[0].image === "string") {
if (outputType === "url") {
return `data:image/jpeg;base64,${response.images[0].image}`;
}
return fetch(`data:image/jpeg;base64,${response.images[0].image}`).then((res) => res.blob());
}
throw new InferenceOutputError("Expected Hyperbolic text-to-image response format");
}
};
// src/providers/nebius.ts
var NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
var NebiusConversationalTask = class extends BaseConversationalTask {
constructor() {
super("nebius", NEBIUS_API_BASE_URL);
}
};
var NebiusTextGenerationTask = class extends BaseTextGenerationTask {
constructor() {
super("nebius", NEBIUS_API_BASE_URL);
}
};
var NebiusTextToImageTask = class extends TaskProviderHelper {
constructor() {
super("nebius", NEBIUS_API_BASE_URL);
}
preparePayload(params) {
return {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
response_format: "b64_json",
prompt: params.args.inputs,
model: params.model
};
}
makeRoute(params) {
return "v1/images/generations";
}
async getResponse(response, url, headers, outputType) {
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
const base64Data = response.data[0].b64_json;
if (outputType === "url") {
return `data:image/jpeg;base64,${base64Data}`;
}
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
}
throw new InferenceOutputError("Expected Nebius text-to-image response format");
}
};
// src/providers/novita.ts
var NOVITA_API_BASE_URL = "https://api.novita.ai";
var NovitaTextGenerationTask = class extends BaseTextGenerationTask {
constructor() {
super("novita", NOVITA_API_BASE_URL);
}
makeRoute() {
return "/v3/openai/chat/completions";
}
};
var NovitaConversationalTask = class extends BaseConversationalTask {
constructor() {
super("novita", NOVITA_API_BASE_URL);
}
makeRoute() {
return "/v3/openai/chat/completions";
}
};
// src/providers/nscale.ts
var NSCALE_API_BASE_URL = "https://inference.api.nscale.com";
var NscaleConversationalTask = class extends BaseConversationalTask {
constructor() {
super("nscale", NSCALE_API_BASE_URL);
}
};
var NscaleTextToImageTask = class extends TaskProviderHelper {
constructor() {
super("nscale", NSCALE_API_BASE_URL);
}
preparePayload(params) {
return {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
response_format: "b64_json",
prompt: params.args.inputs,
model: params.model
};
}
makeRoute() {
return "v1/images/generations";
}
async getResponse(response, url, headers, outputType) {
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
const base64Data = response.data[0].b64_json;
if (outputType === "url") {
return `data:image/jpeg;base64,${base64Data}`;
}
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
}
throw new InferenceOutputError("Expected Nscale text-to-image response format");
}
};
// src/providers/openai.ts
var OPENAI_API_BASE_URL = "https://api.openai.com";
var OpenAIConversationalTask = class extends BaseConversationalTask {
constructor() {
super("openai", OPENAI_API_BASE_URL, true);
}
};
// src/providers/ovhcloud.ts
var OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net";
var OvhCloudConversationalTask = class extends BaseConversationalTask {
constructor() {
super("ovhcloud", OVHCLOUD_API_BASE_URL);
}
};
var OvhCloudTextGenerationTask = class extends BaseTextGenerationTask {
constructor() {
super("ovhcloud", OVHCLOUD_API_BASE_URL);
}
preparePayload(params) {
return {
model: params.model,
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters ? {
max_tokens: params.args.parameters.max_new_tokens,
...omit(params.args.parameters, "max_new_tokens")
} : void 0,
prompt: params.args.inputs
};
}
async getResponse(response) {
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
const completion = response.choices[0];
return {
generated_text: completion.text
};
}
throw new InferenceOutputError("Expected OVHcloud text generation response format");
}
};
// src/providers/replicate.ts
var ReplicateTask = class extends TaskProviderHelper {
constructor(url) {
super("replicate", url || "https://api.replicate.com");
}
makeRoute(params) {
if (params.model.includes(":")) {
return "v1/predictions";
}
return `v1/models/${params.model}/predictions`;
}
preparePayload(params) {
return {
input: {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
prompt: params.args.inputs
},
version: params.model.includes(":") ? params.model.split(":")[1] : void 0
};
}
prepareHeaders(params, binary) {
const headers = { Authorization: `Bearer ${params.accessToken}`, Prefer: "wait" };
if (!binary) {
headers["Content-Type"] = "application/json";
}
return headers;
}
makeUrl(params) {
const baseUrl = this.makeBaseUrl(params);
if (params.model.includes(":")) {
return `${baseUrl}/v1/predictions`;
}
return `${baseUrl}/v1/models/${params.model}/predictions`;
}
};
var ReplicateTextToImageTask = class extends ReplicateTask {
preparePayload(params) {
return {
input: {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
prompt: params.args.inputs,
lora_weights: params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath ? `https://huggingface.co/${params.mapping.hfModelId}` : void 0
},
version: params.model.includes(":") ? params.model.split(":")[1] : void 0
};
}
async getResponse(res, url, headers, outputType) {
if (typeof res === "object" && "output" in res && Array.isArray(res.output) && res.output.length > 0 && typeof res.output[0] === "string") {
if (outputType === "url") {
return res.output[0];
}
const urlResponse = await fetch(res.output[0]);
return await urlResponse.blob();
}
throw new InferenceOutputError("Expected Replicate text-to-image response format");
}
};
var ReplicateTextToSpeechTask = class extends ReplicateTask {
preparePayload(params) {
const payload = super.preparePayload(params);
const input = payload["input"];
if (typeof input === "object" && input !== null && "prompt" in input) {
const inputObj = input;
inputObj["text"] = inputObj["prompt"];
delete inputObj["prompt"];
}
return payload;
}
async getResponse(response) {
if (response instanceof Blob) {
return response;
}
if (response && typeof response === "object") {
if ("output" in response) {
if (typeof response.output === "string") {
const urlResponse = await fetch(response.output);
return await urlResponse.blob();
} else if (Array.isArray(response.output)) {
const urlResponse = await fetch(response.output[0]);
return await urlResponse.blob();
}
}
}
throw new InferenceOutputError("Expected Blob or object with output");
}
};
var ReplicateTextToVideoTask = class extends ReplicateTask {
async getResponse(response) {
if (typeof response === "object" && !!response && "output" in response && typeof response.output === "string" && isUrl(response.output)) {
const urlResponse = await fetch(response.output);
return await urlResponse.blob();
}
throw new InferenceOutputError("Expected { output: string }");
}
};
// src/providers/sambanova.ts
var SambanovaConversationalTask = class extends BaseConversationalTask {
constructor() {
super("sambanova", "https://api.sambanova.ai");
}
};
var SambanovaFeatureExtractionTask = class extends TaskProviderHelper {
constructor() {
super("sambanova", "https://api.sambanova.ai");
}
makeRoute() {
return `/v1/embeddings`;
}
async getResponse(response) {
if (typeof response === "object" && "data" in response && Array.isArray(response.data)) {
return response.data.map((item) => item.embedding);
}
throw new InferenceOutputError(
"Expected Sambanova feature-extraction (embeddings) response format to be {'data' : list of {'embedding' : number[]}}"
);
}
preparePayload(params) {
return {
model: params.model,
input: params.args.inputs,
...params.args
};
}
};
// src/providers/together.ts
var TOGETHER_API_BASE_URL = "https://api.together.xyz";
var TogetherConversationalTask = class extends BaseConversationalTask {
constructor() {
super("together", TOGETHER_API_BASE_URL);
}
};
var TogetherTextGenerationTask = class extends BaseTextGenerationTask {
constructor() {
super("together", TOGETHER_API_BASE_URL);
}
preparePayload(params) {
return {
model: params.model,
...params.args,
prompt: params.args.inputs
};
}
async getResponse(response) {
if (typeof response === "object" && "choices" in response && Array.isArray(response?.choices) && typeof response?.model === "string") {
const completion = response.choices[0];
return {
generated_text: completion.text
};
}
throw new InferenceOutputError("Expected Together text generation response format");
}
};
var TogetherTextToImageTask = class extends TaskProviderHelper {
constructor() {
super("together", TOGETHER_API_BASE_URL);
}
makeRoute() {
return "v1/images/generations";
}
preparePayload(params) {
return {
...omit(params.args, ["inputs", "parameters"]),
...params.args.parameters,
prompt: params.args.inputs,
response_format: "base64",
model: params.model
};
}
async getResponse(response, outputType) {
if (typeof response === "object" && "data" in response && Array.isArray(response.data) && response.data.length > 0 && "b64_json" in response.data[0] && typeof response.data[0].b64_json === "string") {
const base64Data = response.data[0].b64_json;
if (outputType === "url") {
return `data:image/jpeg;base64,${base64Data}`;
}
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
}
throw new InferenceOutputError("Expected Together text-to-image response format");
}
};
// src/lib/getProviderHelper.ts
var PROVIDERS = {
"black-forest-labs": {
"text-to-image": new BlackForestLabsTextToImageTask()
},
cerebras: {
conversational: new CerebrasConversationalTask()
},
cohere: {
conversational: new CohereConversationalTask()
},
"fal-ai": {
"text-to-image": new FalAITextToImageTask(),
"text-to-speech": new FalAITextToSpeechTask(),
"text-to-video": new FalAITextToVideoTask(),
"automatic-speech-recognition": new FalAIAutomaticSpeechRecognitionTask()
},
"featherless-ai": {
conversational: new FeatherlessAIConversationalTask(),
"text-generation": new FeatherlessAITextGenerationTask()
},
"hf-inference": {
"text-to-image": new HFInferenceTextToImageTask(),
conversational: new HFInferenceConversationalTask(),
"text-generation": new HFInferenceTextGenerationTask(),
"text-classification": new HFInferenceTextClassificationTask(),
"question-answering": new HFInferenceQuestionAnsweringTask(),
"audio-classification": new HFInferenceAudioClassificationTask(),
"automatic-speech-recognition": new HFInferenceAutomaticSpeechRecognitionTask(),
"fill-mask": new HFInferenceFillMaskTask(),
"feature-extraction": new HFInferenceFeatureExtractionTask(),
"image-classification": new HFInferenceImageClassificationTask(),
"image-segmentation": new HFInferenceImageSegmentationTask(),
"document-question-answering": new HFInferenceDocumentQuestionAnsweringTask(),
"image-to-text": new HFInferenceImageToTextTask(),
"object-detection": new HFInferenceObjectDetectionTask(),
"audio-to-audio": new HFInferenceAudioToAudioTask(),
"zero-shot-image-classification": new HFInferenceZeroShotImageClassificationTask(),
"zero-shot-classification": new HFInferenceZeroShotClassificationTas