react-native-executorch
Version:
An easy way to run AI models in react native with ExecuTorch
156 lines (151 loc) • 5.61 kB
JavaScript
;
import { cacheDirectory, createDownloadResumable, getInfoAsync, makeDirectoryAsync, moveAsync, FileSystemSessionType, writeAsStringAsync, EncodingType, deleteAsync, readDirectoryAsync } from 'expo-file-system';
import { Asset } from 'expo-asset';
import { RNEDirectory } from '../constants/directories';
export class ResourceFetcher {
static async fetch(source, callback = () => {}) {
if (typeof source === 'object') {
return this.handleObject(source);
}
const uri = typeof source === 'number' ? Asset.fromModule(source).uri : source;
// Handle local files
if (uri.startsWith('file://')) {
return this.removeFilePrefix(uri);
}
const filename = this.getFilenameFromUri(uri);
const fileUri = `${RNEDirectory}${filename}`;
if (await this.checkFileExists(fileUri)) {
return this.removeFilePrefix(fileUri);
}
await this.createDirectoryIfNoExists();
// Handle local asset files in release mode
if (!uri.includes('://')) {
const asset = Asset.fromModule(source);
const fileUriWithType = `${fileUri}.${asset.type}`;
await asset.downloadAsync();
if (!asset.localUri) {
throw new Error(`Asset local URI is not available for ${source}`);
}
await moveAsync({
from: asset.localUri,
to: fileUriWithType
});
return this.removeFilePrefix(fileUriWithType);
}
// Handle remote file download
const cacheFileUri = `${cacheDirectory}${filename}`;
const downloadResumable = createDownloadResumable(uri, cacheFileUri, {
sessionType: FileSystemSessionType.BACKGROUND
}, ({
totalBytesWritten,
totalBytesExpectedToWrite
}) => {
callback(totalBytesWritten / totalBytesExpectedToWrite);
});
const result = await downloadResumable.downloadAsync();
if (!result || result.status !== 200) {
throw new Error(`Failed to fetch resource from '${uri}'`);
}
await moveAsync({
from: cacheFileUri,
to: fileUri
});
this.triggerHuggingFaceDownloadCounter(uri);
return this.removeFilePrefix(fileUri);
}
static async fetchMultipleResources(callback = () => {}, ...sources) {
const paths = [];
for (let idx = 0; idx < sources.length; idx++) {
paths.push(await this.fetch(sources[idx], this.calculateDownloadProgress(sources.length, idx, callback)));
}
return paths;
}
static async deleteMultipleResources(...sources) {
for (const source of sources) {
const filename = this.getFilenameFromUri(source);
const fileUri = `${RNEDirectory}${filename}`;
if (await this.checkFileExists(fileUri)) {
await deleteAsync(fileUri);
}
}
}
static calculateDownloadProgress(numberOfFiles, currentFileIndex, setProgress) {
return progress => {
if (progress === 1 && currentFileIndex === numberOfFiles - 1) {
setProgress(1);
return;
}
const contributionPerFile = 1 / numberOfFiles;
const baseProgress = contributionPerFile * currentFileIndex;
const scaledProgress = progress * contributionPerFile;
const updatedProgress = baseProgress + scaledProgress;
setProgress(updatedProgress);
};
}
static async listDownloadedFiles() {
const files = await readDirectoryAsync(RNEDirectory);
return files.map(file => `${RNEDirectory}${file}`);
}
static async listDownloadedModels() {
const files = await this.listDownloadedFiles();
return files.filter(file => file.endsWith('.pte'));
}
static async handleObject(source) {
const jsonString = JSON.stringify(source);
const digest = this.hashObject(jsonString);
const filename = `${digest}.json`;
const path = `${RNEDirectory}${filename}`;
if (await this.checkFileExists(path)) {
return this.removeFilePrefix(path);
}
await this.createDirectoryIfNoExists();
await writeAsStringAsync(path, jsonString, {
encoding: EncodingType.UTF8
});
return this.removeFilePrefix(path);
}
static getFilenameFromUri(uri) {
let cleanUri = uri.replace(/^https?:\/\//, '');
cleanUri = cleanUri.split('?')?.[0]?.split('#')?.[0] ?? cleanUri;
return cleanUri.replace(/[^a-zA-Z0-9._-]/g, '_');
}
static removeFilePrefix(uri) {
return uri.startsWith('file://') ? uri.slice(7) : uri;
}
static hashObject(jsonString) {
let hash = 0;
for (let i = 0; i < jsonString.length; i++) {
// eslint-disable-next-line no-bitwise
hash = (hash << 5) - hash + jsonString.charCodeAt(i);
// eslint-disable-next-line no-bitwise
hash |= 0;
}
// eslint-disable-next-line no-bitwise
return (hash >>> 0).toString();
}
/*
* Increments the Hugging Face download counter if the URI points to a Software Mansion Hugging Face repo.
* More information: https://huggingface.co/docs/hub/models-download-stats
*/
static triggerHuggingFaceDownloadCounter(uri) {
const url = new URL(uri);
if (url.host === 'huggingface.co' && url.pathname.startsWith('/software-mansion/')) {
const baseUrl = `${url.protocol}//${url.host}${url.pathname.split('resolve')[0]}`;
fetch(`${baseUrl}resolve/main/config.json`, {
method: 'HEAD'
});
}
}
static async createDirectoryIfNoExists() {
if (!(await this.checkFileExists(RNEDirectory))) {
await makeDirectoryAsync(RNEDirectory, {
intermediates: true
});
}
}
static async checkFileExists(fileUri) {
const fileInfo = await getInfoAsync(fileUri);
return fileInfo.exists;
}
}
//# sourceMappingURL=ResourceFetcher.js.map