UNPKG

react-native-executorch

Version:

An easy way to run AI models in react native with ExecuTorch

156 lines (151 loc) 5.61 kB
"use strict"; import { cacheDirectory, createDownloadResumable, getInfoAsync, makeDirectoryAsync, moveAsync, FileSystemSessionType, writeAsStringAsync, EncodingType, deleteAsync, readDirectoryAsync } from 'expo-file-system'; import { Asset } from 'expo-asset'; import { RNEDirectory } from '../constants/directories'; export class ResourceFetcher { static async fetch(source, callback = () => {}) { if (typeof source === 'object') { return this.handleObject(source); } const uri = typeof source === 'number' ? Asset.fromModule(source).uri : source; // Handle local files if (uri.startsWith('file://')) { return this.removeFilePrefix(uri); } const filename = this.getFilenameFromUri(uri); const fileUri = `${RNEDirectory}${filename}`; if (await this.checkFileExists(fileUri)) { return this.removeFilePrefix(fileUri); } await this.createDirectoryIfNoExists(); // Handle local asset files in release mode if (!uri.includes('://')) { const asset = Asset.fromModule(source); const fileUriWithType = `${fileUri}.${asset.type}`; await asset.downloadAsync(); if (!asset.localUri) { throw new Error(`Asset local URI is not available for ${source}`); } await moveAsync({ from: asset.localUri, to: fileUriWithType }); return this.removeFilePrefix(fileUriWithType); } // Handle remote file download const cacheFileUri = `${cacheDirectory}${filename}`; const downloadResumable = createDownloadResumable(uri, cacheFileUri, { sessionType: FileSystemSessionType.BACKGROUND }, ({ totalBytesWritten, totalBytesExpectedToWrite }) => { callback(totalBytesWritten / totalBytesExpectedToWrite); }); const result = await downloadResumable.downloadAsync(); if (!result || result.status !== 200) { throw new Error(`Failed to fetch resource from '${uri}'`); } await moveAsync({ from: cacheFileUri, to: fileUri }); this.triggerHuggingFaceDownloadCounter(uri); return this.removeFilePrefix(fileUri); } static async fetchMultipleResources(callback = () => {}, ...sources) { const paths = []; for (let idx = 0; idx < sources.length; idx++) { paths.push(await this.fetch(sources[idx], this.calculateDownloadProgress(sources.length, idx, callback))); } return paths; } static async deleteMultipleResources(...sources) { for (const source of sources) { const filename = this.getFilenameFromUri(source); const fileUri = `${RNEDirectory}${filename}`; if (await this.checkFileExists(fileUri)) { await deleteAsync(fileUri); } } } static calculateDownloadProgress(numberOfFiles, currentFileIndex, setProgress) { return progress => { if (progress === 1 && currentFileIndex === numberOfFiles - 1) { setProgress(1); return; } const contributionPerFile = 1 / numberOfFiles; const baseProgress = contributionPerFile * currentFileIndex; const scaledProgress = progress * contributionPerFile; const updatedProgress = baseProgress + scaledProgress; setProgress(updatedProgress); }; } static async listDownloadedFiles() { const files = await readDirectoryAsync(RNEDirectory); return files.map(file => `${RNEDirectory}${file}`); } static async listDownloadedModels() { const files = await this.listDownloadedFiles(); return files.filter(file => file.endsWith('.pte')); } static async handleObject(source) { const jsonString = JSON.stringify(source); const digest = this.hashObject(jsonString); const filename = `${digest}.json`; const path = `${RNEDirectory}${filename}`; if (await this.checkFileExists(path)) { return this.removeFilePrefix(path); } await this.createDirectoryIfNoExists(); await writeAsStringAsync(path, jsonString, { encoding: EncodingType.UTF8 }); return this.removeFilePrefix(path); } static getFilenameFromUri(uri) { let cleanUri = uri.replace(/^https?:\/\//, ''); cleanUri = cleanUri.split('?')?.[0]?.split('#')?.[0] ?? cleanUri; return cleanUri.replace(/[^a-zA-Z0-9._-]/g, '_'); } static removeFilePrefix(uri) { return uri.startsWith('file://') ? uri.slice(7) : uri; } static hashObject(jsonString) { let hash = 0; for (let i = 0; i < jsonString.length; i++) { // eslint-disable-next-line no-bitwise hash = (hash << 5) - hash + jsonString.charCodeAt(i); // eslint-disable-next-line no-bitwise hash |= 0; } // eslint-disable-next-line no-bitwise return (hash >>> 0).toString(); } /* * Increments the Hugging Face download counter if the URI points to a Software Mansion Hugging Face repo. * More information: https://huggingface.co/docs/hub/models-download-stats */ static triggerHuggingFaceDownloadCounter(uri) { const url = new URL(uri); if (url.host === 'huggingface.co' && url.pathname.startsWith('/software-mansion/')) { const baseUrl = `${url.protocol}//${url.host}${url.pathname.split('resolve')[0]}`; fetch(`${baseUrl}resolve/main/config.json`, { method: 'HEAD' }); } } static async createDirectoryIfNoExists() { if (!(await this.checkFileExists(RNEDirectory))) { await makeDirectoryAsync(RNEDirectory, { intermediates: true }); } } static async checkFileExists(fileUri) { const fileInfo = await getInfoAsync(fileUri); return fileInfo.exists; } } //# sourceMappingURL=ResourceFetcher.js.map