@huggingface/hub
Version:
Utilities to interact with the Hugging Face hub
157 lines (144 loc) • 4.82 kB
text/typescript
import type { CredentialsParams } from "../types/public";
import { typedInclude } from "../utils/typedInclude";
import type { CommitOutput, CommitParams, CommitProgressEvent, ContentSource } from "./commit";
import { commitIter } from "./commit";
const multipartUploadTracking = new WeakMap<
(progress: number) => void,
{
numParts: number;
partsProgress: Record<number, number>;
}
>();
/**
* Uploads with progress
*
* Needs XMLHttpRequest to be available for progress events for uploads
* Set useWebWorkers to true in order to have progress events for hashing
*/
export async function* uploadFilesWithProgress(
params: {
repo: CommitParams["repo"];
files: Array<URL | File | { path: string; content: ContentSource }>;
commitTitle?: CommitParams["title"];
commitDescription?: CommitParams["description"];
hubUrl?: CommitParams["hubUrl"];
branch?: CommitParams["branch"];
isPullRequest?: CommitParams["isPullRequest"];
parentCommit?: CommitParams["parentCommit"];
abortSignal?: CommitParams["abortSignal"];
maxFolderDepth?: CommitParams["maxFolderDepth"];
useXet?: CommitParams["useXet"];
/**
* Set this to true in order to have progress events for hashing
*/
useWebWorkers?: CommitParams["useWebWorkers"];
} & Partial<CredentialsParams>
): AsyncGenerator<CommitProgressEvent, CommitOutput> {
return yield* commitIter({
...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }),
repo: params.repo,
operations: params.files.map((file) => ({
operation: "addOrUpdate",
path: file instanceof URL ? file.pathname.split("/").at(-1) ?? "file" : "path" in file ? file.path : file.name,
content: "content" in file ? file.content : file,
})),
title: params.commitTitle ?? `Add ${params.files.length} files`,
description: params.commitDescription,
hubUrl: params.hubUrl,
branch: params.branch,
isPullRequest: params.isPullRequest,
parentCommit: params.parentCommit,
useWebWorkers: params.useWebWorkers,
abortSignal: params.abortSignal,
useXet: params.useXet,
fetch: async (input, init) => {
if (!init) {
return fetch(input);
}
if (
!typedInclude(["PUT", "POST"], init.method) ||
!("progressHint" in init) ||
!init.progressHint ||
typeof XMLHttpRequest === "undefined" ||
typeof input !== "string" ||
(!(init.body instanceof ArrayBuffer) &&
!(init.body instanceof Blob) &&
!(init.body instanceof File) &&
typeof init.body !== "string")
) {
return fetch(input, init);
}
const progressHint = init.progressHint as {
progressCallback: (progress: number) => void;
} & (Record<string, never> | { part: number; numParts: number });
const progressCallback = progressHint.progressCallback;
const xhr = new XMLHttpRequest();
xhr.upload.addEventListener("progress", (event) => {
if (event.lengthComputable) {
if (progressHint.part !== undefined) {
let tracking = multipartUploadTracking.get(progressCallback);
if (!tracking) {
tracking = { numParts: progressHint.numParts, partsProgress: {} };
multipartUploadTracking.set(progressCallback, tracking);
}
tracking.partsProgress[progressHint.part] = event.loaded / event.total;
let totalProgress = 0;
for (const partProgress of Object.values(tracking.partsProgress)) {
totalProgress += partProgress;
}
if (totalProgress === tracking.numParts) {
progressCallback(0.9999999999);
} else {
progressCallback(totalProgress / tracking.numParts);
}
} else {
if (event.loaded === event.total) {
progressCallback(0.9999999999);
} else {
progressCallback(event.loaded / event.total);
}
}
}
});
xhr.open(init.method, input, true);
if (init.headers) {
const headers = new Headers(init.headers);
headers.forEach((value, key) => {
xhr.setRequestHeader(key, value);
});
}
init.signal?.throwIfAborted();
xhr.send(init.body);
return new Promise((resolve, reject) => {
xhr.addEventListener("load", () => {
resolve(
new Response(xhr.responseText, {
status: xhr.status,
statusText: xhr.statusText,
headers: Object.fromEntries(
xhr
.getAllResponseHeaders()
.trim()
.split("\n")
.map((header) => [header.slice(0, header.indexOf(":")), header.slice(header.indexOf(":") + 1).trim()])
),
})
);
});
xhr.addEventListener("error", () => {
reject(new Error(xhr.statusText));
});
if (init.signal) {
init.signal.addEventListener("abort", () => {
xhr.abort();
try {
init.signal?.throwIfAborted();
} catch (err) {
reject(err);
}
});
}
});
},
});
}