react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
117 lines (115 loc) • 4.38 kB
JavaScript
;
import { useEffect, useCallback, useState } from 'react';
import { SpeechToTextModule } from '../../modules/natural_language_processing/SpeechToTextModule';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils';
/**
* React hook for managing a Speech to Text (STT) instance.
* @category Hooks
* @param speechToTextProps - Configuration object containing `model` source and optional `preventLoad` flag.
* @returns Ready to use Speech to Text model.
*/
export const useSpeechToText = ({
model,
preventLoad = false
}) => {
const [error, setError] = useState(null);
const [isReady, setIsReady] = useState(false);
const [isGenerating, setIsGenerating] = useState(false);
const [downloadProgress, setDownloadProgress] = useState(0);
const [moduleInstance, setModuleInstance] = useState(null);
useEffect(() => {
if (preventLoad) return;
let active = true;
setDownloadProgress(0);
setError(null);
setIsReady(false);
SpeechToTextModule.fromModelName({
modelName: model.modelName,
isMultilingual: model.isMultilingual,
modelSource: model.modelSource,
tokenizerSource: model.tokenizerSource
}, p => {
if (active) setDownloadProgress(p);
}).then(mod => {
if (!active) {
mod.delete();
return;
}
setModuleInstance(prev => {
prev?.delete();
return mod;
});
setIsReady(true);
}).catch(err => {
if (active) setError(parseUnknownError(err));
});
return () => {
active = false;
setModuleInstance(prev => {
prev?.delete();
return null;
});
};
}, [model.modelName, model.isMultilingual, model.modelSource, model.tokenizerSource, preventLoad]);
const transcribe = useCallback(async (waveform, options = {}) => {
if (!isReady || !moduleInstance) {
throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded. Please load the model before calling this function.');
}
if (isGenerating) {
throw new RnExecutorchError(RnExecutorchErrorCode.ModelGenerating, 'The model is currently generating. Please wait until previous model run is complete.');
}
setIsGenerating(true);
try {
return await moduleInstance.transcribe(waveform, options);
} finally {
setIsGenerating(false);
}
}, [isReady, isGenerating, moduleInstance]);
const stream = useCallback(async function* (options = {}) {
if (!isReady || !moduleInstance) {
throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded. Please load the model before calling this function.');
}
if (isGenerating) {
throw new RnExecutorchError(RnExecutorchErrorCode.ModelGenerating, 'The model is currently generating. Please wait until previous model run is complete.');
}
setIsGenerating(true);
try {
const generator = moduleInstance.stream(options);
for await (const result of generator) {
yield result;
}
} finally {
setIsGenerating(false);
}
}, [isReady, isGenerating, moduleInstance]);
const streamInsert = useCallback(waveform => {
if (!isReady || !moduleInstance) return;
moduleInstance.streamInsert(waveform);
}, [isReady, moduleInstance]);
const streamStop = useCallback(() => {
if (!isReady || !moduleInstance) return;
moduleInstance.streamStop();
}, [isReady, moduleInstance]);
const encode = useCallback(waveform => {
if (!moduleInstance) throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded. Please load the model before calling this function.');
return moduleInstance.encode(waveform);
}, [moduleInstance]);
const decode = useCallback((tokens, encoderOutput) => {
if (!moduleInstance) throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded. Please load the model before calling this function.');
return moduleInstance.decode(tokens, encoderOutput);
}, [moduleInstance]);
return {
error,
isReady,
isGenerating,
downloadProgress,
transcribe,
stream,
streamInsert,
streamStop,
encode,
decode
};
};
//# sourceMappingURL=useSpeechToText.js.map