UNPKG

@huggingface/inference

Version:

Typescript client for the Hugging Face Inference Providers and Inference Endpoints

164 lines (163 loc) 6.98 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.makeRequestOptions = makeRequestOptions; exports.makeRequestOptionsFromResolvedModel = makeRequestOptionsFromResolvedModel; const config_js_1 = require("../config.js"); const package_js_1 = require("../package.js"); const getInferenceProviderMapping_js_1 = require("./getInferenceProviderMapping.js"); const isUrl_js_1 = require("./isUrl.js"); const errors_js_1 = require("../errors.js"); /** * Lazy-loaded from huggingface.co/api/tasks when needed * Used to determine the default model to use when it's not user defined */ let tasks = null; /** * Helper that prepares request arguments. * This async version handle the model ID resolution step. */ async function makeRequestOptions(args, providerHelper, options) { const { model: maybeModel } = args; const provider = providerHelper.provider; const { task } = options ?? {}; // Validate inputs if (args.endpointUrl && provider !== "hf-inference") { throw new errors_js_1.InferenceClientInputError(`Cannot use endpointUrl with a third-party provider.`); } if (maybeModel && (0, isUrl_js_1.isUrl)(maybeModel)) { throw new errors_js_1.InferenceClientInputError(`Model URLs are no longer supported. Use endpointUrl instead.`); } if (args.endpointUrl) { // No need to have maybeModel, or to load default model for a task return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, providerHelper, args, undefined, options); } if (!maybeModel && !task) { throw new errors_js_1.InferenceClientInputError("No model provided, and no task has been specified."); } // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const hfModel = maybeModel ?? (await loadDefaultModel(task)); if (providerHelper.clientSideRoutingOnly && !maybeModel) { throw new errors_js_1.InferenceClientInputError(`Provider ${provider} requires a model ID to be passed directly.`); } const inferenceProviderMapping = providerHelper.clientSideRoutingOnly ? { provider: provider, // eslint-disable-next-line @typescript-eslint/no-non-null-assertion providerId: removeProviderPrefix(maybeModel, provider), // eslint-disable-next-line @typescript-eslint/no-non-null-assertion hfModelId: maybeModel, status: "live", // eslint-disable-next-line @typescript-eslint/no-non-null-assertion task: task, } : await (0, getInferenceProviderMapping_js_1.getInferenceProviderMapping)({ modelId: hfModel, // eslint-disable-next-line @typescript-eslint/no-non-null-assertion task: task, provider, accessToken: args.accessToken, }, { fetch: options?.fetch }); if (!inferenceProviderMapping) { throw new errors_js_1.InferenceClientInputError(`We have not been able to find inference provider information for model ${hfModel}.`); } // Use the sync version with the resolved model return makeRequestOptionsFromResolvedModel(inferenceProviderMapping.providerId, providerHelper, args, inferenceProviderMapping, options); } /** * Helper that prepares request arguments. - for internal use only * This sync version skips the model ID resolution step */ function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) { const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args; void model; void maybeProvider; const provider = providerHelper.provider; const { includeCredentials, task, signal, billTo } = options ?? {}; const authMethod = (() => { if (providerHelper.clientSideRoutingOnly) { // Closed-source providers require an accessToken (cannot be routed). if (accessToken && accessToken.startsWith("hf_")) { throw new errors_js_1.InferenceClientInputError(`Provider ${provider} is closed-source and does not support HF tokens.`); } } if (accessToken) { return accessToken.startsWith("hf_") ? "hf-token" : "provider-key"; } if (includeCredentials === "include") { // If accessToken is passed, it should take precedence over includeCredentials return "credentials-include"; } return "none"; })(); // Make URL const modelId = endpointUrl ?? resolvedModel; const url = providerHelper.makeUrl({ authMethod, model: modelId, task, }); // Make headers const headers = providerHelper.prepareHeaders({ accessToken, authMethod, }, "data" in args && !!args.data); if (billTo) { headers[config_js_1.HF_HEADER_X_BILL_TO] = billTo; } // Add user-agent to headers // e.g. @huggingface/inference/3.1.3 const ownUserAgent = `${package_js_1.PACKAGE_NAME}/${package_js_1.PACKAGE_VERSION}`; const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined] .filter((x) => x !== undefined) .join(" "); headers["User-Agent"] = userAgent; // Make body const body = providerHelper.makeBody({ args: remainingArgs, model: resolvedModel, task, mapping, }); /** * For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error */ let credentials; if (typeof includeCredentials === "string") { credentials = includeCredentials; } else if (includeCredentials === true) { credentials = "include"; } const info = { headers, method: "POST", body: body, ...(credentials ? { credentials } : undefined), signal, }; return { url, info }; } async function loadDefaultModel(task) { if (!tasks) { tasks = await loadTaskInfo(); } const taskInfo = tasks[task]; if ((taskInfo?.models.length ?? 0) <= 0) { throw new errors_js_1.InferenceClientInputError(`No default model defined for task ${task}, please define the model explicitly.`); } return taskInfo.models[0].id; } async function loadTaskInfo() { const url = `${config_js_1.HF_HUB_URL}/api/tasks`; const res = await fetch(url); if (!res.ok) { throw new errors_js_1.InferenceClientHubApiError("Failed to load tasks definitions from Hugging Face Hub.", { url, method: "GET" }, { requestId: res.headers.get("x-request-id") ?? "", status: res.status, body: await res.text() }); } return await res.json(); } function removeProviderPrefix(model, provider) { if (!model.startsWith(`${provider}/`)) { throw new errors_js_1.InferenceClientInputError(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`); } return model.slice(provider.length + 1); }