UNPKG

@huggingface/inference

Version:

Typescript client for the Hugging Face Inference Providers and Inference Endpoints

571 lines (530 loc) 19.8 kB
/** * HF-Inference do not have a mapping since all models use IDs from the Hub. * * If you want to try to run inference for a new model locally before it's registered on huggingface.co, * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. * * - If you work at HF and want to update this mapping, please use the model mapping API we provide on huggingface.co * - If you're a community member and want to add a new supported HF model to HF, please open an issue on the present repo * and we will tag HF team members. * * Thanks! */ import type { AudioClassificationOutput, AutomaticSpeechRecognitionOutput, ChatCompletionOutput, DocumentQuestionAnsweringOutput, FeatureExtractionOutput, FillMaskOutput, ImageClassificationOutput, ImageSegmentationOutput, ImageToTextOutput, ObjectDetectionOutput, QuestionAnsweringOutput, SentenceSimilarityOutput, SummarizationOutput, TableQuestionAnsweringOutput, TextClassificationOutput, TextGenerationOutput, TokenClassificationOutput, TranslationOutput, VisualQuestionAnsweringOutput, ZeroShotClassificationOutput, ZeroShotImageClassificationOutput, } from "@huggingface/tasks"; import { HF_ROUTER_URL } from "../config"; import { InferenceOutputError } from "../lib/InferenceOutputError"; import type { TabularClassificationOutput } from "../tasks/tabular/tabularClassification"; import type { BodyParams, UrlParams } from "../types"; import { toArray } from "../utils/toArray"; import type { AudioClassificationTaskHelper, AudioToAudioTaskHelper, AutomaticSpeechRecognitionTaskHelper, ConversationalTaskHelper, DocumentQuestionAnsweringTaskHelper, FeatureExtractionTaskHelper, FillMaskTaskHelper, ImageClassificationTaskHelper, ImageSegmentationTaskHelper, ImageToImageTaskHelper, ImageToTextTaskHelper, ObjectDetectionTaskHelper, QuestionAnsweringTaskHelper, SentenceSimilarityTaskHelper, SummarizationTaskHelper, TableQuestionAnsweringTaskHelper, TabularClassificationTaskHelper, TabularRegressionTaskHelper, TextClassificationTaskHelper, TextGenerationTaskHelper, TextToAudioTaskHelper, TextToImageTaskHelper, TextToSpeechTaskHelper, TokenClassificationTaskHelper, TranslationTaskHelper, VisualQuestionAnsweringTaskHelper, ZeroShotClassificationTaskHelper, ZeroShotImageClassificationTaskHelper, } from "./providerHelper"; import { TaskProviderHelper } from "./providerHelper"; interface Base64ImageGeneration { data: Array<{ b64_json: string; }>; } interface OutputUrlImageGeneration { output: string[]; } interface AudioToAudioOutput { blob: string; "content-type": string; label: string; } export const EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS = ["feature-extraction", "sentence-similarity"] as const; export class HFInferenceTask extends TaskProviderHelper { constructor() { super("hf-inference", `${HF_ROUTER_URL}/hf-inference`); } preparePayload(params: BodyParams): Record<string, unknown> { return params.args; } override makeUrl(params: UrlParams): string { if (params.model.startsWith("http://") || params.model.startsWith("https://")) { return params.model; } return super.makeUrl(params); } makeRoute(params: UrlParams): string { if (params.task && ["feature-extraction", "sentence-similarity"].includes(params.task)) { // when deployed on hf-inference, those two tasks are automatically compatible with one another. return `models/${params.model}/pipeline/${params.task}`; } return `models/${params.model}`; } override async getResponse(response: unknown): Promise<unknown> { return response; } } export class HFInferenceTextToImageTask extends HFInferenceTask implements TextToImageTaskHelper { override async getResponse( response: Base64ImageGeneration | OutputUrlImageGeneration, url?: string, headers?: HeadersInit, outputType?: "url" | "blob" ): Promise<string | Blob> { if (!response) { throw new InferenceOutputError("response is undefined"); } if (typeof response == "object") { if ("data" in response && Array.isArray(response.data) && response.data[0].b64_json) { const base64Data = response.data[0].b64_json; if (outputType === "url") { return `data:image/jpeg;base64,${base64Data}`; } const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`); return await base64Response.blob(); } if ("output" in response && Array.isArray(response.output)) { if (outputType === "url") { return response.output[0]; } const urlResponse = await fetch(response.output[0]); const blob = await urlResponse.blob(); return blob; } } if (response instanceof Blob) { if (outputType === "url") { const b64 = await response.arrayBuffer().then((buf) => Buffer.from(buf).toString("base64")); return `data:image/jpeg;base64,${b64}`; } return response; } throw new InferenceOutputError("Expected a Blob "); } } export class HFInferenceConversationalTask extends HFInferenceTask implements ConversationalTaskHelper { override makeUrl(params: UrlParams): string { let url: string; if (params.model.startsWith("http://") || params.model.startsWith("https://")) { url = params.model.trim(); } else { url = `${this.makeBaseUrl(params)}/models/${params.model}`; } url = url.replace(/\/+$/, ""); if (url.endsWith("/v1")) { url += "/chat/completions"; } else if (!url.endsWith("/chat/completions")) { url += "/v1/chat/completions"; } return url; } override preparePayload(params: BodyParams): Record<string, unknown> { return { ...params.args, model: params.model, }; } override async getResponse(response: ChatCompletionOutput): Promise<ChatCompletionOutput> { return response; } } export class HFInferenceTextGenerationTask extends HFInferenceTask implements TextGenerationTaskHelper { override async getResponse(response: TextGenerationOutput | TextGenerationOutput[]): Promise<TextGenerationOutput> { const res = toArray(response); if (Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string")) { return (res as TextGenerationOutput[])?.[0]; } throw new InferenceOutputError("Expected Array<{generated_text: string}>"); } } export class HFInferenceAudioClassificationTask extends HFInferenceTask implements AudioClassificationTaskHelper { override async getResponse(response: unknown): Promise<AudioClassificationOutput> { // Add type checking/validation for the 'unknown' input if ( Array.isArray(response) && response.every( (x): x is { label: string; score: number } => typeof x === "object" && x !== null && typeof x.label === "string" && typeof x.score === "number" ) ) { // If validation passes, it's safe to return as AudioClassificationOutput return response; } // If validation fails, throw an error throw new InferenceOutputError("Expected Array<{label: string, score: number}> but received different format"); } } export class HFInferenceAutomaticSpeechRecognitionTask extends HFInferenceTask implements AutomaticSpeechRecognitionTaskHelper { override async getResponse(response: AutomaticSpeechRecognitionOutput): Promise<AutomaticSpeechRecognitionOutput> { return response; } } export class HFInferenceAudioToAudioTask extends HFInferenceTask implements AudioToAudioTaskHelper { override async getResponse(response: AudioToAudioOutput[]): Promise<AudioToAudioOutput[]> { if (!Array.isArray(response)) { throw new InferenceOutputError("Expected Array"); } if ( !response.every((elem): elem is AudioToAudioOutput => { return ( typeof elem === "object" && elem && "label" in elem && typeof elem.label === "string" && "content-type" in elem && typeof elem["content-type"] === "string" && "blob" in elem && typeof elem.blob === "string" ); }) ) { throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>"); } return response; } } export class HFInferenceDocumentQuestionAnsweringTask extends HFInferenceTask implements DocumentQuestionAnsweringTaskHelper { override async getResponse( response: DocumentQuestionAnsweringOutput ): Promise<DocumentQuestionAnsweringOutput[number]> { if ( Array.isArray(response) && response.every( (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && (typeof elem.end === "number" || typeof elem.end === "undefined") && (typeof elem.score === "number" || typeof elem.score === "undefined") && (typeof elem.start === "number" || typeof elem.start === "undefined") ) ) { return response[0]; } throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>"); } } export class HFInferenceFeatureExtractionTask extends HFInferenceTask implements FeatureExtractionTaskHelper { override async getResponse(response: FeatureExtractionOutput): Promise<FeatureExtractionOutput> { const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => { if (curDepth > maxDepth) return false; if (arr.every((x) => Array.isArray(x))) { return arr.every((x) => isNumArrayRec(x as unknown[], maxDepth, curDepth + 1)); } else { return arr.every((x) => typeof x === "number"); } }; if (Array.isArray(response) && isNumArrayRec(response, 3, 0)) { return response; } throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>"); } } export class HFInferenceImageClassificationTask extends HFInferenceTask implements ImageClassificationTaskHelper { override async getResponse(response: ImageClassificationOutput): Promise<ImageClassificationOutput> { if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) { return response; } throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); } } export class HFInferenceImageSegmentationTask extends HFInferenceTask implements ImageSegmentationTaskHelper { override async getResponse(response: ImageSegmentationOutput): Promise<ImageSegmentationOutput> { if ( Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number") ) { return response; } throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>"); } } export class HFInferenceImageToTextTask extends HFInferenceTask implements ImageToTextTaskHelper { override async getResponse(response: ImageToTextOutput): Promise<ImageToTextOutput> { if (typeof response?.generated_text !== "string") { throw new InferenceOutputError("Expected {generated_text: string}"); } return response; } } export class HFInferenceImageToImageTask extends HFInferenceTask implements ImageToImageTaskHelper { override async getResponse(response: Blob): Promise<Blob> { if (response instanceof Blob) { return response; } throw new InferenceOutputError("Expected Blob"); } } export class HFInferenceObjectDetectionTask extends HFInferenceTask implements ObjectDetectionTaskHelper { override async getResponse(response: ObjectDetectionOutput): Promise<ObjectDetectionOutput> { if ( Array.isArray(response) && response.every( (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number" ) ) { return response; } throw new InferenceOutputError( "Expected Array<{label: string, score: number, box: {xmin: number, ymin: number, xmax: number, ymax: number}}>" ); } } export class HFInferenceZeroShotImageClassificationTask extends HFInferenceTask implements ZeroShotImageClassificationTaskHelper { override async getResponse(response: ZeroShotImageClassificationOutput): Promise<ZeroShotImageClassificationOutput> { if (Array.isArray(response) && response.every((x) => typeof x.label === "string" && typeof x.score === "number")) { return response; } throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); } } export class HFInferenceTextClassificationTask extends HFInferenceTask implements TextClassificationTaskHelper { override async getResponse(response: TextClassificationOutput): Promise<TextClassificationOutput> { const output = response?.[0]; if (Array.isArray(output) && output.every((x) => typeof x?.label === "string" && typeof x.score === "number")) { return output; } throw new InferenceOutputError("Expected Array<{label: string, score: number}>"); } } export class HFInferenceQuestionAnsweringTask extends HFInferenceTask implements QuestionAnsweringTaskHelper { override async getResponse( response: QuestionAnsweringOutput | QuestionAnsweringOutput[number] ): Promise<QuestionAnsweringOutput[number]> { if ( Array.isArray(response) ? response.every( (elem) => typeof elem === "object" && !!elem && typeof elem.answer === "string" && typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number" ) : typeof response === "object" && !!response && typeof response.answer === "string" && typeof response.end === "number" && typeof response.score === "number" && typeof response.start === "number" ) { return Array.isArray(response) ? response[0] : response; } throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>"); } } export class HFInferenceFillMaskTask extends HFInferenceTask implements FillMaskTaskHelper { override async getResponse(response: FillMaskOutput): Promise<FillMaskOutput> { if ( Array.isArray(response) && response.every( (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string" ) ) { return response; } throw new InferenceOutputError( "Expected Array<{score: number, sequence: string, token: number, token_str: string}>" ); } } export class HFInferenceZeroShotClassificationTask extends HFInferenceTask implements ZeroShotClassificationTaskHelper { override async getResponse(response: ZeroShotClassificationOutput): Promise<ZeroShotClassificationOutput> { if ( Array.isArray(response) && response.every( (x) => Array.isArray(x.labels) && x.labels.every((_label) => typeof _label === "string") && Array.isArray(x.scores) && x.scores.every((_score) => typeof _score === "number") && typeof x.sequence === "string" ) ) { return response; } throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>"); } } export class HFInferenceSentenceSimilarityTask extends HFInferenceTask implements SentenceSimilarityTaskHelper { override async getResponse(response: SentenceSimilarityOutput): Promise<SentenceSimilarityOutput> { if (Array.isArray(response) && response.every((x) => typeof x === "number")) { return response; } throw new InferenceOutputError("Expected Array<number>"); } } export class HFInferenceTableQuestionAnsweringTask extends HFInferenceTask implements TableQuestionAnsweringTaskHelper { static validate(elem: unknown): elem is TableQuestionAnsweringOutput[number] { return ( typeof elem === "object" && !!elem && "aggregator" in elem && typeof elem.aggregator === "string" && "answer" in elem && typeof elem.answer === "string" && "cells" in elem && Array.isArray(elem.cells) && elem.cells.every((x: unknown): x is string => typeof x === "string") && "coordinates" in elem && Array.isArray(elem.coordinates) && elem.coordinates.every( (coord: unknown): coord is number[] => Array.isArray(coord) && coord.every((x) => typeof x === "number") ) ); } override async getResponse(response: TableQuestionAnsweringOutput): Promise<TableQuestionAnsweringOutput[number]> { if ( Array.isArray(response) && Array.isArray(response) ? response.every((elem) => HFInferenceTableQuestionAnsweringTask.validate(elem)) : HFInferenceTableQuestionAnsweringTask.validate(response) ) { return Array.isArray(response) ? response[0] : response; } throw new InferenceOutputError( "Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}" ); } } export class HFInferenceTokenClassificationTask extends HFInferenceTask implements TokenClassificationTaskHelper { override async getResponse(response: TokenClassificationOutput): Promise<TokenClassificationOutput> { if ( Array.isArray(response) && response.every( (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string" ) ) { return response; } throw new InferenceOutputError( "Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>" ); } } export class HFInferenceTranslationTask extends HFInferenceTask implements TranslationTaskHelper { override async getResponse(response: TranslationOutput): Promise<TranslationOutput> { if (Array.isArray(response) && response.every((x) => typeof x?.translation_text === "string")) { return response?.length === 1 ? response?.[0] : response; } throw new InferenceOutputError("Expected Array<{translation_text: string}>"); } } export class HFInferenceSummarizationTask extends HFInferenceTask implements SummarizationTaskHelper { override async getResponse(response: SummarizationOutput): Promise<SummarizationOutput> { if (Array.isArray(response) && response.every((x) => typeof x?.summary_text === "string")) { return response?.[0]; } throw new InferenceOutputError("Expected Array<{summary_text: string}>"); } } export class HFInferenceTextToSpeechTask extends HFInferenceTask implements TextToSpeechTaskHelper { override async getResponse(response: Blob): Promise<Blob> { return response; } } export class HFInferenceTabularClassificationTask extends HFInferenceTask implements TabularClassificationTaskHelper { override async getResponse(response: TabularClassificationOutput): Promise<TabularClassificationOutput> { if (Array.isArray(response) && response.every((x) => typeof x === "number")) { return response; } throw new InferenceOutputError("Expected Array<number>"); } } export class HFInferenceVisualQuestionAnsweringTask extends HFInferenceTask implements VisualQuestionAnsweringTaskHelper { override async getResponse(response: VisualQuestionAnsweringOutput): Promise<VisualQuestionAnsweringOutput[number]> { if ( Array.isArray(response) && response.every( (elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number" ) ) { return response[0]; } throw new InferenceOutputError("Expected Array<{answer: string, score: number}>"); } } export class HFInferenceTabularRegressionTask extends HFInferenceTask implements TabularRegressionTaskHelper { override async getResponse(response: number[]): Promise<number[]> { if (Array.isArray(response) && response.every((x) => typeof x === "number")) { return response; } throw new InferenceOutputError("Expected Array<number>"); } } export class HFInferenceTextToAudioTask extends HFInferenceTask implements TextToAudioTaskHelper { override async getResponse(response: Blob): Promise<Blob> { return response; } }