UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

117 lines (109 loc) 3.73 kB
import { useEffect, useCallback, useState } from 'react'; import { ETError, getError } from '../../Error'; import { SpeechToTextModule } from '../../modules/natural_language_processing/SpeechToTextModule'; import { DecodingOptions, SpeechToTextModelConfig } from '../../types/stt'; export const useSpeechToText = ({ model, preventLoad = false, }: { model: SpeechToTextModelConfig; preventLoad?: boolean; }) => { const [error, setError] = useState<null | string>(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); const [modelInstance] = useState(() => new SpeechToTextModule()); const [committedTranscription, setCommittedTranscription] = useState(''); const [nonCommittedTranscription, setNonCommittedTranscription] = useState(''); useEffect(() => { if (preventLoad) return; (async () => { setDownloadProgress(0); setError(null); try { setIsReady(false); await modelInstance.load( { isMultilingual: model.isMultilingual, encoderSource: model.encoderSource, decoderSource: model.decoderSource, tokenizerSource: model.tokenizerSource, }, setDownloadProgress ); setIsReady(true); } catch (err) { setError((err as Error).message); } })(); }, [ modelInstance, model.isMultilingual, model.encoderSource, model.decoderSource, model.tokenizerSource, preventLoad, ]); const stateWrapper = useCallback( <T extends (...args: any[]) => Promise<any>>(fn: T) => async (...args: Parameters<T>): Promise<Awaited<ReturnType<T>>> => { if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded)); if (isGenerating) throw new Error(getError(ETError.ModelGenerating)); setIsGenerating(true); try { return await fn.apply(modelInstance, args); } finally { setIsGenerating(false); } }, [isReady, isGenerating, modelInstance] ); const stream = useCallback( async (options?: DecodingOptions) => { if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded)); if (isGenerating) throw new Error(getError(ETError.ModelGenerating)); setIsGenerating(true); setCommittedTranscription(''); setNonCommittedTranscription(''); let transcription = ''; try { for await (const { committed, nonCommitted } of modelInstance.stream( options )) { setCommittedTranscription((prev) => prev + committed); setNonCommittedTranscription(nonCommitted); transcription += committed; } } finally { setIsGenerating(false); } return transcription; }, [isReady, isGenerating, modelInstance] ); const wrapper = useCallback( <T extends (...args: any[]) => any>(fn: T) => { return (...args: Parameters<T>): ReturnType<T> => { if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded)); return fn.apply(modelInstance, args); }; }, [isReady, modelInstance] ); return { error, isReady, isGenerating, downloadProgress, committedTranscription, nonCommittedTranscription, encode: stateWrapper(SpeechToTextModule.prototype.encode), decode: stateWrapper(SpeechToTextModule.prototype.decode), transcribe: stateWrapper(SpeechToTextModule.prototype.transcribe), stream, streamStop: wrapper(SpeechToTextModule.prototype.streamStop), streamInsert: wrapper(SpeechToTextModule.prototype.streamInsert), }; };