UNPKG

@huggingface/inference

Version:

Typescript client for the Hugging Face Inference Providers and Inference Endpoints

123 lines (122 loc) 5.44 kB
/** * See the registered mapping of HF model ID => ZAI model ID here: * * https://huggingface.co/api/partners/zai-org/models * * This is a publicly available mapping. * * If you want to try to run inference for a new model locally before it's registered on huggingface.co, * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. * * - If you work at zai and want to update this mapping, please use the model mapping API we provide on huggingface.co * - If you're a community member and want to add a new supported HF model to zai, please open an issue on the present repo * and we will tag zai team members. * * Thanks! */ import { InferenceClientInputError, InferenceClientProviderApiError, InferenceClientProviderOutputError, } from "../errors.js"; import { isUrl } from "../lib/isUrl.js"; import { delay } from "../utils/delay.js"; import { omit } from "../utils/omit.js"; import { BaseConversationalTask, TaskProviderHelper } from "./providerHelper.js"; const ZAI_API_BASE_URL = "https://api.z.ai"; export class ZaiConversationalTask extends BaseConversationalTask { constructor() { super("zai-org", ZAI_API_BASE_URL); } prepareHeaders(params, binary) { const headers = super.prepareHeaders(params, binary); headers["x-source-channel"] = "hugging_face"; headers["accept-language"] = "en-US,en"; return headers; } makeRoute() { return "/api/paas/v4/chat/completions"; } } const MAX_POLL_ATTEMPTS = 60; const POLL_INTERVAL_MS = 5000; export class ZaiTextToImageTask extends TaskProviderHelper { constructor() { super("zai-org", ZAI_API_BASE_URL); } prepareHeaders(params, binary) { const headers = { Authorization: `Bearer ${params.accessToken}`, "x-source-channel": "hugging_face", "accept-language": "en-US,en", }; if (!binary) { headers["Content-Type"] = "application/json"; } return headers; } makeRoute() { return "/api/paas/v4/async/images/generations"; } preparePayload(params) { return { ...omit(params.args, ["inputs", "parameters"]), ...params.args.parameters, model: params.model, prompt: params.args.inputs, }; } async getResponse(response, url, headers, outputType) { if (!url || !headers) { throw new InferenceClientInputError(`URL and headers are required for 'text-to-image' task`); } if (typeof response !== "object" || !response || !("task_status" in response) || !("id" in response) || typeof response.id !== "string") { throw new InferenceClientProviderOutputError(`Received malformed response from ZAI text-to-image API: expected { id: string, task_status: string }, got: ${JSON.stringify(response)}`); } if (response.task_status === "FAIL") { throw new InferenceClientProviderOutputError("ZAI API returned task status: FAIL"); } const taskId = response.id; const parsedUrl = new URL(url); const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/zai-org" : ""}`; const pollUrl = `${baseUrl}/api/paas/v4/async-result/${taskId}`; const pollHeaders = { ...headers, "x-source-channel": "hugging_face", "accept-language": "en-US,en", }; for (let attempt = 0; attempt < MAX_POLL_ATTEMPTS; attempt++) { await delay(POLL_INTERVAL_MS); const resp = await fetch(pollUrl, { method: "GET", headers: pollHeaders, }); if (!resp.ok) { throw new InferenceClientProviderApiError(`Failed to fetch result from ZAI text-to-image API: ${resp.status}`, { url: pollUrl, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }); } const result = await resp.json(); if (result.task_status === "FAIL") { throw new InferenceClientProviderOutputError("ZAI text-to-image API task failed"); } if (result.task_status === "SUCCESS") { if (!result.image_result || !Array.isArray(result.image_result) || result.image_result.length === 0 || typeof result.image_result[0]?.url !== "string" || !isUrl(result.image_result[0].url)) { throw new InferenceClientProviderOutputError(`Received malformed response from ZAI text-to-image API: expected { image_result: Array<{ url: string }> }, got: ${JSON.stringify(result)}`); } const imageUrl = result.image_result[0].url; if (outputType === "json") { return { ...result }; } if (outputType === "url") { return imageUrl; } const imageResponse = await fetch(imageUrl); return await imageResponse.blob(); } } throw new InferenceClientProviderOutputError(`Timed out while waiting for the result from ZAI API - aborting after ${MAX_POLL_ATTEMPTS} attempts`); } }