react-native-executorch
Version:
An easy way to run AI models in react native with ExecuTorch
92 lines (90 loc) • 2.71 kB
JavaScript
import { useEffect, useState } from 'react';
import { fetchResource } from '../utils/fetchResource';
import { ETError, getError } from '../Error';
import { getTypeIdentifier } from '../types/common';
export const useModule = ({
modelSource,
module
}) => {
const [error, setError] = useState(null);
const [isReady, setIsReady] = useState(false);
const [isGenerating, setIsGenerating] = useState(false);
const [downloadProgress, setDownloadProgress] = useState(0);
useEffect(() => {
const loadModel = async () => {
if (!modelSource) return;
try {
setIsReady(false);
const fileUri = await fetchResource(modelSource, setDownloadProgress);
await module.loadModule(fileUri);
setIsReady(true);
} catch (e) {
setError(getError(e));
}
};
loadModel();
}, [modelSource, module]);
const forwardImage = async input => {
if (!isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (isGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}
try {
setIsGenerating(true);
const output = await module.forward(input);
return output;
} catch (e) {
throw new Error(getError(e));
} finally {
setIsGenerating(false);
}
};
const forwardETInput = async (input, shape) => {
if (!isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (isGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}
// Since the native module expects an array of inputs and an array of shapes,
// if the user provides a single ETInput, we want to "unsqueeze" the array so
// the data is properly processed on the native side
if (!Array.isArray(input)) {
input = [input];
}
if (!Array.isArray(shape[0])) {
shape = [shape];
}
let inputTypeIdentifiers = [];
let modelInputs = [];
for (let idx = 0; idx < input.length; idx++) {
let currentInputTypeIdentifier = getTypeIdentifier(input[idx]);
if (currentInputTypeIdentifier === -1) {
throw new Error(getError(ETError.InvalidArgument));
}
inputTypeIdentifiers.push(currentInputTypeIdentifier);
modelInputs.push([...input[idx]]);
}
try {
setIsGenerating(true);
const output = await module.forward(modelInputs, shape, inputTypeIdentifiers);
setIsGenerating(false);
return output;
} catch (e) {
setIsGenerating(false);
throw new Error(getError(e));
}
};
return {
error,
isReady,
isGenerating,
downloadProgress,
forwardETInput,
forwardImage
};
};
//# sourceMappingURL=useModule.js.map
;