UNPKG

@huggingface/inference

Version:

Typescript client for the Hugging Face Inference Providers and Inference Endpoints

120 lines (119 loc) 6.41 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.inferenceProviderMappingCache = void 0; exports.fetchInferenceProviderMappingForModel = fetchInferenceProviderMappingForModel; exports.getInferenceProviderMapping = getInferenceProviderMapping; exports.resolveProvider = resolveProvider; const config_js_1 = require("../config.js"); const consts_js_1 = require("../providers/consts.js"); const hf_inference_js_1 = require("../providers/hf-inference.js"); const typedInclude_js_1 = require("../utils/typedInclude.js"); const errors_js_1 = require("../errors.js"); exports.inferenceProviderMappingCache = new Map(); /** * Normalize inferenceProviderMapping to always return an array format. * This provides backward and forward compatibility for the API changes. * * Vendored from @huggingface/hub to avoid extra dependency. */ function normalizeInferenceProviderMapping(modelId, inferenceProviderMapping) { if (!inferenceProviderMapping) { return []; } // If it's already an array, return it as is if (Array.isArray(inferenceProviderMapping)) { return inferenceProviderMapping; } // Convert mapping to array format return Object.entries(inferenceProviderMapping).map(([provider, mapping]) => ({ provider, hfModelId: modelId, providerId: mapping.providerId, status: mapping.status, task: mapping.task, adapter: mapping.adapter, adapterWeightsPath: mapping.adapterWeightsPath, })); } async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) { let inferenceProviderMapping; if (exports.inferenceProviderMappingCache.has(modelId)) { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion inferenceProviderMapping = exports.inferenceProviderMappingCache.get(modelId); } else { const url = `${config_js_1.HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`; const resp = await (options?.fetch ?? fetch)(url, { headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {}, }); if (!resp.ok) { if (resp.headers.get("Content-Type")?.startsWith("application/json")) { const error = await resp.json(); if ("error" in error && typeof error.error === "string") { throw new errors_js_1.InferenceClientHubApiError(`Failed to fetch inference provider mapping for model ${modelId}: ${error.error}`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: error }); } } else { throw new errors_js_1.InferenceClientHubApiError(`Failed to fetch inference provider mapping for model ${modelId}`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }); } } let payload = null; try { payload = await resp.json(); } catch { throw new errors_js_1.InferenceClientHubApiError(`Failed to fetch inference provider mapping for model ${modelId}: malformed API response, invalid JSON`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }); } if (!payload?.inferenceProviderMapping) { throw new errors_js_1.InferenceClientHubApiError(`We have not been able to find inference provider information for model ${modelId}.`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }); } inferenceProviderMapping = normalizeInferenceProviderMapping(modelId, payload.inferenceProviderMapping); exports.inferenceProviderMappingCache.set(modelId, inferenceProviderMapping); } return inferenceProviderMapping; } async function getInferenceProviderMapping(params, options) { if (consts_js_1.HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) { return consts_js_1.HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]; } const mappings = await fetchInferenceProviderMappingForModel(params.modelId, params.accessToken, options); const providerMapping = mappings.find((mapping) => mapping.provider === params.provider); if (providerMapping) { const equivalentTasks = params.provider === "hf-inference" && (0, typedInclude_js_1.typedInclude)(hf_inference_js_1.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? hf_inference_js_1.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task]; if (!(0, typedInclude_js_1.typedInclude)(equivalentTasks, providerMapping.task)) { throw new errors_js_1.InferenceClientInputError(`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`); } if (providerMapping.status === "staging") { console.warn(`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`); } return providerMapping; } return null; } async function resolveProvider(provider, modelId, endpointUrl) { if (endpointUrl) { if (provider) { throw new errors_js_1.InferenceClientInputError("Specifying both endpointUrl and provider is not supported."); } /// Defaulting to hf-inference helpers / API return "hf-inference"; } if (!provider) { console.log("Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."); provider = "auto"; } if (provider === "auto") { if (!modelId) { throw new errors_js_1.InferenceClientInputError("Specifying a model is required when provider is 'auto'"); } const mappings = await fetchInferenceProviderMappingForModel(modelId); provider = mappings[0]?.provider; console.log("Auto selected provider:", provider); } if (!provider) { throw new errors_js_1.InferenceClientInputError(`No Inference Provider available for model ${modelId}.`); } return provider; }