@huggingface/hub
Version:
Utilities to interact with the Hugging Face hub
125 lines (113 loc) • 3.29 kB
text/typescript
import type { CredentialsParams, RepoDesignation } from "../types/public";
import { listFiles } from "./list-files";
import { getHFHubCachePath, getRepoFolderName } from "./cache-management";
import { spaceInfo } from "./space-info";
import { datasetInfo } from "./dataset-info";
import { modelInfo } from "./model-info";
import { toRepoId } from "../utils/toRepoId";
import { join, dirname } from "node:path";
import { mkdir, writeFile } from "node:fs/promises";
import { downloadFileToCacheDir } from "./download-file-to-cache-dir";
export const DEFAULT_REVISION = "main";
/**
* Downloads an entire repository at a given revision in the cache directory {@link getHFHubCachePath}.
* You can list all cached repositories using {@link scanCachedRepo}
* @remarks It uses internally {@link downloadFileToCacheDir}.
*/
export async function snapshotDownload(
params: {
repo: RepoDesignation;
cacheDir?: string;
/**
* An optional Git revision id which can be a branch name, a tag, or a commit hash.
*
* @default "main"
*/
revision?: string;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<string> {
let cacheDir: string;
if (params.cacheDir) {
cacheDir = params.cacheDir;
} else {
cacheDir = getHFHubCachePath();
}
let revision: string;
if (params.revision) {
revision = params.revision;
} else {
revision = DEFAULT_REVISION;
}
const repoId = toRepoId(params.repo);
// get repository revision value (sha)
let repoInfo: { sha: string };
switch (repoId.type) {
case "space":
repoInfo = await spaceInfo({
...params,
name: repoId.name,
additionalFields: ["sha"],
revision: revision,
});
break;
case "dataset":
repoInfo = await datasetInfo({
...params,
name: repoId.name,
additionalFields: ["sha"],
revision: revision,
});
break;
case "model":
repoInfo = await modelInfo({
...params,
name: repoId.name,
additionalFields: ["sha"],
revision: revision,
});
break;
default:
throw new Error(`invalid repository type ${repoId.type}`);
}
const commitHash: string = repoInfo.sha;
// get storage folder
const storageFolder = join(cacheDir, getRepoFolderName(repoId));
const snapshotFolder = join(storageFolder, "snapshots", commitHash);
// if passed revision is not identical to commit_hash
// then revision has to be a branch name or tag name.
// In that case store a ref.
if (revision !== commitHash) {
const refPath = join(storageFolder, "refs", revision);
await mkdir(dirname(refPath), { recursive: true });
await writeFile(refPath, commitHash);
}
const cursor = listFiles({
...params,
repo: params.repo,
recursive: true,
revision: repoInfo.sha,
});
for await (const entry of cursor) {
switch (entry.type) {
case "file":
await downloadFileToCacheDir({
...params,
path: entry.path,
revision: commitHash,
cacheDir: cacheDir,
});
break;
case "directory":
await mkdir(join(snapshotFolder, entry.path), { recursive: true });
break;
default:
throw new Error(`unknown entry type: ${entry.type}`);
}
}
return snapshotFolder;
}