UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

419 lines (389 loc) 14 kB
/** * Resource Fetcher * * Provides an interface for downloading files (via `ResourceFetcher.fetch()`) * * Key functionality: * - Download control: pause, resume, and cancel operations through: * - Single file: `.pauseFetching()`, `.resumeFetching()`, `.cancelFetching()` * - Downloaded file management: * - `.getFilesTotalSize()`, `.listDownloadedFiles()`, `.listDownloadedModels()`, `.deleteResources()` * * Remark: The pausing/resuming/canceling works only for fetching remote resources. * * Most exported functions accept: * - Multiple `ResourceSource` arguments, (union type of string, number or object) * * Method `.fetch()` takes argument as callback that reports download progress. * Method`.fetch()` returns array of paths to successfully saved files or null if the download was paused or cancelled (then resume functions can return paths). * * Technical Implementation: * - Maintains a `downloads` Map instance that tracks: * - Currently downloading resources * - Paused downloads * - Successful downloads are automatically removed from the `downloads` Map * - Uses the `ResourceSourceExtended` interface to enable pause/resume functionality: * - Wraps user-provided `ResourceSource` elements * - Implements linked list behavior via the `.next` attribute * - Automatically processes subsequent downloads when `.next` contains a valid resource */ import { cacheDirectory, copyAsync, createDownloadResumable, moveAsync, FileSystemSessionType, writeAsStringAsync, EncodingType, deleteAsync, readDirectoryAsync, } from 'expo-file-system/legacy'; import { Asset } from 'expo-asset'; import { Platform } from 'react-native'; import { RNEDirectory } from '../constants/directories'; import { ResourceSource } from '../types/common'; import { ResourceFetcherUtils, HTTP_CODE, DownloadStatus, SourceType, ResourceSourceExtended, DownloadResource, } from './ResourceFetcherUtils'; export class ResourceFetcher { static downloads = new Map<ResourceSource, DownloadResource>(); //map of currently downloading (or paused) files, if the download was started by .fetch() method. static async fetch( callback: (downloadProgress: number) => void = () => {}, ...sources: ResourceSource[] ) { if (sources.length === 0) { throw new Error('Empty list given as an argument!'); } const { results: info, totalLength } = await ResourceFetcherUtils.getFilesSizes(sources); const head: ResourceSourceExtended = { source: info[0]!.source, sourceType: info[0]!.type, callback: info[0]!.type === SourceType.REMOTE_FILE ? ResourceFetcherUtils.calculateDownloadProgress( totalLength, info[0]!.previousFilesTotalLength, info[0]!.length, callback ) : () => {}, results: [], }; let node = head; for (let idx = 1; idx < sources.length; idx++) { node.next = { source: info[idx]!.source, sourceType: info[idx]!.type, callback: info[idx]!.type === SourceType.REMOTE_FILE ? ResourceFetcherUtils.calculateDownloadProgress( totalLength, info[idx]!.previousFilesTotalLength, info[idx]!.length, callback ) : () => {}, results: [], }; node = node.next; } return this.singleFetch(head); } private static async singleFetch( sourceExtended: ResourceSourceExtended ): Promise<string[] | null> { const source = sourceExtended.source; switch (sourceExtended.sourceType) { case SourceType.OBJECT: { return this.returnOrStartNext( sourceExtended, await this.handleObject(source) ); } case SourceType.LOCAL_FILE: { return this.returnOrStartNext( sourceExtended, this.handleLocalFile(source) ); } case SourceType.RELEASE_MODE_FILE: { return this.returnOrStartNext( sourceExtended, await this.handleReleaseModeFile(sourceExtended) ); } case SourceType.DEV_MODE_FILE: { const result = await this.handleDevModeFile(sourceExtended); if (result !== null) { return this.returnOrStartNext(sourceExtended, result); } return null; } default: { //case SourceType.REMOTE_FILE const result = await this.handleRemoteFile(sourceExtended); if (result !== null) { return this.returnOrStartNext(sourceExtended, result); } return null; } } } //if any download ends successfully this function is called - it checks whether it should trigger next download or return list of paths. private static returnOrStartNext( sourceExtended: ResourceSourceExtended, result: string ) { sourceExtended.results.push(result); if (sourceExtended.next) { const nextSource = sourceExtended.next; nextSource.results.push(...sourceExtended.results); return this.singleFetch(nextSource); } sourceExtended.callback!(1); return sourceExtended.results; } private static async pause(source: ResourceSource) { const resource = this.downloads.get(source)!; switch (resource.status) { case DownloadStatus.PAUSED: throw new Error( "The file download is currently paused. Can't pause the download of the same file twice." ); default: { resource.status = DownloadStatus.PAUSED; await resource.downloadResumable.pauseAsync(); } } } private static async resume(source: ResourceSource) { const resource = this.downloads.get(source)!; if ( !resource.extendedInfo.fileUri || !resource.extendedInfo.cacheFileUri || !resource.extendedInfo.uri ) { throw new Error('Something went wrong. File uri info is not specified!'); } switch (resource.status) { case DownloadStatus.ONGOING: throw new Error( "The file download is currently ongoing. Can't resume the ongoing download." ); default: { resource.status = DownloadStatus.ONGOING; const result = await resource.downloadResumable.resumeAsync(); if ( !this.downloads.has(source) || this.downloads.get(source)!.status === DownloadStatus.PAUSED ) { //if canceled or paused after earlier resuming. return null; } if ( !result || (result.status !== HTTP_CODE.OK && result.status !== HTTP_CODE.PARTIAL_CONTENT) ) { throw new Error( `Failed to fetch resource from '${resource.extendedInfo.uri}'` ); } await moveAsync({ from: resource.extendedInfo.cacheFileUri, to: resource.extendedInfo.fileUri, }); this.downloads.delete(source); ResourceFetcherUtils.triggerHuggingFaceDownloadCounter( resource.extendedInfo.uri ); return this.returnOrStartNext( resource.extendedInfo, ResourceFetcherUtils.removeFilePrefix(resource.extendedInfo.fileUri) ); } } } private static async cancel(source: ResourceSource) { const resource = this.downloads.get(source)!; await resource.downloadResumable.cancelAsync(); this.downloads.delete(source); } static async pauseFetching(...sources: ResourceSource[]) { const source = this.findActive(sources); await this.pause(source); } static async resumeFetching(...sources: ResourceSource[]) { const source = this.findActive(sources); await this.resume(source); } static async cancelFetching(...sources: ResourceSource[]) { const source = this.findActive(sources); await this.cancel(source); } private static findActive(sources: ResourceSource[]) { for (const source of sources) { if (this.downloads.has(source)) { return source; } } throw new Error( 'None of given sources are currently during downloading process.' ); } static async listDownloadedFiles() { const files = await readDirectoryAsync(RNEDirectory); return files.map((file) => `${RNEDirectory}${file}`); } static async listDownloadedModels() { const files = await this.listDownloadedFiles(); return files.filter((file) => file.endsWith('.pte')); } static async deleteResources(...sources: ResourceSource[]) { for (const source of sources) { const filename = ResourceFetcherUtils.getFilenameFromUri( source as string ); const fileUri = `${RNEDirectory}${filename}`; if (await ResourceFetcherUtils.checkFileExists(fileUri)) { await deleteAsync(fileUri); } } } static async getFilesTotalSize(...sources: ResourceSource[]) { return (await ResourceFetcherUtils.getFilesSizes(sources)).totalLength; } private static async handleObject(source: ResourceSource) { if (typeof source !== 'object') { throw new Error('Source is expected to be object!'); } const jsonString = JSON.stringify(source); const digest = ResourceFetcherUtils.hashObject(jsonString); const filename = `${digest}.json`; const path = `${RNEDirectory}${filename}`; if (await ResourceFetcherUtils.checkFileExists(path)) { return ResourceFetcherUtils.removeFilePrefix(path); } await ResourceFetcherUtils.createDirectoryIfNoExists(); await writeAsStringAsync(path, jsonString, { encoding: EncodingType.UTF8, }); return ResourceFetcherUtils.removeFilePrefix(path); } private static handleLocalFile(source: ResourceSource) { if (typeof source !== 'string') { throw new Error('Source is expected to be string.'); } return ResourceFetcherUtils.removeFilePrefix(source); } private static async handleReleaseModeFile( sourceExtended: ResourceSourceExtended ) { const source = sourceExtended.source; if (typeof source !== 'number') { throw new Error('Source is expected to be string.'); } const asset = Asset.fromModule(source); const uri = asset.uri; const filename = ResourceFetcherUtils.getFilenameFromUri(uri); const fileUri = `${RNEDirectory}${filename}`; // On Android, file uri does not contain file extension, so we add it manually const fileUriWithType = Platform.OS === 'android' ? `${fileUri}.${asset.type}` : fileUri; if (await ResourceFetcherUtils.checkFileExists(fileUri)) { return ResourceFetcherUtils.removeFilePrefix(fileUri); } await ResourceFetcherUtils.createDirectoryIfNoExists(); await copyAsync({ from: asset.uri, to: fileUriWithType, }); return ResourceFetcherUtils.removeFilePrefix(fileUriWithType); } private static async handleDevModeFile( sourceExtended: ResourceSourceExtended ) { const source = sourceExtended.source; if (typeof source !== 'number') { throw new Error('Source is expected to be a number.'); } sourceExtended.uri = Asset.fromModule(source).uri; return await this.handleRemoteFile(sourceExtended); } private static async handleRemoteFile( sourceExtended: ResourceSourceExtended ) { const source = sourceExtended.source; if (typeof source === 'object') { throw new Error('Source is expected to be a string or a number.'); } if (this.downloads.has(source)) { const resource = this.downloads.get(source)!; if (resource.status === DownloadStatus.PAUSED) { // if the download is paused, `fetch` is treated like `resume` this.resume(source); } // if the download is ongoing, throw error. throw new Error('Already downloading this file.'); } if (typeof source === 'number' && !sourceExtended.uri) { throw new Error('Source Uri is expected to be available here.'); } if (typeof source === 'string') { sourceExtended.uri = source; } const uri = sourceExtended.uri!; const filename = ResourceFetcherUtils.getFilenameFromUri(uri); sourceExtended.fileUri = `${RNEDirectory}${filename}`; sourceExtended.cacheFileUri = `${cacheDirectory}${filename}`; if (await ResourceFetcherUtils.checkFileExists(sourceExtended.fileUri)) { return ResourceFetcherUtils.removeFilePrefix(sourceExtended.fileUri); } await ResourceFetcherUtils.createDirectoryIfNoExists(); const downloadResumable = createDownloadResumable( uri, sourceExtended.cacheFileUri, { sessionType: FileSystemSessionType.BACKGROUND }, ({ totalBytesWritten, totalBytesExpectedToWrite }) => { if (totalBytesExpectedToWrite === -1) { // If totalBytesExpectedToWrite is -1, it means the server does not provide content length. sourceExtended.callback!(0); return; } sourceExtended.callback!(totalBytesWritten / totalBytesExpectedToWrite); } ); //create value for the this.download Map const downloadResource: DownloadResource = { downloadResumable: downloadResumable, status: DownloadStatus.ONGOING, extendedInfo: sourceExtended, }; //add key-value pair to map this.downloads.set(source, downloadResource); const result = await downloadResumable.downloadAsync(); if ( !this.downloads.has(source) || this.downloads.get(source)!.status === DownloadStatus.PAUSED ) { // if canceled or paused during the download return null; } if (!result || result.status !== HTTP_CODE.OK) { throw new Error(`Failed to fetch resource from '${source}'`); } await moveAsync({ from: sourceExtended.cacheFileUri, to: sourceExtended.fileUri, }); this.downloads.delete(source); ResourceFetcherUtils.triggerHuggingFaceDownloadCounter(uri); return ResourceFetcherUtils.removeFilePrefix(sourceExtended.fileUri); } }