react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
114 lines (113 loc) • 4.01 kB
JavaScript
;
import { Logger } from '../../common/Logger';
import { ResourceFetcher } from '../../utils/ResourceFetcher';
export class SpeechToTextModule {
textDecoder = new TextDecoder('utf-8', {
fatal: false,
ignoreBOM: true
});
async load(model, onDownloadProgressCallback = () => {}) {
this.modelConfig = model;
const tokenizerLoadPromise = ResourceFetcher.fetch(undefined, model.tokenizerSource);
const encoderDecoderPromise = ResourceFetcher.fetch(onDownloadProgressCallback, model.encoderSource, model.decoderSource);
const [tokenizerSources, encoderDecoderResults] = await Promise.all([tokenizerLoadPromise, encoderDecoderPromise]);
const encoderSource = encoderDecoderResults?.[0];
const decoderSource = encoderDecoderResults?.[1];
if (!encoderSource || !decoderSource || !tokenizerSources) {
throw new Error('Download interrupted.');
}
this.nativeModule = await global.loadSpeechToText(encoderSource, decoderSource, tokenizerSources[0]);
}
delete() {
this.nativeModule.unload();
}
async encode(waveform) {
if (Array.isArray(waveform)) {
Logger.info('Passing waveform as number[] is deprecated, use Float32Array instead');
waveform = new Float32Array(waveform);
}
return new Float32Array(await this.nativeModule.encode(waveform));
}
async decode(tokens, encoderOutput) {
if (Array.isArray(tokens)) {
Logger.info('Passing tokens as number[] is deprecated, use Int32Array instead');
tokens = new Int32Array(tokens);
}
if (Array.isArray(encoderOutput)) {
Logger.info('Passing encoderOutput as number[] is deprecated, use Float32Array instead');
encoderOutput = new Float32Array(encoderOutput);
}
return new Float32Array(await this.nativeModule.decode(tokens, encoderOutput));
}
async transcribe(waveform, options = {}) {
this.validateOptions(options);
if (Array.isArray(waveform)) {
Logger.info('Passing waveform as number[] is deprecated, use Float32Array instead');
waveform = new Float32Array(waveform);
}
const transcriptionBytes = await this.nativeModule.transcribe(waveform, options.language || '');
return this.textDecoder.decode(new Uint8Array(transcriptionBytes));
}
async *stream(options = {}) {
this.validateOptions(options);
const queue = [];
let waiter = null;
let finished = false;
let error;
const wake = () => {
waiter?.();
waiter = null;
};
(async () => {
try {
await this.nativeModule.stream((committed, nonCommitted, isDone) => {
queue.push({
committed: this.textDecoder.decode(new Uint8Array(committed)),
nonCommitted: this.textDecoder.decode(new Uint8Array(nonCommitted))
});
if (isDone) {
finished = true;
}
wake();
}, options.language || '');
finished = true;
wake();
} catch (e) {
error = e;
finished = true;
wake();
}
})();
while (true) {
if (queue.length > 0) {
yield queue.shift();
if (finished && queue.length === 0) {
return;
}
continue;
}
if (error) throw error;
if (finished) return;
await new Promise(r => waiter = r);
}
}
streamInsert(waveform) {
if (Array.isArray(waveform)) {
Logger.info('Passing waveform as number[] is deprecated, use Float32Array instead');
waveform = new Float32Array(waveform);
}
this.nativeModule.streamInsert(waveform);
}
streamStop() {
this.nativeModule.streamStop();
}
validateOptions(options) {
if (!this.modelConfig.isMultilingual && options.language) {
throw new Error('Model is not multilingual, cannot set language');
}
if (this.modelConfig.isMultilingual && !options.language) {
throw new Error('Model is multilingual, provide a language');
}
}
}
//# sourceMappingURL=SpeechToTextModule.js.map