@huggingface/inference
Version:
Typescript client for the Hugging Face Inference Providers and Inference Endpoints
124 lines (123 loc) • 5.79 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.WavespeedAIImageToVideoTask = exports.WavespeedAIImageToImageTask = exports.WavespeedAITextToVideoTask = exports.WavespeedAITextToImageTask = void 0;
const delay_js_1 = require("../utils/delay.js");
const omit_js_1 = require("../utils/omit.js");
const base64FromBytes_js_1 = require("../utils/base64FromBytes.js");
const providerHelper_js_1 = require("./providerHelper.js");
const errors_js_1 = require("../errors.js");
const WAVESPEEDAI_API_BASE_URL = "https://api.wavespeed.ai";
async function buildImagesField(inputs, hasImages) {
const base = (0, base64FromBytes_js_1.base64FromBytes)(new Uint8Array(inputs instanceof ArrayBuffer ? inputs : await inputs.arrayBuffer()));
const images = Array.isArray(hasImages) && hasImages.every((value) => typeof value === "string")
? hasImages
: [base];
return { base, images };
}
class WavespeedAITask extends providerHelper_js_1.TaskProviderHelper {
constructor(url) {
super("wavespeed", url || WAVESPEEDAI_API_BASE_URL);
}
makeRoute(params) {
return `/api/v3/${params.model}`;
}
preparePayload(params) {
const payload = {
...(0, omit_js_1.omit)(params.args, ["inputs", "parameters"]),
...(params.args.parameters ? (0, omit_js_1.omit)(params.args.parameters, ["images"]) : undefined),
prompt: params.args.inputs,
};
// Add LoRA support if adapter is specified in the mapping
if (params.mapping?.adapter === "lora") {
payload.loras = [
{
path: params.mapping.hfModelId,
scale: 1, // Default scale value
},
];
}
return payload;
}
async getResponse(response, url, headers) {
if (!url || !headers) {
throw new errors_js_1.InferenceClientInputError("Headers are required for WaveSpeed AI API calls");
}
const parsedUrl = new URL(url);
const resultPath = new URL(response.data.urls.get).pathname;
/// override the base url to use the router.huggingface.co if going through huggingface router
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/wavespeed" : ""}`;
const resultUrl = `${baseUrl}${resultPath}`;
// Poll for results until completion
while (true) {
const resultResponse = await fetch(resultUrl, { headers });
if (!resultResponse.ok) {
throw new errors_js_1.InferenceClientProviderApiError("Failed to fetch response status from WaveSpeed AI API", { url: resultUrl, method: "GET" }, {
requestId: resultResponse.headers.get("x-request-id") ?? "",
status: resultResponse.status,
body: await resultResponse.text(),
});
}
const result = await resultResponse.json();
const taskResult = result.data;
switch (taskResult.status) {
case "completed": {
// Get the media data from the first output URL
if (!taskResult.outputs?.[0]) {
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from WaveSpeed AI API: No output URL in completed response");
}
const mediaResponse = await fetch(taskResult.outputs[0]);
if (!mediaResponse.ok) {
throw new errors_js_1.InferenceClientProviderApiError("Failed to fetch generation output from WaveSpeed AI API", { url: taskResult.outputs[0], method: "GET" }, {
requestId: mediaResponse.headers.get("x-request-id") ?? "",
status: mediaResponse.status,
body: await mediaResponse.text(),
});
}
return await mediaResponse.blob();
}
case "failed": {
throw new errors_js_1.InferenceClientProviderOutputError(taskResult.error || "Task failed");
}
default: {
// Wait before polling again
await (0, delay_js_1.delay)(500);
continue;
}
}
}
}
}
class WavespeedAITextToImageTask extends WavespeedAITask {
constructor() {
super(WAVESPEEDAI_API_BASE_URL);
}
}
exports.WavespeedAITextToImageTask = WavespeedAITextToImageTask;
class WavespeedAITextToVideoTask extends WavespeedAITask {
constructor() {
super(WAVESPEEDAI_API_BASE_URL);
}
}
exports.WavespeedAITextToVideoTask = WavespeedAITextToVideoTask;
class WavespeedAIImageToImageTask extends WavespeedAITask {
constructor() {
super(WAVESPEEDAI_API_BASE_URL);
}
async preparePayloadAsync(args) {
const hasImages = args.images ?? args.parameters?.images;
const { base, images } = await buildImagesField(args.inputs, hasImages);
return { ...args, inputs: args.parameters?.prompt, image: base, images };
}
}
exports.WavespeedAIImageToImageTask = WavespeedAIImageToImageTask;
class WavespeedAIImageToVideoTask extends WavespeedAITask {
constructor() {
super(WAVESPEEDAI_API_BASE_URL);
}
async preparePayloadAsync(args) {
const hasImages = args.images ?? args.parameters?.images;
const { base, images } = await buildImagesField(args.inputs, hasImages);
return { ...args, inputs: args.parameters?.prompt, image: base, images };
}
}
exports.WavespeedAIImageToVideoTask = WavespeedAIImageToVideoTask;