UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

114 lines (113 loc) 4.01 kB
"use strict"; import { Logger } from '../../common/Logger'; import { ResourceFetcher } from '../../utils/ResourceFetcher'; export class SpeechToTextModule { textDecoder = new TextDecoder('utf-8', { fatal: false, ignoreBOM: true }); async load(model, onDownloadProgressCallback = () => {}) { this.modelConfig = model; const tokenizerLoadPromise = ResourceFetcher.fetch(undefined, model.tokenizerSource); const encoderDecoderPromise = ResourceFetcher.fetch(onDownloadProgressCallback, model.encoderSource, model.decoderSource); const [tokenizerSources, encoderDecoderResults] = await Promise.all([tokenizerLoadPromise, encoderDecoderPromise]); const encoderSource = encoderDecoderResults?.[0]; const decoderSource = encoderDecoderResults?.[1]; if (!encoderSource || !decoderSource || !tokenizerSources) { throw new Error('Download interrupted.'); } this.nativeModule = await global.loadSpeechToText(encoderSource, decoderSource, tokenizerSources[0]); } delete() { this.nativeModule.unload(); } async encode(waveform) { if (Array.isArray(waveform)) { Logger.info('Passing waveform as number[] is deprecated, use Float32Array instead'); waveform = new Float32Array(waveform); } return new Float32Array(await this.nativeModule.encode(waveform)); } async decode(tokens, encoderOutput) { if (Array.isArray(tokens)) { Logger.info('Passing tokens as number[] is deprecated, use Int32Array instead'); tokens = new Int32Array(tokens); } if (Array.isArray(encoderOutput)) { Logger.info('Passing encoderOutput as number[] is deprecated, use Float32Array instead'); encoderOutput = new Float32Array(encoderOutput); } return new Float32Array(await this.nativeModule.decode(tokens, encoderOutput)); } async transcribe(waveform, options = {}) { this.validateOptions(options); if (Array.isArray(waveform)) { Logger.info('Passing waveform as number[] is deprecated, use Float32Array instead'); waveform = new Float32Array(waveform); } const transcriptionBytes = await this.nativeModule.transcribe(waveform, options.language || ''); return this.textDecoder.decode(new Uint8Array(transcriptionBytes)); } async *stream(options = {}) { this.validateOptions(options); const queue = []; let waiter = null; let finished = false; let error; const wake = () => { waiter?.(); waiter = null; }; (async () => { try { await this.nativeModule.stream((committed, nonCommitted, isDone) => { queue.push({ committed: this.textDecoder.decode(new Uint8Array(committed)), nonCommitted: this.textDecoder.decode(new Uint8Array(nonCommitted)) }); if (isDone) { finished = true; } wake(); }, options.language || ''); finished = true; wake(); } catch (e) { error = e; finished = true; wake(); } })(); while (true) { if (queue.length > 0) { yield queue.shift(); if (finished && queue.length === 0) { return; } continue; } if (error) throw error; if (finished) return; await new Promise(r => waiter = r); } } streamInsert(waveform) { if (Array.isArray(waveform)) { Logger.info('Passing waveform as number[] is deprecated, use Float32Array instead'); waveform = new Float32Array(waveform); } this.nativeModule.streamInsert(waveform); } streamStop() { this.nativeModule.streamStop(); } validateOptions(options) { if (!this.modelConfig.isMultilingual && options.language) { throw new Error('Model is not multilingual, cannot set language'); } if (this.modelConfig.isMultilingual && !options.language) { throw new Error('Model is multilingual, provide a language'); } } } //# sourceMappingURL=SpeechToTextModule.js.map