UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

106 lines (105 loc) 3.69 kB
"use strict"; import { Logger } from '../../common/Logger'; import { ResourceFetcher } from '../../utils/ResourceFetcher'; export class SpeechToTextModule { async load(model, onDownloadProgressCallback = () => {}) { this.modelConfig = model; const tokenizerLoadPromise = ResourceFetcher.fetch(undefined, model.tokenizerSource); const encoderDecoderPromise = ResourceFetcher.fetch(onDownloadProgressCallback, model.encoderSource, model.decoderSource); const [tokenizerSources, encoderDecoderResults] = await Promise.all([tokenizerLoadPromise, encoderDecoderPromise]); const encoderSource = encoderDecoderResults?.[0]; const decoderSource = encoderDecoderResults?.[1]; if (!encoderSource || !decoderSource || !tokenizerSources) { throw new Error('Download interrupted.'); } this.nativeModule = await global.loadSpeechToText(encoderSource, decoderSource, tokenizerSources[0]); } async encode(waveform) { if (Array.isArray(waveform)) { Logger.info('Passing waveform as number[] is deprecated, use Float32Array instead'); waveform = new Float32Array(waveform); } return new Float32Array(await this.nativeModule.encode(waveform)); } async decode(tokens, encoderOutput) { if (Array.isArray(tokens)) { Logger.info('Passing tokens as number[] is deprecated, use Int32Array instead'); tokens = new Int32Array(tokens); } if (Array.isArray(encoderOutput)) { Logger.info('Passing encoderOutput as number[] is deprecated, use Float32Array instead'); encoderOutput = new Float32Array(encoderOutput); } return new Float32Array(await this.nativeModule.decode(tokens, encoderOutput)); } async transcribe(waveform, options = {}) { this.validateOptions(options); if (Array.isArray(waveform)) { Logger.info('Passing waveform as number[] is deprecated, use Float32Array instead'); waveform = new Float32Array(waveform); } return this.nativeModule.transcribe(waveform, options.language || ''); } async *stream(options = {}) { this.validateOptions(options); const queue = []; let waiter = null; let finished = false; let error; const wake = () => { waiter?.(); waiter = null; }; (async () => { try { await this.nativeModule.stream((committed, nonCommitted, isDone) => { queue.push({ committed, nonCommitted }); if (isDone) { finished = true; } wake(); }, options.language || ''); finished = true; wake(); } catch (e) { error = e; finished = true; wake(); } })(); while (true) { if (queue.length > 0) { yield queue.shift(); if (finished && queue.length === 0) { return; } continue; } if (error) throw error; if (finished) return; await new Promise(r => waiter = r); } } async streamInsert(waveform) { if (Array.isArray(waveform)) { Logger.info('Passing waveform as number[] is deprecated, use Float32Array instead'); waveform = new Float32Array(waveform); } return this.nativeModule.streamInsert(waveform); } async streamStop() { return this.nativeModule.streamStop(); } validateOptions(options) { if (!this.modelConfig.isMultilingual && options.language) { throw new Error('Model is not multilingual, cannot set language'); } if (this.modelConfig.isMultilingual && !options.language) { throw new Error('Model is multilingual, provide a language'); } } } //# sourceMappingURL=SpeechToTextModule.js.map