react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
88 lines (87 loc) • 3.23 kB
JavaScript
;
import { useEffect, useCallback, useState } from 'react';
import { ETError, getError } from '../../Error';
import { SpeechToTextModule } from '../../modules/natural_language_processing/SpeechToTextModule';
export const useSpeechToText = ({
model,
preventLoad = false
}) => {
const [error, setError] = useState(null);
const [isReady, setIsReady] = useState(false);
const [isGenerating, setIsGenerating] = useState(false);
const [downloadProgress, setDownloadProgress] = useState(0);
const [modelInstance] = useState(() => new SpeechToTextModule());
const [committedTranscription, setCommittedTranscription] = useState('');
const [nonCommittedTranscription, setNonCommittedTranscription] = useState('');
useEffect(() => {
if (preventLoad) return;
(async () => {
setDownloadProgress(0);
setError(null);
try {
setIsReady(false);
await modelInstance.load({
isMultilingual: model.isMultilingual,
encoderSource: model.encoderSource,
decoderSource: model.decoderSource,
tokenizerSource: model.tokenizerSource
}, setDownloadProgress);
setIsReady(true);
} catch (err) {
setError(err.message);
}
})();
}, [modelInstance, model.isMultilingual, model.encoderSource, model.decoderSource, model.tokenizerSource, preventLoad]);
const stateWrapper = useCallback(fn => async (...args) => {
if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded));
if (isGenerating) throw new Error(getError(ETError.ModelGenerating));
setIsGenerating(true);
try {
return await fn.apply(modelInstance, args);
} finally {
setIsGenerating(false);
}
}, [isReady, isGenerating, modelInstance]);
const stream = useCallback(async options => {
if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded));
if (isGenerating) throw new Error(getError(ETError.ModelGenerating));
setIsGenerating(true);
setCommittedTranscription('');
setNonCommittedTranscription('');
let transcription = '';
try {
for await (const {
committed,
nonCommitted
} of modelInstance.stream(options)) {
setCommittedTranscription(prev => prev + committed);
setNonCommittedTranscription(nonCommitted);
transcription += committed;
}
} finally {
setIsGenerating(false);
}
return transcription;
}, [isReady, isGenerating, modelInstance]);
const wrapper = useCallback(fn => {
return (...args) => {
if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded));
return fn.apply(modelInstance, args);
};
}, [isReady, modelInstance]);
return {
error,
isReady,
isGenerating,
downloadProgress,
committedTranscription,
nonCommittedTranscription,
encode: stateWrapper(SpeechToTextModule.prototype.encode),
decode: stateWrapper(SpeechToTextModule.prototype.decode),
transcribe: stateWrapper(SpeechToTextModule.prototype.transcribe),
stream,
streamStop: wrapper(SpeechToTextModule.prototype.streamStop),
streamInsert: wrapper(SpeechToTextModule.prototype.streamInsert)
};
};
//# sourceMappingURL=useSpeechToText.js.map