@huggingface/hub
Version:
Utilities to interact with the Hugging Face hub
191 lines (174 loc) • 5.54 kB
text/typescript
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiModelInfo } from "../types/api/api-model";
import type { CredentialsParams, PipelineType } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping";
export const MODEL_EXPAND_KEYS = [
"pipeline_tag",
"private",
"gated",
"downloads",
"likes",
"lastModified",
] as const satisfies readonly (keyof ApiModelInfo)[];
export const MODEL_EXPANDABLE_KEYS = [
"author",
"cardData",
"config",
"createdAt",
"disabled",
"downloads",
"downloadsAllTime",
"gated",
"gitalyUid",
"inferenceProviderMapping",
"lastModified",
"library_name",
"likes",
"model-index",
"pipeline_tag",
"private",
"safetensors",
"sha",
"spaces",
"tags",
"transformersInfo",
] as const satisfies readonly (keyof ApiModelInfo)[];
export interface ModelDerivedFields {
filePaths: string[];
}
export const MODEL_DERIVED_FIELD_TO_API_KEY: Record<keyof ModelDerivedFields, keyof ApiModelInfo> = {
filePaths: "siblings",
};
export type ModelAdditionalField =
| Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]>
| keyof ModelDerivedFields;
export type ResolveModelAdditionalFields<T extends ModelAdditionalField> = Pick<ApiModelInfo, T & keyof ApiModelInfo> &
Pick<ModelDerivedFields, T & keyof ModelDerivedFields>;
export interface ModelEntry {
id: string;
name: string;
private: boolean;
gated: false | "auto" | "manual";
task?: PipelineType;
likes: number;
downloads: number;
updatedAt: Date;
}
export async function* listModels<const T extends ModelAdditionalField = never>(
params?: {
search?: {
/**
* Will search in the model name for matches
*/
query?: string;
owner?: string;
task?: PipelineType;
tags?: string[];
/**
* Will search for models that have one of the inference providers in the list.
*/
inferenceProviders?: string[];
/**
* Will search for models that support at least one of those local apps (eg "lmstudio", "mlx-lm", ...)
*/
apps?: string[];
};
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Sort models by a specific field.
*/
sort?:
| "createdAt"
| "downloads"
| "likes"
| "lastModified"
| "likes30d"
| "trendingScore"
| "num_parameters"
| "mainSize"
| "id";
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>,
): AsyncGenerator<ModelEntry & ResolveModelAdditionalFields<T>> {
const accessToken = params && checkCredentials(params);
let totalToFetch = params?.limit ?? Infinity;
const additionalExpandKeys =
params?.additionalFields?.map(
(field) => MODEL_DERIVED_FIELD_TO_API_KEY[field as keyof ModelDerivedFields] ?? field,
) ?? [];
const search = new URLSearchParams([
...Object.entries({
limit: String(Math.min(totalToFetch, 500)),
...(params?.search?.owner ? { author: params.search.owner } : undefined),
...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
...(params?.search?.query ? { search: params.search.query } : undefined),
...(params?.search?.inferenceProviders
? { inference_provider: params.search.inferenceProviders.join(",") }
: undefined),
...(params?.search?.apps ? { apps: params.search.apps.join(",") } : undefined),
...(params?.sort ? { sort: params.sort } : undefined),
}),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...additionalExpandKeys.map((val) => ["expand", val] satisfies [string, string]),
]).toString();
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;
while (url) {
const res: Response = await (params?.fetch ?? fetch)(url, {
headers: {
accept: "application/json",
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
},
});
if (!res.ok) {
throw await createApiError(res);
}
const items: ApiModelInfo[] = await res.json();
for (const item of items) {
const additional: Record<string, unknown> = {};
if (params?.additionalFields) {
for (const field of params.additionalFields) {
if (field === "filePaths") {
additional.filePaths = (item.siblings ?? []).map((s) => s.rfilename);
} else if (field === "inferenceProviderMapping" && item.inferenceProviderMapping) {
additional.inferenceProviderMapping = normalizeInferenceProviderMapping(
item.id,
item.inferenceProviderMapping,
);
} else {
additional[field] = item[field as keyof ApiModelInfo];
}
}
}
yield {
...additional,
id: item._id,
name: item.id,
private: item.private,
task: item.pipeline_tag,
downloads: item.downloads,
gated: item.gated,
likes: item.likes,
updatedAt: new Date(item.lastModified),
} as ModelEntry & ResolveModelAdditionalFields<T>;
totalToFetch--;
if (totalToFetch <= 0) {
return;
}
}
const linkHeader = res.headers.get("Link");
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
// Could update url to reduce the limit if we don't need the whole 500 of the next batch.
}
}