UNPKG

@huggingface/inference

Version:

Typescript client for the Hugging Face Inference Providers and Inference Endpoints

109 lines (108 loc) 3.78 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.BaseTextGenerationTask = exports.BaseConversationalTask = exports.TaskProviderHelper = void 0; const config_js_1 = require("../config.js"); const errors_js_1 = require("../errors.js"); const toArray_js_1 = require("../utils/toArray.js"); /** * Base class for task-specific provider helpers */ class TaskProviderHelper { provider; baseUrl; clientSideRoutingOnly; constructor(provider, baseUrl, clientSideRoutingOnly = false) { this.provider = provider; this.baseUrl = baseUrl; this.clientSideRoutingOnly = clientSideRoutingOnly; } /** * Prepare the base URL for the request */ makeBaseUrl(params) { return params.authMethod !== "provider-key" ? `${config_js_1.HF_ROUTER_URL}/${this.provider}` : this.baseUrl; } /** * Prepare the body for the request */ makeBody(params) { if ("data" in params.args && !!params.args.data) { return params.args.data; } return JSON.stringify(this.preparePayload(params)); } /** * Prepare the URL for the request */ makeUrl(params) { const baseUrl = this.makeBaseUrl(params); const route = this.makeRoute(params).replace(/^\/+/, ""); return `${baseUrl}/${route}`; } /** * Prepare the headers for the request */ prepareHeaders(params, isBinary) { const headers = { Authorization: `Bearer ${params.accessToken}` }; if (!isBinary) { headers["Content-Type"] = "application/json"; } return headers; } } exports.TaskProviderHelper = TaskProviderHelper; // BASE IMPLEMENTATIONS FOR COMMON PATTERNS class BaseConversationalTask extends TaskProviderHelper { constructor(provider, baseUrl, clientSideRoutingOnly = false) { super(provider, baseUrl, clientSideRoutingOnly); } makeRoute() { return "v1/chat/completions"; } preparePayload(params) { return { ...params.args, model: params.model, }; } async getResponse(response) { if (typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint (response.system_fingerprint === undefined || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object") { return response; } throw new errors_js_1.InferenceClientProviderOutputError("Expected ChatCompletionOutput"); } } exports.BaseConversationalTask = BaseConversationalTask; class BaseTextGenerationTask extends TaskProviderHelper { constructor(provider, baseUrl, clientSideRoutingOnly = false) { super(provider, baseUrl, clientSideRoutingOnly); } preparePayload(params) { return { ...params.args, model: params.model, }; } makeRoute() { return "v1/completions"; } async getResponse(response) { const res = (0, toArray_js_1.toArray)(response); if (Array.isArray(res) && res.length > 0 && res.every((x) => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string")) { return res[0]; } throw new errors_js_1.InferenceClientProviderOutputError("Expected Array<{generated_text: string}>"); } } exports.BaseTextGenerationTask = BaseTextGenerationTask;