UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

317 lines (313 loc) 12.9 kB
"use strict"; /** * Resource Fetcher * * Provides an interface for downloading files (via `ResourceFetcher.fetch()`) * * Key functionality: * - Download control: pause, resume, and cancel operations through: * - Single file: `.pauseFetching()`, `.resumeFetching()`, `.cancelFetching()` * - Downloaded file management: * - `.getFilesTotalSize()`, `.listDownloadedFiles()`, `.listDownloadedModels()`, `.deleteResources()` * * Remark: The pausing/resuming/canceling works only for fetching remote resources. * * Most exported functions accept: * - Multiple `ResourceSource` arguments, (union type of string, number or object) * * Method `.fetch()` takes argument as callback that reports download progress. * Method`.fetch()` returns array of paths to successfully saved files or null if the download was paused or cancelled (then resume functions can return paths). * * Technical Implementation: * - Maintains a `downloads` Map instance that tracks: * - Currently downloading resources * - Paused downloads * - Successful downloads are automatically removed from the `downloads` Map * - Uses the `ResourceSourceExtended` interface to enable pause/resume functionality: * - Wraps user-provided `ResourceSource` elements * - Implements linked list behavior via the `.next` attribute * - Automatically processes subsequent downloads when `.next` contains a valid resource */ import { cacheDirectory, copyAsync, createDownloadResumable, moveAsync, FileSystemSessionType, writeAsStringAsync, EncodingType, deleteAsync, readDirectoryAsync } from 'expo-file-system/legacy'; import { Asset } from 'expo-asset'; import { Platform } from 'react-native'; import { RNEDirectory } from '../constants/directories'; import { ResourceFetcherUtils, HTTP_CODE, DownloadStatus, SourceType } from './ResourceFetcherUtils'; export class ResourceFetcher { static downloads = new Map(); //map of currently downloading (or paused) files, if the download was started by .fetch() method. static async fetch(callback = () => {}, ...sources) { if (sources.length === 0) { throw new Error('Empty list given as an argument!'); } const { results: info, totalLength } = await ResourceFetcherUtils.getFilesSizes(sources); const head = { source: info[0].source, sourceType: info[0].type, callback: info[0].type === SourceType.REMOTE_FILE ? ResourceFetcherUtils.calculateDownloadProgress(totalLength, info[0].previousFilesTotalLength, info[0].length, callback) : () => {}, results: [] }; let node = head; for (let idx = 1; idx < sources.length; idx++) { node.next = { source: info[idx].source, sourceType: info[idx].type, callback: info[idx].type === SourceType.REMOTE_FILE ? ResourceFetcherUtils.calculateDownloadProgress(totalLength, info[idx].previousFilesTotalLength, info[idx].length, callback) : () => {}, results: [] }; node = node.next; } return this.singleFetch(head); } static async singleFetch(sourceExtended) { const source = sourceExtended.source; switch (sourceExtended.sourceType) { case SourceType.OBJECT: { return this.returnOrStartNext(sourceExtended, await this.handleObject(source)); } case SourceType.LOCAL_FILE: { return this.returnOrStartNext(sourceExtended, this.handleLocalFile(source)); } case SourceType.RELEASE_MODE_FILE: { return this.returnOrStartNext(sourceExtended, await this.handleReleaseModeFile(sourceExtended)); } case SourceType.DEV_MODE_FILE: { const result = await this.handleDevModeFile(sourceExtended); if (result !== null) { return this.returnOrStartNext(sourceExtended, result); } return null; } default: { //case SourceType.REMOTE_FILE const result = await this.handleRemoteFile(sourceExtended); if (result !== null) { return this.returnOrStartNext(sourceExtended, result); } return null; } } } //if any download ends successfully this function is called - it checks whether it should trigger next download or return list of paths. static returnOrStartNext(sourceExtended, result) { sourceExtended.results.push(result); if (sourceExtended.next) { const nextSource = sourceExtended.next; nextSource.results.push(...sourceExtended.results); return this.singleFetch(nextSource); } sourceExtended.callback(1); return sourceExtended.results; } static async pause(source) { const resource = this.downloads.get(source); switch (resource.status) { case DownloadStatus.PAUSED: throw new Error("The file download is currently paused. Can't pause the download of the same file twice."); default: { resource.status = DownloadStatus.PAUSED; await resource.downloadResumable.pauseAsync(); } } } static async resume(source) { const resource = this.downloads.get(source); if (!resource.extendedInfo.fileUri || !resource.extendedInfo.cacheFileUri || !resource.extendedInfo.uri) { throw new Error('Something went wrong. File uri info is not specified!'); } switch (resource.status) { case DownloadStatus.ONGOING: throw new Error("The file download is currently ongoing. Can't resume the ongoing download."); default: { resource.status = DownloadStatus.ONGOING; const result = await resource.downloadResumable.resumeAsync(); if (!this.downloads.has(source) || this.downloads.get(source).status === DownloadStatus.PAUSED) { //if canceled or paused after earlier resuming. return null; } if (!result || result.status !== HTTP_CODE.OK && result.status !== HTTP_CODE.PARTIAL_CONTENT) { throw new Error(`Failed to fetch resource from '${resource.extendedInfo.uri}'`); } await moveAsync({ from: resource.extendedInfo.cacheFileUri, to: resource.extendedInfo.fileUri }); this.downloads.delete(source); ResourceFetcherUtils.triggerHuggingFaceDownloadCounter(resource.extendedInfo.uri); return this.returnOrStartNext(resource.extendedInfo, ResourceFetcherUtils.removeFilePrefix(resource.extendedInfo.fileUri)); } } } static async cancel(source) { const resource = this.downloads.get(source); await resource.downloadResumable.cancelAsync(); this.downloads.delete(source); } static async pauseFetching(...sources) { const source = this.findActive(sources); await this.pause(source); } static async resumeFetching(...sources) { const source = this.findActive(sources); await this.resume(source); } static async cancelFetching(...sources) { const source = this.findActive(sources); await this.cancel(source); } static findActive(sources) { for (const source of sources) { if (this.downloads.has(source)) { return source; } } throw new Error('None of given sources are currently during downloading process.'); } static async listDownloadedFiles() { const files = await readDirectoryAsync(RNEDirectory); return files.map(file => `${RNEDirectory}${file}`); } static async listDownloadedModels() { const files = await this.listDownloadedFiles(); return files.filter(file => file.endsWith('.pte')); } static async deleteResources(...sources) { for (const source of sources) { const filename = ResourceFetcherUtils.getFilenameFromUri(source); const fileUri = `${RNEDirectory}${filename}`; if (await ResourceFetcherUtils.checkFileExists(fileUri)) { await deleteAsync(fileUri); } } } static async getFilesTotalSize(...sources) { return (await ResourceFetcherUtils.getFilesSizes(sources)).totalLength; } static async handleObject(source) { if (typeof source !== 'object') { throw new Error('Source is expected to be object!'); } const jsonString = JSON.stringify(source); const digest = ResourceFetcherUtils.hashObject(jsonString); const filename = `${digest}.json`; const path = `${RNEDirectory}${filename}`; if (await ResourceFetcherUtils.checkFileExists(path)) { return ResourceFetcherUtils.removeFilePrefix(path); } await ResourceFetcherUtils.createDirectoryIfNoExists(); await writeAsStringAsync(path, jsonString, { encoding: EncodingType.UTF8 }); return ResourceFetcherUtils.removeFilePrefix(path); } static handleLocalFile(source) { if (typeof source !== 'string') { throw new Error('Source is expected to be string.'); } return ResourceFetcherUtils.removeFilePrefix(source); } static async handleReleaseModeFile(sourceExtended) { const source = sourceExtended.source; if (typeof source !== 'number') { throw new Error('Source is expected to be string.'); } const asset = Asset.fromModule(source); const uri = asset.uri; const filename = ResourceFetcherUtils.getFilenameFromUri(uri); const fileUri = `${RNEDirectory}${filename}`; // On Android, file uri does not contain file extension, so we add it manually const fileUriWithType = Platform.OS === 'android' ? `${fileUri}.${asset.type}` : fileUri; if (await ResourceFetcherUtils.checkFileExists(fileUri)) { return ResourceFetcherUtils.removeFilePrefix(fileUri); } await ResourceFetcherUtils.createDirectoryIfNoExists(); await copyAsync({ from: asset.uri, to: fileUriWithType }); return ResourceFetcherUtils.removeFilePrefix(fileUriWithType); } static async handleDevModeFile(sourceExtended) { const source = sourceExtended.source; if (typeof source !== 'number') { throw new Error('Source is expected to be a number.'); } sourceExtended.uri = Asset.fromModule(source).uri; return await this.handleRemoteFile(sourceExtended); } static async handleRemoteFile(sourceExtended) { const source = sourceExtended.source; if (typeof source === 'object') { throw new Error('Source is expected to be a string or a number.'); } if (this.downloads.has(source)) { const resource = this.downloads.get(source); if (resource.status === DownloadStatus.PAUSED) { // if the download is paused, `fetch` is treated like `resume` this.resume(source); } // if the download is ongoing, throw error. throw new Error('Already downloading this file.'); } if (typeof source === 'number' && !sourceExtended.uri) { throw new Error('Source Uri is expected to be available here.'); } if (typeof source === 'string') { sourceExtended.uri = source; } const uri = sourceExtended.uri; const filename = ResourceFetcherUtils.getFilenameFromUri(uri); sourceExtended.fileUri = `${RNEDirectory}${filename}`; sourceExtended.cacheFileUri = `${cacheDirectory}${filename}`; if (await ResourceFetcherUtils.checkFileExists(sourceExtended.fileUri)) { return ResourceFetcherUtils.removeFilePrefix(sourceExtended.fileUri); } await ResourceFetcherUtils.createDirectoryIfNoExists(); const downloadResumable = createDownloadResumable(uri, sourceExtended.cacheFileUri, { sessionType: FileSystemSessionType.BACKGROUND }, ({ totalBytesWritten, totalBytesExpectedToWrite }) => { if (totalBytesExpectedToWrite === -1) { // If totalBytesExpectedToWrite is -1, it means the server does not provide content length. sourceExtended.callback(0); return; } sourceExtended.callback(totalBytesWritten / totalBytesExpectedToWrite); }); //create value for the this.download Map const downloadResource = { downloadResumable: downloadResumable, status: DownloadStatus.ONGOING, extendedInfo: sourceExtended }; //add key-value pair to map this.downloads.set(source, downloadResource); const result = await downloadResumable.downloadAsync(); if (!this.downloads.has(source) || this.downloads.get(source).status === DownloadStatus.PAUSED) { // if canceled or paused during the download return null; } if (!result || result.status !== HTTP_CODE.OK) { throw new Error(`Failed to fetch resource from '${source}'`); } await moveAsync({ from: sourceExtended.cacheFileUri, to: sourceExtended.fileUri }); this.downloads.delete(source); ResourceFetcherUtils.triggerHuggingFaceDownloadCounter(uri); return ResourceFetcherUtils.removeFilePrefix(sourceExtended.fileUri); } } //# sourceMappingURL=ResourceFetcher.js.map