@huggingface/inference
Version:
Typescript client for the Hugging Face Inference Providers and Inference Endpoints
430 lines (429 loc) • 22.2 kB
JavaScript
Object.defineProperty(exports, "__esModule", { value: true });
exports.HFInferenceTextToAudioTask = exports.HFInferenceTabularRegressionTask = exports.HFInferenceVisualQuestionAnsweringTask = exports.HFInferenceTabularClassificationTask = exports.HFInferenceTextToSpeechTask = exports.HFInferenceSummarizationTask = exports.HFInferenceTranslationTask = exports.HFInferenceTokenClassificationTask = exports.HFInferenceTableQuestionAnsweringTask = exports.HFInferenceSentenceSimilarityTask = exports.HFInferenceZeroShotClassificationTask = exports.HFInferenceFillMaskTask = exports.HFInferenceQuestionAnsweringTask = exports.HFInferenceTextClassificationTask = exports.HFInferenceZeroShotImageClassificationTask = exports.HFInferenceObjectDetectionTask = exports.HFInferenceImageToImageTask = exports.HFInferenceImageToTextTask = exports.HFInferenceImageSegmentationTask = exports.HFInferenceImageClassificationTask = exports.HFInferenceFeatureExtractionTask = exports.HFInferenceDocumentQuestionAnsweringTask = exports.HFInferenceAudioToAudioTask = exports.HFInferenceAutomaticSpeechRecognitionTask = exports.HFInferenceAudioClassificationTask = exports.HFInferenceTextGenerationTask = exports.HFInferenceConversationalTask = exports.HFInferenceTextToImageTask = exports.HFInferenceTask = exports.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = void 0;
const config_js_1 = require("../config.js");
const errors_js_1 = require("../errors.js");
const toArray_js_1 = require("../utils/toArray.js");
const providerHelper_js_1 = require("./providerHelper.js");
const base64FromBytes_js_1 = require("../utils/base64FromBytes.js");
const omit_js_1 = require("../utils/omit.js");
exports.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"];
class HFInferenceTask extends providerHelper_js_1.TaskProviderHelper {
constructor() {
super("hf-inference", `${config_js_1.HF_ROUTER_URL}/hf-inference`);
}
preparePayload(params) {
return params.args;
}
makeUrl(params) {
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
return params.model;
}
return super.makeUrl(params);
}
makeRoute(params) {
if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) {
// when deployed on hf-inference, those two tasks are automatically compatible with one another.
return `models/${params.model}/pipeline/${params.task}`;
}
return `models/${params.model}`;
}
async getResponse(response) {
return response;
}
}
exports.HFInferenceTask = HFInferenceTask;
class HFInferenceTextToImageTask extends HFInferenceTask {
async getResponse(response, url, headers, outputType) {
if (!response) {
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference text-to-image API: response is undefined");
}
if (typeof response == "object") {
if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) {
const base64Data = response.data[0].b64_json;
if (outputType === "url") {
return `data:image/jpeg;base64,${base64Data}`;
}
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
return await base64Response.blob();
}
if ("output" in response && Array.isArray(response.output)) {
if (outputType === "url") {
return response.output[0];
}
const urlResponse = await fetch(response.output[0]);
const blob = await urlResponse.blob();
return blob;
}
}
if (response instanceof Blob) {
if (outputType === "url") {
const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64"));
return `data:image/jpeg;base64,${b64}`;
}
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference text-to-image API: expected a Blob");
}
}
exports.HFInferenceTextToImageTask = HFInferenceTextToImageTask;
class HFInferenceConversationalTask extends HFInferenceTask {
makeUrl(params) {
let url;
if (params.model.startsWith("http://") || params.model.startsWith("https://")) {
url = params.model.trim();
}
else {
url = `${this.makeBaseUrl(params)}/models/${params.model}`;
}
url = url.replace(/\/+$/, "");
if (url.endsWith("/v1")) {
url += "/chat/completions";
}
else if (!url.endsWith("/chat/completions")) {
url += "/v1/chat/completions";
}
return url;
}
preparePayload(params) {
return {
...params.args,
model: params.model,
};
}
async getResponse(response) {
return response;
}
}
exports.HFInferenceConversationalTask = HFInferenceConversationalTask;
class HFInferenceTextGenerationTask extends HFInferenceTask {
async getResponse(response) {
const res = (0, toArray_js_1.toArray)(response);
if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) {
return res?.[0];
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference text generation API: expected Array<{generated_text: string}>");
}
}
exports.HFInferenceTextGenerationTask = HFInferenceTextGenerationTask;
class HFInferenceAudioClassificationTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) &&
response.every((x) => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number")) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference audio-classification API: expected Array<{label: string, score: number}> but received different format");
}
}
exports.HFInferenceAudioClassificationTask = HFInferenceAudioClassificationTask;
class HFInferenceAutomaticSpeechRecognitionTask extends HFInferenceTask {
async getResponse(response) {
return response;
}
async preparePayloadAsync(args) {
return "data" in args
? args
: {
...(0, omit_js_1.omit)(args, "inputs"),
data: args.inputs,
};
}
}
exports.HFInferenceAutomaticSpeechRecognitionTask = HFInferenceAutomaticSpeechRecognitionTask;
class HFInferenceAudioToAudioTask extends HFInferenceTask {
async getResponse(response) {
if (!Array.isArray(response)) {
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference audio-to-audio API: expected Array");
}
if (!response.every((elem) => {
return (typeof elem === "object" &&
elem &&
"label" in elem &&
typeof elem.label === "string" &&
"content-type" in elem &&
typeof elem["content-type"] === "string" &&
"blob" in elem &&
typeof elem.blob === "string");
})) {
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference audio-to-audio API: expected Array<{label: string, audio: Blob}>");
}
return response;
}
}
exports.HFInferenceAudioToAudioTask = HFInferenceAudioToAudioTask;
class HFInferenceDocumentQuestionAnsweringTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) &&
response.every((elem) => typeof elem === "object" &&
!!elem &&
typeof elem?.answer === "string" &&
(typeof elem.end === "number" || typeof elem.end === "undefined") &&
(typeof elem.score === "number" || typeof elem.score === "undefined") &&
(typeof elem.start === "number" || typeof elem.start === "undefined"))) {
return response[0];
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference document-question-answering API: expected Array<{answer: string, end: number, score: number, start: number}>");
}
}
exports.HFInferenceDocumentQuestionAnsweringTask = HFInferenceDocumentQuestionAnsweringTask;
class HFInferenceFeatureExtractionTask extends HFInferenceTask {
async getResponse(response) {
const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {
if (curDepth > maxDepth)
return false;
if (arr.every((x) => Array.isArray(x))) {
return arr.every((x) => isNumArrayRec(x, maxDepth, curDepth + 1));
}
else {
return arr.every((x) => typeof x === "number");
}
};
if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference feature-extraction API: expected Array<number[][][] | number[][] | number[] | number>");
}
}
exports.HFInferenceFeatureExtractionTask = HFInferenceFeatureExtractionTask;
class HFInferenceImageClassificationTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference image-classification API: expected Array<{label: string, score: number}>");
}
}
exports.HFInferenceImageClassificationTask = HFInferenceImageClassificationTask;
class HFInferenceImageSegmentationTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) &&
response.every((x) => typeof x.label === "string" &&
typeof x.mask === "string" &&
(x.score === undefined || typeof x.score === "number"))) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference image-segmentation API: expected Array<{label: string, mask: string, score: number}>");
}
}
exports.HFInferenceImageSegmentationTask = HFInferenceImageSegmentationTask;
class HFInferenceImageToTextTask extends HFInferenceTask {
async getResponse(response) {
if (typeof response?.generated_text !== "string") {
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference image-to-text API: expected {generated_text: string}");
}
return response;
}
}
exports.HFInferenceImageToTextTask = HFInferenceImageToTextTask;
class HFInferenceImageToImageTask extends HFInferenceTask {
async preparePayloadAsync(args) {
if (!args.parameters) {
return {
...args,
model: args.model,
data: args.inputs,
};
}
else {
return {
...args,
inputs: (0, base64FromBytes_js_1.base64FromBytes)(new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())),
};
}
}
async getResponse(response) {
if (response instanceof Blob) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference image-to-image API: expected Blob");
}
}
exports.HFInferenceImageToImageTask = HFInferenceImageToImageTask;
class HFInferenceObjectDetectionTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) &&
response.every((x) => typeof x.label === "string" &&
typeof x.score === "number" &&
typeof x.box.xmin === "number" &&
typeof x.box.ymin === "number" &&
typeof x.box.xmax === "number" &&
typeof x.box.ymax === "number")) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference object-detection API: expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>");
}
}
exports.HFInferenceObjectDetectionTask = HFInferenceObjectDetectionTask;
class HFInferenceZeroShotImageClassificationTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference zero-shot-image-classification API: expected Array<{label: string, score: number}>");
}
}
exports.HFInferenceZeroShotImageClassificationTask = HFInferenceZeroShotImageClassificationTask;
class HFInferenceTextClassificationTask extends HFInferenceTask {
async getResponse(response) {
const output = response?.[0];
if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) {
return output;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference text-classification API: expected Array<{label: string, score: number}>");
}
}
exports.HFInferenceTextClassificationTask = HFInferenceTextClassificationTask;
class HFInferenceQuestionAnsweringTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response)
? response.every((elem) => typeof elem === "object" &&
!!elem &&
typeof elem.answer === "string" &&
typeof elem.end === "number" &&
typeof elem.score === "number" &&
typeof elem.start === "number")
: typeof response === "object" &&
!!response &&
typeof response.answer === "string" &&
typeof response.end === "number" &&
typeof response.score === "number" &&
typeof response.start === "number") {
return Array.isArray(response) ? response[0] : response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference question-answering API: expected Array<{answer: string, end: number, score: number, start: number}>");
}
}
exports.HFInferenceQuestionAnsweringTask = HFInferenceQuestionAnsweringTask;
class HFInferenceFillMaskTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) &&
response.every((x) => typeof x.score === "number" &&
typeof x.sequence === "string" &&
typeof x.token === "number" &&
typeof x.token_str === "string")) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference fill-mask API: expected Array<{score: number, sequence: string, token: number, token_str: string}>");
}
}
exports.HFInferenceFillMaskTask = HFInferenceFillMaskTask;
class HFInferenceZeroShotClassificationTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) &&
response.every((x) => Array.isArray(x.labels) &&
x.labels.every((_label) => typeof _label === "string") &&
Array.isArray(x.scores) &&
x.scores.every((_score) => typeof _score === "number") &&
typeof x.sequence === "string")) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference zero-shot-classification API: expected Array<{labels: string[], scores: number[], sequence: string}>");
}
}
exports.HFInferenceZeroShotClassificationTask = HFInferenceZeroShotClassificationTask;
class HFInferenceSentenceSimilarityTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference sentence-similarity API: expected Array<number>");
}
}
exports.HFInferenceSentenceSimilarityTask = HFInferenceSentenceSimilarityTask;
class HFInferenceTableQuestionAnsweringTask extends HFInferenceTask {
static validate(elem) {
return (typeof elem === "object" &&
!!elem &&
"aggregator" in elem &&
typeof elem.aggregator === "string" &&
"answer" in elem &&
typeof elem.answer === "string" &&
"cells" in elem &&
Array.isArray(elem.cells) &&
elem.cells.every((x) => typeof x === "string") &&
"coordinates" in elem &&
Array.isArray(elem.coordinates) &&
elem.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")));
}
async getResponse(response) {
if (Array.isArray(response) && Array.isArray(response)
? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem))
: HFInferenceTableQuestionAnsweringTask.validate(response)) {
return Array.isArray(response) ? response[0] : response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference table-question-answering API: expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}");
}
}
exports.HFInferenceTableQuestionAnsweringTask = HFInferenceTableQuestionAnsweringTask;
class HFInferenceTokenClassificationTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) &&
response.every((x) => typeof x.end === "number" &&
typeof x.entity_group === "string" &&
typeof x.score === "number" &&
typeof x.start === "number" &&
typeof x.word === "string")) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference token-classification API: expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>");
}
}
exports.HFInferenceTokenClassificationTask = HFInferenceTokenClassificationTask;
class HFInferenceTranslationTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) {
return response?.length === 1 ? response?.[0] : response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference translation API: expected Array<{translation_text: string}>");
}
}
exports.HFInferenceTranslationTask = HFInferenceTranslationTask;
class HFInferenceSummarizationTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) {
return response?.[0];
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference summarization API: expected Array<{summary_text: string}>");
}
}
exports.HFInferenceSummarizationTask = HFInferenceSummarizationTask;
class HFInferenceTextToSpeechTask extends HFInferenceTask {
async getResponse(response) {
return response;
}
}
exports.HFInferenceTextToSpeechTask = HFInferenceTextToSpeechTask;
class HFInferenceTabularClassificationTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference tabular-classification API: expected Array<number>");
}
}
exports.HFInferenceTabularClassificationTask = HFInferenceTabularClassificationTask;
class HFInferenceVisualQuestionAnsweringTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) &&
response.every((elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number")) {
return response[0];
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference visual-question-answering API: expected Array<{answer: string, score: number}>");
}
}
exports.HFInferenceVisualQuestionAnsweringTask = HFInferenceVisualQuestionAnsweringTask;
class HFInferenceTabularRegressionTask extends HFInferenceTask {
async getResponse(response) {
if (Array.isArray(response) && response.every((x) => typeof x === "number")) {
return response;
}
throw new errors_js_1.InferenceClientProviderOutputError("Received malformed response from HF-Inference tabular-regression API: expected Array<number>");
}
}
exports.HFInferenceTabularRegressionTask = HFInferenceTabularRegressionTask;
class HFInferenceTextToAudioTask extends HFInferenceTask {
async getResponse(response) {
return response;
}
}
exports.HFInferenceTextToAudioTask = HFInferenceTextToAudioTask;
;