react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
317 lines (313 loc) • 12.9 kB
JavaScript
;
/**
* Resource Fetcher
*
* Provides an interface for downloading files (via `ResourceFetcher.fetch()`)
*
* Key functionality:
* - Download control: pause, resume, and cancel operations through:
* - Single file: `.pauseFetching()`, `.resumeFetching()`, `.cancelFetching()`
* - Downloaded file management:
* - `.getFilesTotalSize()`, `.listDownloadedFiles()`, `.listDownloadedModels()`, `.deleteResources()`
*
* Remark: The pausing/resuming/canceling works only for fetching remote resources.
*
* Most exported functions accept:
* - Multiple `ResourceSource` arguments, (union type of string, number or object)
*
* Method `.fetch()` takes argument as callback that reports download progress.
* Method`.fetch()` returns array of paths to successfully saved files or null if the download was paused or cancelled (then resume functions can return paths).
*
* Technical Implementation:
* - Maintains a `downloads` Map instance that tracks:
* - Currently downloading resources
* - Paused downloads
* - Successful downloads are automatically removed from the `downloads` Map
* - Uses the `ResourceSourceExtended` interface to enable pause/resume functionality:
* - Wraps user-provided `ResourceSource` elements
* - Implements linked list behavior via the `.next` attribute
* - Automatically processes subsequent downloads when `.next` contains a valid resource
*/
import { cacheDirectory, copyAsync, createDownloadResumable, moveAsync, FileSystemSessionType, writeAsStringAsync, EncodingType, deleteAsync, readDirectoryAsync } from 'expo-file-system/legacy';
import { Asset } from 'expo-asset';
import { Platform } from 'react-native';
import { RNEDirectory } from '../constants/directories';
import { ResourceFetcherUtils, HTTP_CODE, DownloadStatus, SourceType } from './ResourceFetcherUtils';
export class ResourceFetcher {
static downloads = new Map(); //map of currently downloading (or paused) files, if the download was started by .fetch() method.
static async fetch(callback = () => {}, ...sources) {
if (sources.length === 0) {
throw new Error('Empty list given as an argument!');
}
const {
results: info,
totalLength
} = await ResourceFetcherUtils.getFilesSizes(sources);
const head = {
source: info[0].source,
sourceType: info[0].type,
callback: info[0].type === SourceType.REMOTE_FILE ? ResourceFetcherUtils.calculateDownloadProgress(totalLength, info[0].previousFilesTotalLength, info[0].length, callback) : () => {},
results: []
};
let node = head;
for (let idx = 1; idx < sources.length; idx++) {
node.next = {
source: info[idx].source,
sourceType: info[idx].type,
callback: info[idx].type === SourceType.REMOTE_FILE ? ResourceFetcherUtils.calculateDownloadProgress(totalLength, info[idx].previousFilesTotalLength, info[idx].length, callback) : () => {},
results: []
};
node = node.next;
}
return this.singleFetch(head);
}
static async singleFetch(sourceExtended) {
const source = sourceExtended.source;
switch (sourceExtended.sourceType) {
case SourceType.OBJECT:
{
return this.returnOrStartNext(sourceExtended, await this.handleObject(source));
}
case SourceType.LOCAL_FILE:
{
return this.returnOrStartNext(sourceExtended, this.handleLocalFile(source));
}
case SourceType.RELEASE_MODE_FILE:
{
return this.returnOrStartNext(sourceExtended, await this.handleReleaseModeFile(sourceExtended));
}
case SourceType.DEV_MODE_FILE:
{
const result = await this.handleDevModeFile(sourceExtended);
if (result !== null) {
return this.returnOrStartNext(sourceExtended, result);
}
return null;
}
default:
{
//case SourceType.REMOTE_FILE
const result = await this.handleRemoteFile(sourceExtended);
if (result !== null) {
return this.returnOrStartNext(sourceExtended, result);
}
return null;
}
}
}
//if any download ends successfully this function is called - it checks whether it should trigger next download or return list of paths.
static returnOrStartNext(sourceExtended, result) {
sourceExtended.results.push(result);
if (sourceExtended.next) {
const nextSource = sourceExtended.next;
nextSource.results.push(...sourceExtended.results);
return this.singleFetch(nextSource);
}
sourceExtended.callback(1);
return sourceExtended.results;
}
static async pause(source) {
const resource = this.downloads.get(source);
switch (resource.status) {
case DownloadStatus.PAUSED:
throw new Error("The file download is currently paused. Can't pause the download of the same file twice.");
default:
{
resource.status = DownloadStatus.PAUSED;
await resource.downloadResumable.pauseAsync();
}
}
}
static async resume(source) {
const resource = this.downloads.get(source);
if (!resource.extendedInfo.fileUri || !resource.extendedInfo.cacheFileUri || !resource.extendedInfo.uri) {
throw new Error('Something went wrong. File uri info is not specified!');
}
switch (resource.status) {
case DownloadStatus.ONGOING:
throw new Error("The file download is currently ongoing. Can't resume the ongoing download.");
default:
{
resource.status = DownloadStatus.ONGOING;
const result = await resource.downloadResumable.resumeAsync();
if (!this.downloads.has(source) || this.downloads.get(source).status === DownloadStatus.PAUSED) {
//if canceled or paused after earlier resuming.
return null;
}
if (!result || result.status !== HTTP_CODE.OK && result.status !== HTTP_CODE.PARTIAL_CONTENT) {
throw new Error(`Failed to fetch resource from '${resource.extendedInfo.uri}'`);
}
await moveAsync({
from: resource.extendedInfo.cacheFileUri,
to: resource.extendedInfo.fileUri
});
this.downloads.delete(source);
ResourceFetcherUtils.triggerHuggingFaceDownloadCounter(resource.extendedInfo.uri);
return this.returnOrStartNext(resource.extendedInfo, ResourceFetcherUtils.removeFilePrefix(resource.extendedInfo.fileUri));
}
}
}
static async cancel(source) {
const resource = this.downloads.get(source);
await resource.downloadResumable.cancelAsync();
this.downloads.delete(source);
}
static async pauseFetching(...sources) {
const source = this.findActive(sources);
await this.pause(source);
}
static async resumeFetching(...sources) {
const source = this.findActive(sources);
await this.resume(source);
}
static async cancelFetching(...sources) {
const source = this.findActive(sources);
await this.cancel(source);
}
static findActive(sources) {
for (const source of sources) {
if (this.downloads.has(source)) {
return source;
}
}
throw new Error('None of given sources are currently during downloading process.');
}
static async listDownloadedFiles() {
const files = await readDirectoryAsync(RNEDirectory);
return files.map(file => `${RNEDirectory}${file}`);
}
static async listDownloadedModels() {
const files = await this.listDownloadedFiles();
return files.filter(file => file.endsWith('.pte'));
}
static async deleteResources(...sources) {
for (const source of sources) {
const filename = ResourceFetcherUtils.getFilenameFromUri(source);
const fileUri = `${RNEDirectory}${filename}`;
if (await ResourceFetcherUtils.checkFileExists(fileUri)) {
await deleteAsync(fileUri);
}
}
}
static async getFilesTotalSize(...sources) {
return (await ResourceFetcherUtils.getFilesSizes(sources)).totalLength;
}
static async handleObject(source) {
if (typeof source !== 'object') {
throw new Error('Source is expected to be object!');
}
const jsonString = JSON.stringify(source);
const digest = ResourceFetcherUtils.hashObject(jsonString);
const filename = `${digest}.json`;
const path = `${RNEDirectory}${filename}`;
if (await ResourceFetcherUtils.checkFileExists(path)) {
return ResourceFetcherUtils.removeFilePrefix(path);
}
await ResourceFetcherUtils.createDirectoryIfNoExists();
await writeAsStringAsync(path, jsonString, {
encoding: EncodingType.UTF8
});
return ResourceFetcherUtils.removeFilePrefix(path);
}
static handleLocalFile(source) {
if (typeof source !== 'string') {
throw new Error('Source is expected to be string.');
}
return ResourceFetcherUtils.removeFilePrefix(source);
}
static async handleReleaseModeFile(sourceExtended) {
const source = sourceExtended.source;
if (typeof source !== 'number') {
throw new Error('Source is expected to be string.');
}
const asset = Asset.fromModule(source);
const uri = asset.uri;
const filename = ResourceFetcherUtils.getFilenameFromUri(uri);
const fileUri = `${RNEDirectory}${filename}`;
// On Android, file uri does not contain file extension, so we add it manually
const fileUriWithType = Platform.OS === 'android' ? `${fileUri}.${asset.type}` : fileUri;
if (await ResourceFetcherUtils.checkFileExists(fileUri)) {
return ResourceFetcherUtils.removeFilePrefix(fileUri);
}
await ResourceFetcherUtils.createDirectoryIfNoExists();
await copyAsync({
from: asset.uri,
to: fileUriWithType
});
return ResourceFetcherUtils.removeFilePrefix(fileUriWithType);
}
static async handleDevModeFile(sourceExtended) {
const source = sourceExtended.source;
if (typeof source !== 'number') {
throw new Error('Source is expected to be a number.');
}
sourceExtended.uri = Asset.fromModule(source).uri;
return await this.handleRemoteFile(sourceExtended);
}
static async handleRemoteFile(sourceExtended) {
const source = sourceExtended.source;
if (typeof source === 'object') {
throw new Error('Source is expected to be a string or a number.');
}
if (this.downloads.has(source)) {
const resource = this.downloads.get(source);
if (resource.status === DownloadStatus.PAUSED) {
// if the download is paused, `fetch` is treated like `resume`
this.resume(source);
}
// if the download is ongoing, throw error.
throw new Error('Already downloading this file.');
}
if (typeof source === 'number' && !sourceExtended.uri) {
throw new Error('Source Uri is expected to be available here.');
}
if (typeof source === 'string') {
sourceExtended.uri = source;
}
const uri = sourceExtended.uri;
const filename = ResourceFetcherUtils.getFilenameFromUri(uri);
sourceExtended.fileUri = `${RNEDirectory}${filename}`;
sourceExtended.cacheFileUri = `${cacheDirectory}${filename}`;
if (await ResourceFetcherUtils.checkFileExists(sourceExtended.fileUri)) {
return ResourceFetcherUtils.removeFilePrefix(sourceExtended.fileUri);
}
await ResourceFetcherUtils.createDirectoryIfNoExists();
const downloadResumable = createDownloadResumable(uri, sourceExtended.cacheFileUri, {
sessionType: FileSystemSessionType.BACKGROUND
}, ({
totalBytesWritten,
totalBytesExpectedToWrite
}) => {
if (totalBytesExpectedToWrite === -1) {
// If totalBytesExpectedToWrite is -1, it means the server does not provide content length.
sourceExtended.callback(0);
return;
}
sourceExtended.callback(totalBytesWritten / totalBytesExpectedToWrite);
});
//create value for the this.download Map
const downloadResource = {
downloadResumable: downloadResumable,
status: DownloadStatus.ONGOING,
extendedInfo: sourceExtended
};
//add key-value pair to map
this.downloads.set(source, downloadResource);
const result = await downloadResumable.downloadAsync();
if (!this.downloads.has(source) || this.downloads.get(source).status === DownloadStatus.PAUSED) {
// if canceled or paused during the download
return null;
}
if (!result || result.status !== HTTP_CODE.OK) {
throw new Error(`Failed to fetch resource from '${source}'`);
}
await moveAsync({
from: sourceExtended.cacheFileUri,
to: sourceExtended.fileUri
});
this.downloads.delete(source);
ResourceFetcherUtils.triggerHuggingFaceDownloadCounter(uri);
return ResourceFetcherUtils.removeFilePrefix(sourceExtended.fileUri);
}
}
//# sourceMappingURL=ResourceFetcher.js.map