react-native-executorch
Version:
An easy way to run AI models in react native with ExecuTorch
66 lines (65 loc) • 2.81 kB
JavaScript
"use strict";
import { symbols } from '../constants/ocr/symbols';
import { ETError, getError } from '../Error';
import { _VerticalOCRModule } from '../native/RnExecutorchModules';
import { fetchResource, calculateDownloadProgres } from '../utils/fetchResource';
export class VerticalOCRController {
isReady = false;
isGenerating = false;
error = null;
constructor({
modelDownloadProgressCallback = _downloadProgress => {},
isReadyCallback = _isReady => {},
isGeneratingCallback = _isGenerating => {},
errorCallback = _error => {}
}) {
this.nativeModule = new _VerticalOCRModule();
this.modelDownloadProgressCallback = modelDownloadProgressCallback;
this.isReadyCallback = isReadyCallback;
this.isGeneratingCallback = isGeneratingCallback;
this.errorCallback = errorCallback;
}
loadModel = async (detectorSources, recognizerSources, language, independentCharacters) => {
try {
if (Object.keys(detectorSources).length !== 2 || Object.keys(recognizerSources).length !== 2) return;
if (!symbols[language]) {
throw new Error(getError(ETError.LanguageNotSupported));
}
this.isReady = false;
this.isReadyCallback(this.isReady);
const recognizerPath = independentCharacters ? await fetchResource(recognizerSources.recognizerSmall, calculateDownloadProgres(3, 0, this.modelDownloadProgressCallback)) : await fetchResource(recognizerSources.recognizerLarge, calculateDownloadProgres(3, 0, this.modelDownloadProgressCallback));
const detectorPaths = {
detectorLarge: await fetchResource(detectorSources.detectorLarge, calculateDownloadProgres(3, 1, this.modelDownloadProgressCallback)),
detectorNarrow: await fetchResource(detectorSources.detectorNarrow, calculateDownloadProgres(3, 2, this.modelDownloadProgressCallback))
};
await this.nativeModule.loadModule(detectorPaths.detectorLarge, detectorPaths.detectorNarrow, recognizerPath, symbols[language], independentCharacters);
this.isReady = true;
this.isReadyCallback(this.isReady);
} catch (e) {
if (this.errorCallback) {
this.errorCallback(getError(e));
} else {
throw new Error(getError(e));
}
}
};
forward = async input => {
if (!this.isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (this.isGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}
try {
this.isGenerating = true;
this.isGeneratingCallback(this.isGenerating);
return await this.nativeModule.forward(input);
} catch (e) {
throw new Error(getError(e));
} finally {
this.isGenerating = false;
this.isGeneratingCallback(this.isGenerating);
}
};
}
//# sourceMappingURL=VerticalOCRController.js.map