UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

88 lines (87 loc) 3.23 kB
"use strict"; import { useEffect, useCallback, useState } from 'react'; import { ETError, getError } from '../../Error'; import { SpeechToTextModule } from '../../modules/natural_language_processing/SpeechToTextModule'; export const useSpeechToText = ({ model, preventLoad = false }) => { const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); const [modelInstance] = useState(() => new SpeechToTextModule()); const [committedTranscription, setCommittedTranscription] = useState(''); const [nonCommittedTranscription, setNonCommittedTranscription] = useState(''); useEffect(() => { if (preventLoad) return; (async () => { setDownloadProgress(0); setError(null); try { setIsReady(false); await modelInstance.load({ isMultilingual: model.isMultilingual, encoderSource: model.encoderSource, decoderSource: model.decoderSource, tokenizerSource: model.tokenizerSource }, setDownloadProgress); setIsReady(true); } catch (err) { setError(err.message); } })(); }, [modelInstance, model.isMultilingual, model.encoderSource, model.decoderSource, model.tokenizerSource, preventLoad]); const stateWrapper = useCallback(fn => async (...args) => { if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded)); if (isGenerating) throw new Error(getError(ETError.ModelGenerating)); setIsGenerating(true); try { return await fn.apply(modelInstance, args); } finally { setIsGenerating(false); } }, [isReady, isGenerating, modelInstance]); const stream = useCallback(async options => { if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded)); if (isGenerating) throw new Error(getError(ETError.ModelGenerating)); setIsGenerating(true); setCommittedTranscription(''); setNonCommittedTranscription(''); let transcription = ''; try { for await (const { committed, nonCommitted } of modelInstance.stream(options)) { setCommittedTranscription(prev => prev + committed); setNonCommittedTranscription(nonCommitted); transcription += committed; } } finally { setIsGenerating(false); } return transcription; }, [isReady, isGenerating, modelInstance]); const wrapper = useCallback(fn => { return (...args) => { if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded)); return fn.apply(modelInstance, args); }; }, [isReady, modelInstance]); return { error, isReady, isGenerating, downloadProgress, committedTranscription, nonCommittedTranscription, encode: stateWrapper(SpeechToTextModule.prototype.encode), decode: stateWrapper(SpeechToTextModule.prototype.decode), transcribe: stateWrapper(SpeechToTextModule.prototype.transcribe), stream, streamStop: wrapper(SpeechToTextModule.prototype.streamStop), streamInsert: wrapper(SpeechToTextModule.prototype.streamInsert) }; }; //# sourceMappingURL=useSpeechToText.js.map