UNPKG

@huggingface/inference

Version:

Typescript client for the Hugging Face Inference Providers and Inference Endpoints

224 lines (201 loc) 6.66 kB
import type { TextToImageArgs } from "../tasks/cv/textToImage.js"; import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js"; import type { TextToVideoArgs } from "../tasks/cv/textToVideo.js"; import type { ImageToVideoArgs } from "../tasks/cv/imageToVideo.js"; import type { BodyParams, RequestArgs, UrlParams } from "../types.js"; import { delay } from "../utils/delay.js"; import { omit } from "../utils/omit.js"; import { base64FromBytes } from "../utils/base64FromBytes.js"; import type { TextToImageTaskHelper, TextToVideoTaskHelper, ImageToImageTaskHelper, ImageToVideoTaskHelper, } from "./providerHelper.js"; import { TaskProviderHelper } from "./providerHelper.js"; import { InferenceClientInputError, InferenceClientProviderApiError, InferenceClientProviderOutputError, } from "../errors.js"; const WAVESPEEDAI_API_BASE_URL = "https://api.wavespeed.ai"; /** * Response structure for task status and results */ interface WaveSpeedAITaskResponse { id: string; model: string; outputs: string[]; urls: { get: string; }; has_nsfw_contents: boolean[]; status: "created" | "processing" | "completed" | "failed"; created_at: string; error: string; executionTime: number; timings: { inference: number; }; } /** * Response structure for initial task submission */ interface WaveSpeedAISubmitResponse { id: string; urls: { get: string; }; } /** * Response structure for WaveSpeed AI API */ interface WaveSpeedAIResponse { code: number; message: string; data: WaveSpeedAITaskResponse; } /** * Response structure for WaveSpeed AI API with submit response data */ interface WaveSpeedAISubmitTaskResponse { code: number; message: string; data: WaveSpeedAISubmitResponse; } async function buildImagesField( inputs: Blob | ArrayBuffer, hasImages: unknown ): Promise<{ base: string; images: string[] }> { const base = base64FromBytes( new Uint8Array(inputs instanceof ArrayBuffer ? inputs : await (inputs as Blob).arrayBuffer()) ); const images = Array.isArray(hasImages) && hasImages.every((value): value is string => typeof value === "string") ? hasImages : [base]; return { base, images }; } abstract class WavespeedAITask extends TaskProviderHelper { constructor(url?: string) { super("wavespeed", url || WAVESPEEDAI_API_BASE_URL); } makeRoute(params: UrlParams): string { return `/api/v3/${params.model}`; } preparePayload( params: BodyParams<ImageToImageArgs | TextToImageArgs | TextToVideoArgs | ImageToVideoArgs> ): Record<string, unknown> { const payload: Record<string, unknown> = { ...omit(params.args, ["inputs", "parameters"]), ...(params.args.parameters ? omit(params.args.parameters as Record<string, unknown>, ["images"]) : undefined), prompt: params.args.inputs, }; // Add LoRA support if adapter is specified in the mapping if (params.mapping?.adapter === "lora") { payload.loras = [ { path: params.mapping.hfModelId, scale: 1, // Default scale value }, ]; } return payload; } override async getResponse( response: WaveSpeedAISubmitTaskResponse, url?: string, headers?: Record<string, string> ): Promise<Blob> { if (!url || !headers) { throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls"); } const parsedUrl = new URL(url); const resultPath = new URL(response.data.urls.get).pathname; /// override the base url to use the router.huggingface.co if going through huggingface router const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${ parsedUrl.host === "router.huggingface.co" ? "/wavespeed" : "" }`; const resultUrl = `${baseUrl}${resultPath}`; // Poll for results until completion while (true) { const resultResponse = await fetch(resultUrl, { headers }); if (!resultResponse.ok) { throw new InferenceClientProviderApiError( "Failed to fetch response status from WaveSpeed AI API", { url: resultUrl, method: "GET" }, { requestId: resultResponse.headers.get("x-request-id") ?? "", status: resultResponse.status, body: await resultResponse.text(), } ); } const result: WaveSpeedAIResponse = await resultResponse.json(); const taskResult = result.data; switch (taskResult.status) { case "completed": { // Get the media data from the first output URL if (!taskResult.outputs?.[0]) { throw new InferenceClientProviderOutputError( "Received malformed response from WaveSpeed AI API: No output URL in completed response" ); } const mediaResponse = await fetch(taskResult.outputs[0]); if (!mediaResponse.ok) { throw new InferenceClientProviderApiError( "Failed to fetch generation output from WaveSpeed AI API", { url: taskResult.outputs[0], method: "GET" }, { requestId: mediaResponse.headers.get("x-request-id") ?? "", status: mediaResponse.status, body: await mediaResponse.text(), } ); } return await mediaResponse.blob(); } case "failed": { throw new InferenceClientProviderOutputError(taskResult.error || "Task failed"); } default: { // Wait before polling again await delay(500); continue; } } } } } export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper { constructor() { super(WAVESPEEDAI_API_BASE_URL); } } export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper { constructor() { super(WAVESPEEDAI_API_BASE_URL); } } export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper { constructor() { super(WAVESPEEDAI_API_BASE_URL); } async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> { const hasImages = (args as { images?: unknown }).images ?? (args.parameters as Record<string, unknown> | undefined)?.images; const { base, images } = await buildImagesField(args.inputs as Blob | ArrayBuffer, hasImages); return { ...args, inputs: args.parameters?.prompt, image: base, images }; } } export class WavespeedAIImageToVideoTask extends WavespeedAITask implements ImageToVideoTaskHelper { constructor() { super(WAVESPEEDAI_API_BASE_URL); } async preparePayloadAsync(args: ImageToVideoArgs): Promise<RequestArgs> { const hasImages = (args as { images?: unknown }).images ?? (args.parameters as Record<string, unknown> | undefined)?.images; const { base, images } = await buildImagesField(args.inputs as Blob | ArrayBuffer, hasImages); return { ...args, inputs: args.parameters?.prompt, image: base, images }; } }