UNPKG

@huggingface/hub

Version:

Utilities to interact with the Hugging Face hub

502 lines (452 loc) 14.4 kB
import type { CredentialsParams, RepoDesignation } from "../types/public"; import { omit } from "../utils/omit"; import { toRepoId } from "../utils/toRepoId"; import { typedEntries } from "../utils/typedEntries"; import { downloadFile } from "./download-file"; import { fileExists } from "./file-exists"; import { promisesQueue } from "../utils/promisesQueue"; import type { SetRequired } from "../vendor/type-fest/set-required"; export const SAFETENSORS_FILE = "model.safetensors"; export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"; /// We advise model/library authors to use the filenames above for convention inside model repos, /// but in some situations safetensors weights have different filenames. export const RE_SAFETENSORS_FILE = /\.safetensors$/; export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/; export const RE_SAFETENSORS_SHARD_FILE = /^(?<prefix>(?<basePrefix>.*?)[_-])(?<shard>\d{5,6})-of-(?<total>\d{5,6})\.safetensors$/; export interface SafetensorsShardFileInfo { prefix: string; basePrefix: string; shard: string; total: string; } export function parseSafetensorsShardFilename(filename: string): SafetensorsShardFileInfo | null { const match = RE_SAFETENSORS_SHARD_FILE.exec(filename); if (match && match.groups) { return { prefix: match.groups["prefix"], basePrefix: match.groups["basePrefix"], shard: match.groups["shard"], total: match.groups["total"], }; } return null; } const PARALLEL_DOWNLOADS = 20; const MAX_HEADER_LENGTH = 25_000_000; const GPTQ_QWEIGHT_SUFFIX = "qweight"; class SafetensorParseError extends Error {} type FileName = string; export type TensorName = string; export type Dtype = | "F64" | "F32" | "F16" | "F8_E4M3" | "F8_E5M2" | "E8M0" | "F6_E3M2" | "F6_E2M3" | "F4" | "FP4" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U16" | "U8" | "UE8" | "BOOL"; export interface TensorInfo { dtype: Dtype; shape: number[]; data_offsets: [number, number]; } export type SafetensorsFileHeader = Record<TensorName, TensorInfo> & { __metadata__: { total_parameters?: string | number } & Record<string, string>; }; export interface SafetensorsIndexJson { dtype?: string; /// ^there's sometimes a dtype but it looks inconsistent. metadata?: { total_parameters?: string | number } & Record<string, string>; /// ^ why the naming inconsistency? weight_map: Record<TensorName, FileName>; } export type SafetensorsShardedHeaders = Record<FileName, SafetensorsFileHeader>; export type SafetensorsParseFromRepo = | { sharded: false; header: SafetensorsFileHeader; parameterCount?: Partial<Record<Dtype, number>>; parameterTotal?: number; } | { sharded: true; index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders; parameterCount?: Partial<Record<Dtype, number>>; parameterTotal?: number; }; /** * Fetches and parses model config.json */ async function fetchModelConfig( params: { repo: RepoDesignation; revision?: string; hubUrl?: string; fetch?: typeof fetch; } & Partial<CredentialsParams> ): Promise<ModelConfig | null> { try { const configBlob = await downloadFile({ ...params, path: "config.json", }); if (!configBlob) { return null; } const config = JSON.parse(await configBlob.text()); return config as ModelConfig; } catch (error) { // Config file might not exist or be inaccessible return null; } } async function parseSingleFile( path: string, params: { repo: RepoDesignation; revision?: string; hubUrl?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial<CredentialsParams> ): Promise<SafetensorsFileHeader> { const blob = await downloadFile({ ...params, path }); if (!blob) { throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors header length.`); } const bufLengthOfHeaderLE = await blob.slice(0, 8).arrayBuffer(); const lengthOfHeader = new DataView(bufLengthOfHeaderLE).getBigUint64(0, true); // ^little-endian if (lengthOfHeader <= 0) { throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is malformed.`); } if (lengthOfHeader > MAX_HEADER_LENGTH) { throw new SafetensorParseError( `Failed to parse file ${path}: safetensor header is too big. Maximum supported size is ${MAX_HEADER_LENGTH} bytes.` ); } try { // no validation for now, we assume it's a valid FileHeader. const header: SafetensorsFileHeader = JSON.parse(await blob.slice(8, 8 + Number(lengthOfHeader)).text()); return header; } catch (err) { throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is not valid JSON.`); } } async function parseShardedIndex( path: string, params: { repo: RepoDesignation; revision?: string; hubUrl?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial<CredentialsParams> ): Promise<SafetensorsIndexJson> { const indexBlob = await downloadFile({ ...params, path, }); if (!indexBlob) { throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`); } try { // no validation for now, we assume it's a valid IndexJson. const index = JSON.parse(await indexBlob.slice(0, 20_000_000).text()); return index; } catch (error) { throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`); } } async function fetchAllHeaders( path: string, index: SafetensorsIndexJson, params: { repo: RepoDesignation; revision?: string; hubUrl?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial<CredentialsParams> ): Promise<SafetensorsShardedHeaders> { const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1); const filenames = [...new Set(Object.values(index.weight_map))]; const shardedMap: SafetensorsShardedHeaders = Object.fromEntries( await promisesQueue( filenames.map( (filename) => async () => [filename, await parseSingleFile(pathPrefix + filename, params)] satisfies [string, SafetensorsFileHeader] ), PARALLEL_DOWNLOADS ) ); return shardedMap; } /** * Analyze model.safetensors.index.json or model.safetensors from a model hosted * on Hugging Face using smart range requests to extract its metadata. */ export async function parseSafetensorsMetadata( params: { /** Only models are supported */ repo: RepoDesignation; /** * Relative file path to safetensors file inside `repo`. Defaults to `SAFETENSORS_FILE` or `SAFETENSORS_INDEX_FILE` (whichever one exists). */ path?: string; /** * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType * * @default false */ computeParametersCount: true; hubUrl?: string; revision?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial<CredentialsParams> ): Promise<SetRequired<SafetensorsParseFromRepo, "parameterCount">>; export async function parseSafetensorsMetadata( params: { /** Only models are supported */ repo: RepoDesignation; path?: string; /** * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType * * @default false */ computeParametersCount?: boolean; hubUrl?: string; revision?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial<CredentialsParams> ): Promise<SafetensorsParseFromRepo>; export async function parseSafetensorsMetadata( params: { repo: RepoDesignation; path?: string; computeParametersCount?: boolean; hubUrl?: string; revision?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial<CredentialsParams> ): Promise<SafetensorsParseFromRepo> { const repoId = toRepoId(params.repo); if (repoId.type !== "model") { throw new TypeError("Only model repos should contain safetensors files."); } // Fetch model config for quantization information const modelConfig = params.computeParametersCount ? await fetchModelConfig(params) : null; const quantConfig = modelConfig?.quantization_config; if ( (params.path && RE_SAFETENSORS_FILE.test(params.path)) || (await fileExists({ ...params, path: SAFETENSORS_FILE })) ) { const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params); return { sharded: false, header, ...(params.computeParametersCount ? { parameterCount: computeNumOfParamsByDtypeSingleFile(header, quantConfig), parameterTotal: /// shortcut: get param count directly from metadata header.__metadata__.total_parameters ? typeof header.__metadata__.total_parameters === "number" ? header.__metadata__.total_parameters : typeof header.__metadata__.total_parameters === "string" ? parseInt(header.__metadata__.total_parameters) : undefined : undefined, } : undefined), }; } else if ( (params.path && RE_SAFETENSORS_INDEX_FILE.test(params.path)) || (await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE })) ) { const path = params.path ?? SAFETENSORS_INDEX_FILE; const index = await parseShardedIndex(path, params); const shardedMap = await fetchAllHeaders(path, index, params); return { sharded: true, index, headers: shardedMap, ...(params.computeParametersCount ? { parameterCount: computeNumOfParamsByDtypeSharded(shardedMap, quantConfig), parameterTotal: /// shortcut: get param count directly from metadata index.metadata?.total_parameters ? typeof index.metadata.total_parameters === "number" ? index.metadata.total_parameters : typeof index.metadata.total_parameters === "string" ? parseInt(index.metadata.total_parameters) : undefined : undefined, } : undefined), }; } else { throw new Error("model id does not seem to contain safetensors weights"); } } export interface QuantizationConfig { quant_method?: string; modules_to_not_convert?: string[]; bits?: number; load_in_4bit?: boolean; load_in_8bit?: boolean; } export interface ModelConfig { quantization_config?: QuantizationConfig; } /** * Determines if a tensor is quantized based on quantization config and tensor name */ function isQuantizedTensor(tensorName: string, quantConfig?: QuantizationConfig): boolean { if (!quantConfig) { return false; } if (!quantConfig.modules_to_not_convert || quantConfig.modules_to_not_convert.length === 0) { return true; } for (const pattern of quantConfig.modules_to_not_convert) { const regexPattern = pattern.replace(/\*/g, ".*"); const regex = new RegExp(regexPattern); if (regex.test(tensorName)) { return false; } } return true; } /** * Gets the parameter multiplier for a quantized tensor based on quantization method */ function getQuantizationMultiplier(tensorName: string, dtype: Dtype, quantConfig?: QuantizationConfig): number { if (!quantConfig || !isQuantizedTensor(tensorName, quantConfig)) { return 1; } const quantMethod = quantConfig.quant_method?.toLowerCase(); switch (quantMethod) { case "mxfp4": if (dtype === "U8" && tensorName.includes("_blocks")) { return 2; } return 1; case "gptq": case "awq": if (getTensorSuffix(tensorName) === GPTQ_QWEIGHT_SUFFIX) { const bits = quantConfig.bits && quantConfig.bits > 0 ? quantConfig.bits : 4; return Math.max(1, Math.floor(32 / bits)); } if (quantConfig.bits === 4 && dtype === "U8") { return 2; } if (quantConfig.bits === 2 && dtype === "U8") { return 4; } return 1; case "bitsandbytes": if (quantConfig.load_in_4bit && dtype === "U8") { return 2; } if (quantConfig.load_in_8bit) { return 1; } return 1; default: if (quantConfig.load_in_4bit && dtype === "U8") { return 2; } if (quantConfig.bits === 4 && dtype === "U8") { return 2; } return 1; } } function computeNumOfParamsByDtypeSingleFile( header: SafetensorsFileHeader, quantConfig?: QuantizationConfig ): Partial<Record<Dtype, number>> { const counter: Partial<Record<Dtype, number>> = {}; const tensors = omit(header, "__metadata__"); for (const [tensorName, v] of typedEntries(tensors)) { if (shouldSkipTensor(tensorName, quantConfig)) { continue; } if (v.shape.length === 0) { continue; } const elements = v.shape.reduce((a, b) => a * b); const multiplier = quantConfig ? getQuantizationMultiplier(tensorName, v.dtype, quantConfig) : 1; if (multiplier === 0) { continue; } counter[v.dtype] = (counter[v.dtype] ?? 0) + elements * multiplier; } return counter; } function computeNumOfParamsByDtypeSharded( shardedMap: SafetensorsShardedHeaders, quantConfig?: QuantizationConfig ): Partial<Record<Dtype, number>> { const counter: Partial<Record<Dtype, number>> = {}; for (const header of Object.values(shardedMap)) { for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header, quantConfig))) { counter[k] = (counter[k] ?? 0) + (v ?? 0); } } return counter; } function getTensorSuffix(tensorName: string): string { const lastDotIndex = tensorName.lastIndexOf("."); return lastDotIndex === -1 ? tensorName : tensorName.slice(lastDotIndex + 1); } function shouldSkipTensor(tensorName: string, quantConfig?: QuantizationConfig): boolean { const GPTQ_AWQ_AUXILIARY_SUFFIXES = ["qzeros", "g_idx", "scales"]; if (!quantConfig) { return false; } const quantMethod = quantConfig.quant_method?.toLowerCase(); if (!quantMethod || (quantMethod !== "gptq" && quantMethod !== "awq")) { return false; } if (!isQuantizedTensor(tensorName, quantConfig)) { return false; } const suffix = getTensorSuffix(tensorName); if (suffix === GPTQ_QWEIGHT_SUFFIX) { return false; } return GPTQ_AWQ_AUXILIARY_SUFFIXES.includes(suffix); }