UNPKG

@huggingface/hub

Version:

Utilities to interact with the Hugging Face hub

191 lines (174 loc) 5.54 kB
import { HUB_URL } from "../consts"; import { createApiError } from "../error"; import type { ApiModelInfo } from "../types/api/api-model"; import type { CredentialsParams, PipelineType } from "../types/public"; import { checkCredentials } from "../utils/checkCredentials"; import { parseLinkHeader } from "../utils/parseLinkHeader"; import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping"; export const MODEL_EXPAND_KEYS = [ "pipeline_tag", "private", "gated", "downloads", "likes", "lastModified", ] as const satisfies readonly (keyof ApiModelInfo)[]; export const MODEL_EXPANDABLE_KEYS = [ "author", "cardData", "config", "createdAt", "disabled", "downloads", "downloadsAllTime", "gated", "gitalyUid", "inferenceProviderMapping", "lastModified", "library_name", "likes", "model-index", "pipeline_tag", "private", "safetensors", "sha", "spaces", "tags", "transformersInfo", ] as const satisfies readonly (keyof ApiModelInfo)[]; export interface ModelDerivedFields { filePaths: string[]; } export const MODEL_DERIVED_FIELD_TO_API_KEY: Record<keyof ModelDerivedFields, keyof ApiModelInfo> = { filePaths: "siblings", }; export type ModelAdditionalField = | Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> | keyof ModelDerivedFields; export type ResolveModelAdditionalFields<T extends ModelAdditionalField> = Pick<ApiModelInfo, T & keyof ApiModelInfo> & Pick<ModelDerivedFields, T & keyof ModelDerivedFields>; export interface ModelEntry { id: string; name: string; private: boolean; gated: false | "auto" | "manual"; task?: PipelineType; likes: number; downloads: number; updatedAt: Date; } export async function* listModels<const T extends ModelAdditionalField = never>( params?: { search?: { /** * Will search in the model name for matches */ query?: string; owner?: string; task?: PipelineType; tags?: string[]; /** * Will search for models that have one of the inference providers in the list. */ inferenceProviders?: string[]; /** * Will search for models that support at least one of those local apps (eg "lmstudio", "mlx-lm", ...) */ apps?: string[]; }; hubUrl?: string; additionalFields?: T[]; /** * Set to limit the number of models returned. */ limit?: number; /** * Sort models by a specific field. */ sort?: | "createdAt" | "downloads" | "likes" | "lastModified" | "likes30d" | "trendingScore" | "num_parameters" | "mainSize" | "id"; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial<CredentialsParams>, ): AsyncGenerator<ModelEntry & ResolveModelAdditionalFields<T>> { const accessToken = params && checkCredentials(params); let totalToFetch = params?.limit ?? Infinity; const additionalExpandKeys = params?.additionalFields?.map( (field) => MODEL_DERIVED_FIELD_TO_API_KEY[field as keyof ModelDerivedFields] ?? field, ) ?? []; const search = new URLSearchParams([ ...Object.entries({ limit: String(Math.min(totalToFetch, 500)), ...(params?.search?.owner ? { author: params.search.owner } : undefined), ...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined), ...(params?.search?.query ? { search: params.search.query } : undefined), ...(params?.search?.inferenceProviders ? { inference_provider: params.search.inferenceProviders.join(",") } : undefined), ...(params?.search?.apps ? { apps: params.search.apps.join(",") } : undefined), ...(params?.sort ? { sort: params.sort } : undefined), }), ...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []), ...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]), ...additionalExpandKeys.map((val) => ["expand", val] satisfies [string, string]), ]).toString(); let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`; while (url) { const res: Response = await (params?.fetch ?? fetch)(url, { headers: { accept: "application/json", ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined), }, }); if (!res.ok) { throw await createApiError(res); } const items: ApiModelInfo[] = await res.json(); for (const item of items) { const additional: Record<string, unknown> = {}; if (params?.additionalFields) { for (const field of params.additionalFields) { if (field === "filePaths") { additional.filePaths = (item.siblings ?? []).map((s) => s.rfilename); } else if (field === "inferenceProviderMapping" && item.inferenceProviderMapping) { additional.inferenceProviderMapping = normalizeInferenceProviderMapping( item.id, item.inferenceProviderMapping, ); } else { additional[field] = item[field as keyof ApiModelInfo]; } } } yield { ...additional, id: item._id, name: item.id, private: item.private, task: item.pipeline_tag, downloads: item.downloads, gated: item.gated, likes: item.likes, updatedAt: new Date(item.lastModified), } as ModelEntry & ResolveModelAdditionalFields<T>; totalToFetch--; if (totalToFetch <= 0) { return; } } const linkHeader = res.headers.get("Link"); url = linkHeader ? parseLinkHeader(linkHeader).next : undefined; // Could update url to reduce the limit if we don't need the whole 500 of the next batch. } }