UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

180 lines (161 loc) 5 kB
import { Logger } from '../../common/Logger'; import { DecodingOptions, SpeechToTextModelConfig } from '../../types/stt'; import { ResourceFetcher } from '../../utils/ResourceFetcher'; export class SpeechToTextModule { private nativeModule: any; private modelConfig!: SpeechToTextModelConfig; private textDecoder = new TextDecoder('utf-8', { fatal: false, ignoreBOM: true, }); public async load( model: SpeechToTextModelConfig, onDownloadProgressCallback: (progress: number) => void = () => {} ) { this.modelConfig = model; const tokenizerLoadPromise = ResourceFetcher.fetch( undefined, model.tokenizerSource ); const encoderDecoderPromise = ResourceFetcher.fetch( onDownloadProgressCallback, model.encoderSource, model.decoderSource ); const [tokenizerSources, encoderDecoderResults] = await Promise.all([ tokenizerLoadPromise, encoderDecoderPromise, ]); const encoderSource = encoderDecoderResults?.[0]; const decoderSource = encoderDecoderResults?.[1]; if (!encoderSource || !decoderSource || !tokenizerSources) { throw new Error('Download interrupted.'); } this.nativeModule = await global.loadSpeechToText( encoderSource, decoderSource, tokenizerSources[0]! ); } public delete(): void { this.nativeModule.unload(); } public async encode( waveform: Float32Array | number[] ): Promise<Float32Array> { if (Array.isArray(waveform)) { Logger.info( 'Passing waveform as number[] is deprecated, use Float32Array instead' ); waveform = new Float32Array(waveform); } return new Float32Array(await this.nativeModule.encode(waveform)); } public async decode( tokens: Int32Array | number[], encoderOutput: Float32Array | number[] ): Promise<Float32Array> { if (Array.isArray(tokens)) { Logger.info( 'Passing tokens as number[] is deprecated, use Int32Array instead' ); tokens = new Int32Array(tokens); } if (Array.isArray(encoderOutput)) { Logger.info( 'Passing encoderOutput as number[] is deprecated, use Float32Array instead' ); encoderOutput = new Float32Array(encoderOutput); } return new Float32Array( await this.nativeModule.decode(tokens, encoderOutput) ); } public async transcribe( waveform: Float32Array | number[], options: DecodingOptions = {} ): Promise<string> { this.validateOptions(options); if (Array.isArray(waveform)) { Logger.info( 'Passing waveform as number[] is deprecated, use Float32Array instead' ); waveform = new Float32Array(waveform); } const transcriptionBytes = await this.nativeModule.transcribe( waveform, options.language || '' ); return this.textDecoder.decode(new Uint8Array(transcriptionBytes)); } public async *stream( options: DecodingOptions = {} ): AsyncGenerator<{ committed: string; nonCommitted: string }> { this.validateOptions(options); const queue: { committed: string; nonCommitted: string }[] = []; let waiter: (() => void) | null = null; let finished = false; let error: unknown; const wake = () => { waiter?.(); waiter = null; }; (async () => { try { await this.nativeModule.stream( (committed: number[], nonCommitted: number[], isDone: boolean) => { queue.push({ committed: this.textDecoder.decode(new Uint8Array(committed)), nonCommitted: this.textDecoder.decode( new Uint8Array(nonCommitted) ), }); if (isDone) { finished = true; } wake(); }, options.language || '' ); finished = true; wake(); } catch (e) { error = e; finished = true; wake(); } })(); while (true) { if (queue.length > 0) { yield queue.shift()!; if (finished && queue.length === 0) { return; } continue; } if (error) throw error; if (finished) return; await new Promise<void>((r) => (waiter = r)); } } public streamInsert(waveform: Float32Array | number[]): void { if (Array.isArray(waveform)) { Logger.info( 'Passing waveform as number[] is deprecated, use Float32Array instead' ); waveform = new Float32Array(waveform); } this.nativeModule.streamInsert(waveform); } public streamStop(): void { this.nativeModule.streamStop(); } private validateOptions(options: DecodingOptions) { if (!this.modelConfig.isMultilingual && options.language) { throw new Error('Model is not multilingual, cannot set language'); } if (this.modelConfig.isMultilingual && !options.language) { throw new Error('Model is multilingual, provide a language'); } } }