@burncloud/inference
Version:
Typescript client for the Hugging Face Inference Providers and Inference Endpoints
82 lines (81 loc) • 4.13 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.inferenceProviderMappingCache = void 0;
exports.fetchInferenceProviderMappingForModel = fetchInferenceProviderMappingForModel;
exports.getInferenceProviderMapping = getInferenceProviderMapping;
exports.resolveProvider = resolveProvider;
const config_js_1 = require("../config.js");
const consts_js_1 = require("../providers/consts.js");
const hf_inference_js_1 = require("../providers/hf-inference.js");
const typedInclude_js_1 = require("../utils/typedInclude.js");
exports.inferenceProviderMappingCache = new Map();
async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) {
let inferenceProviderMapping;
if (exports.inferenceProviderMappingCache.has(modelId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inferenceProviderMapping = exports.inferenceProviderMappingCache.get(modelId);
}
else {
const resp = await (options?.fetch ?? fetch)(`${config_js_1.HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`, {
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
});
if (resp.status === 404) {
throw new Error(`Model ${modelId} does not exist`);
}
inferenceProviderMapping = await resp
.json()
.then((json) => json.inferenceProviderMapping)
.catch(() => null);
if (inferenceProviderMapping) {
exports.inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
}
}
if (!inferenceProviderMapping) {
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
}
return inferenceProviderMapping;
}
async function getInferenceProviderMapping(params, options) {
if (consts_js_1.HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
return consts_js_1.HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(params.modelId, params.accessToken, options);
const providerMapping = inferenceProviderMapping[params.provider];
if (providerMapping) {
const equivalentTasks = params.provider === "hf-inference" && (0, typedInclude_js_1.typedInclude)(hf_inference_js_1.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
? hf_inference_js_1.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
: [params.task];
if (!(0, typedInclude_js_1.typedInclude)(equivalentTasks, providerMapping.task)) {
throw new Error(`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`);
}
if (providerMapping.status === "staging") {
console.warn(`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`);
}
return { ...providerMapping, hfModelId: params.modelId };
}
return null;
}
async function resolveProvider(provider, modelId, endpointUrl) {
if (endpointUrl) {
if (provider) {
throw new Error("Specifying both endpointUrl and provider is not supported.");
}
/// Defaulting to hf-inference helpers / API
return "hf-inference";
}
if (!provider) {
console.log("Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.");
provider = "auto";
}
if (provider === "auto") {
if (!modelId) {
throw new Error("Specifying a model is required when provider is 'auto'");
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
provider = Object.keys(inferenceProviderMapping)[0];
}
if (!provider) {
throw new Error(`No Inference Provider available for model ${modelId}.`);
}
return provider;
}