UNPKG

@huggingface/inference

Version:

Typescript client for the Hugging Face Inference Providers and Inference Endpoints

355 lines (307 loc) 12.9 kB
import type { AudioClassificationInput, AudioClassificationOutput, AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput, ChatCompletionInput, ChatCompletionOutput, DocumentQuestionAnsweringInput, DocumentQuestionAnsweringOutput, FeatureExtractionInput, FeatureExtractionOutput, FillMaskInput, FillMaskOutput, ImageClassificationInput, ImageClassificationOutput, ImageSegmentationInput, ImageSegmentationOutput, ImageToImageInput, ImageToTextInput, ImageToTextOutput, ObjectDetectionInput, ObjectDetectionOutput, QuestionAnsweringInput, QuestionAnsweringOutput, SentenceSimilarityInput, SentenceSimilarityOutput, SummarizationInput, SummarizationOutput, TableQuestionAnsweringInput, TableQuestionAnsweringOutput, TextClassificationOutput, TextGenerationInput, TextGenerationOutput, TextToImageInput, TextToSpeechInput, TextToVideoInput, TokenClassificationInput, TokenClassificationOutput, TranslationInput, TranslationOutput, VisualQuestionAnsweringInput, VisualQuestionAnsweringOutput, ZeroShotClassificationInput, ZeroShotClassificationOutput, ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput, } from "@huggingface/tasks"; import { HF_ROUTER_URL } from "../config"; import { InferenceOutputError } from "../lib/InferenceOutputError"; import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio"; import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, UrlParams } from "../types"; import { toArray } from "../utils/toArray"; /** * Base class for task-specific provider helpers */ export abstract class TaskProviderHelper { constructor( readonly provider: InferenceProvider, private baseUrl: string, readonly clientSideRoutingOnly: boolean = false ) {} /** * Return the response in the expected format. * Needs to be implemented in the subclasses. */ abstract getResponse( response: unknown, url?: string, headers?: HeadersInit, outputType?: "url" | "blob" ): Promise<unknown>; /** * Prepare the route for the request * Needs to be implemented in the subclasses. */ abstract makeRoute(params: UrlParams): string; /** * Prepare the payload for the request * Needs to be implemented in the subclasses. */ abstract preparePayload(params: BodyParams): unknown; /** * Prepare the base URL for the request */ makeBaseUrl(params: UrlParams): string { return params.authMethod !== "provider-key" ? `${HF_ROUTER_URL}/${this.provider}` : this.baseUrl; } /** * Prepare the body for the request */ makeBody(params: BodyParams): BodyInit { if ("data" in params.args && !!params.args.data) { return params.args.data as BodyInit; } return JSON.stringify(this.preparePayload(params)); } /** * Prepare the URL for the request */ makeUrl(params: UrlParams): string { const baseUrl = this.makeBaseUrl(params); const route = this.makeRoute(params).replace(/^\/+/, ""); return `${baseUrl}/${route}`; } /** * Prepare the headers for the request */ prepareHeaders(params: HeaderParams, isBinary: boolean): Record<string, string> { const headers: Record<string, string> = { Authorization: `Bearer ${params.accessToken}` }; if (!isBinary) { headers["Content-Type"] = "application/json"; } return headers; } } // PER-TASK PROVIDER HELPER INTERFACES // CV Tasks export interface TextToImageTaskHelper { getResponse( response: unknown, url?: string, headers?: HeadersInit, outputType?: "url" | "blob" ): Promise<string | Blob>; preparePayload(params: BodyParams<TextToImageInput & BaseArgs>): Record<string, unknown>; } export interface TextToVideoTaskHelper { getResponse(response: unknown, url?: string, headers?: Record<string, string>): Promise<Blob>; preparePayload(params: BodyParams<TextToVideoInput & BaseArgs>): Record<string, unknown>; } export interface ImageToImageTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>; preparePayload(params: BodyParams<ImageToImageInput & BaseArgs>): Record<string, unknown>; } export interface ImageSegmentationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ImageSegmentationOutput>; preparePayload(params: BodyParams<ImageSegmentationInput & BaseArgs>): Record<string, unknown> | BodyInit; } export interface ImageClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ImageClassificationOutput>; preparePayload(params: BodyParams<ImageClassificationInput & BaseArgs>): Record<string, unknown> | BodyInit; } export interface ObjectDetectionTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ObjectDetectionOutput>; preparePayload(params: BodyParams<ObjectDetectionInput & BaseArgs>): Record<string, unknown> | BodyInit; } export interface ImageToTextTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ImageToTextOutput>; preparePayload(params: BodyParams<ImageToTextInput & BaseArgs>): Record<string, unknown> | BodyInit; } export interface ZeroShotImageClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ZeroShotImageClassificationOutput>; preparePayload(params: BodyParams<ZeroShotImageClassificationInput & BaseArgs>): Record<string, unknown> | BodyInit; } // NLP Tasks export interface TextGenerationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TextGenerationOutput>; preparePayload(params: BodyParams<TextGenerationInput & BaseArgs>): Record<string, unknown>; } export interface ConversationalTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ChatCompletionOutput>; preparePayload(params: BodyParams<ChatCompletionInput & BaseArgs>): Record<string, unknown>; } export interface TextClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TextClassificationOutput>; preparePayload(params: BodyParams<ZeroShotClassificationInput & BaseArgs>): Record<string, unknown>; } export interface QuestionAnsweringTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<QuestionAnsweringOutput[number]>; preparePayload(params: BodyParams<QuestionAnsweringInput & BaseArgs>): Record<string, unknown>; } export interface FillMaskTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<FillMaskOutput>; preparePayload(params: BodyParams<FillMaskInput & BaseArgs>): Record<string, unknown>; } export interface ZeroShotClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ZeroShotClassificationOutput>; preparePayload(params: BodyParams<ZeroShotClassificationInput & BaseArgs>): Record<string, unknown>; } export interface SentenceSimilarityTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<SentenceSimilarityOutput>; preparePayload(params: BodyParams<SentenceSimilarityInput & BaseArgs>): Record<string, unknown>; } export interface TableQuestionAnsweringTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TableQuestionAnsweringOutput[number]>; preparePayload(params: BodyParams<TableQuestionAnsweringInput & BaseArgs>): Record<string, unknown>; } export interface TokenClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TokenClassificationOutput>; preparePayload(params: BodyParams<TokenClassificationInput & BaseArgs>): Record<string, unknown>; } export interface TranslationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<TranslationOutput>; preparePayload(params: BodyParams<TranslationInput & BaseArgs>): Record<string, unknown>; } export interface SummarizationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<SummarizationOutput>; preparePayload(params: BodyParams<SummarizationInput & BaseArgs>): Record<string, unknown>; } // Audio Tasks export interface TextToSpeechTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>; preparePayload(params: BodyParams<TextToSpeechInput & BaseArgs>): Record<string, unknown>; } export interface TextToAudioTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>; preparePayload(params: BodyParams<Record<string, unknown> & BaseArgs>): Record<string, unknown>; } export interface AudioToAudioTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AudioToAudioOutput[]>; preparePayload( params: BodyParams<BaseArgs & { inputs: Blob } & Record<string, unknown>> ): Record<string, unknown> | BodyInit; } export interface AutomaticSpeechRecognitionTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AutomaticSpeechRecognitionOutput>; preparePayload(params: BodyParams<AutomaticSpeechRecognitionInput & BaseArgs>): Record<string, unknown> | BodyInit; } export interface AudioClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AudioClassificationOutput>; preparePayload(params: BodyParams<AudioClassificationInput & BaseArgs>): Record<string, unknown> | BodyInit; } // Multimodal Tasks export interface DocumentQuestionAnsweringTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<DocumentQuestionAnsweringOutput[number]>; preparePayload(params: BodyParams<DocumentQuestionAnsweringInput & BaseArgs>): Record<string, unknown> | BodyInit; } export interface FeatureExtractionTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<FeatureExtractionOutput>; preparePayload(params: BodyParams<FeatureExtractionInput & BaseArgs>): Record<string, unknown>; } export interface VisualQuestionAnsweringTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<VisualQuestionAnsweringOutput[number]>; preparePayload(params: BodyParams<VisualQuestionAnsweringInput & BaseArgs>): Record<string, unknown> | BodyInit; } export interface TabularClassificationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<number[]>; preparePayload( params: BodyParams<BaseArgs & { inputs: { data: Record<string, string[]> } } & Record<string, unknown>> ): Record<string, unknown> | BodyInit; } export interface TabularRegressionTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<number[]>; preparePayload( params: BodyParams<BaseArgs & { inputs: { data: Record<string, string[]> } } & Record<string, unknown>> ): Record<string, unknown> | BodyInit; } // BASE IMPLEMENTATIONS FOR COMMON PATTERNS export class BaseConversationalTask extends TaskProviderHelper implements ConversationalTaskHelper { constructor(provider: InferenceProvider, baseUrl: string, clientSideRoutingOnly: boolean = false) { super(provider, baseUrl, clientSideRoutingOnly); } makeRoute(): string { return "v1/chat/completions"; } preparePayload(params: BodyParams): Record<string, unknown> { return { ...params.args, model: params.model, }; } async getResponse(response: ChatCompletionOutput): Promise<ChatCompletionOutput> { if ( typeof response === "object" && Array.isArray(response?.choices) && typeof response?.created === "number" && typeof response?.id === "string" && typeof response?.model === "string" && /// Together.ai and Nebius do not output a system_fingerprint (response.system_fingerprint === undefined || response.system_fingerprint === null || typeof response.system_fingerprint === "string") && typeof response?.usage === "object" ) { return response; } throw new InferenceOutputError("Expected ChatCompletionOutput"); } } export class BaseTextGenerationTask extends TaskProviderHelper implements TextGenerationTaskHelper { constructor(provider: InferenceProvider, baseUrl: string, clientSideRoutingOnly: boolean = false) { super(provider, baseUrl, clientSideRoutingOnly); } preparePayload(params: BodyParams): Record<string, unknown> { return { ...params.args, model: params.model, }; } makeRoute(): string { return "v1/completions"; } async getResponse(response: unknown): Promise<TextGenerationOutput> { const res = toArray(response); if ( Array.isArray(res) && res.length > 0 && res.every( (x): x is { generated_text: string } => typeof x === "object" && !!x && "generated_text" in x && typeof x.generated_text === "string" ) ) { return res[0]; } throw new InferenceOutputError("Expected Array<{generated_text: string}>"); } }