UNPKG

@burncloud/inference

Version:

Typescript client for the Hugging Face Inference Providers and Inference Endpoints

82 lines (81 loc) 4.13 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.inferenceProviderMappingCache = void 0; exports.fetchInferenceProviderMappingForModel = fetchInferenceProviderMappingForModel; exports.getInferenceProviderMapping = getInferenceProviderMapping; exports.resolveProvider = resolveProvider; const config_js_1 = require("../config.js"); const consts_js_1 = require("../providers/consts.js"); const hf_inference_js_1 = require("../providers/hf-inference.js"); const typedInclude_js_1 = require("../utils/typedInclude.js"); exports.inferenceProviderMappingCache = new Map(); async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) { let inferenceProviderMapping; if (exports.inferenceProviderMappingCache.has(modelId)) { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion inferenceProviderMapping = exports.inferenceProviderMappingCache.get(modelId); } else { const resp = await (options?.fetch ?? fetch)(`${config_js_1.HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`, { headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {}, }); if (resp.status === 404) { throw new Error(`Model ${modelId} does not exist`); } inferenceProviderMapping = await resp .json() .then((json) => json.inferenceProviderMapping) .catch(() => null); if (inferenceProviderMapping) { exports.inferenceProviderMappingCache.set(modelId, inferenceProviderMapping); } } if (!inferenceProviderMapping) { throw new Error(`We have not been able to find inference provider information for model ${modelId}.`); } return inferenceProviderMapping; } async function getInferenceProviderMapping(params, options) { if (consts_js_1.HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) { return consts_js_1.HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]; } const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(params.modelId, params.accessToken, options); const providerMapping = inferenceProviderMapping[params.provider]; if (providerMapping) { const equivalentTasks = params.provider === "hf-inference" && (0, typedInclude_js_1.typedInclude)(hf_inference_js_1.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task) ? hf_inference_js_1.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS : [params.task]; if (!(0, typedInclude_js_1.typedInclude)(equivalentTasks, providerMapping.task)) { throw new Error(`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`); } if (providerMapping.status === "staging") { console.warn(`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`); } return { ...providerMapping, hfModelId: params.modelId }; } return null; } async function resolveProvider(provider, modelId, endpointUrl) { if (endpointUrl) { if (provider) { throw new Error("Specifying both endpointUrl and provider is not supported."); } /// Defaulting to hf-inference helpers / API return "hf-inference"; } if (!provider) { console.log("Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."); provider = "auto"; } if (provider === "auto") { if (!modelId) { throw new Error("Specifying a model is required when provider is 'auto'"); } const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId); provider = Object.keys(inferenceProviderMapping)[0]; } if (!provider) { throw new Error(`No Inference Provider available for model ${modelId}.`); } return provider; }