UNPKG

react-native-executorch

Version:

An easy way to run AI models in React Native with ExecuTorch

194 lines (183 loc) 7.32 kB
"use strict"; import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; import { Logger } from '../../common/Logger'; /** * Module for Speech to Text (STT) functionalities. * @category Typescript API */ export class SpeechToTextModule { constructor(nativeModule, modelConfig) { this.nativeModule = nativeModule; this.modelConfig = modelConfig; } /** * Creates a Speech to Text instance for a built-in model. * @param namedSources - Configuration object containing model name, sources, and multilingual flag. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to a `SpeechToTextModule` instance. * @example * ```ts * import { SpeechToTextModule, WHISPER_TINY_EN } from 'react-native-executorch'; * const stt = await SpeechToTextModule.fromModelName(WHISPER_TINY_EN); * ``` */ static async fromModelName(namedSources, onDownloadProgress = () => {}) { try { const nativeModule = await SpeechToTextModule.loadWhisper(namedSources, onDownloadProgress); return new SpeechToTextModule(nativeModule, namedSources); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); } } /** * Creates a Speech to Text instance with user-provided model binaries. * Use this when working with a custom-exported STT model. * Internally uses `'custom'` as the model name for telemetry. * @remarks The native model contract for this method is not formally defined and may change * between releases. Currently only the Whisper architecture is supported by the native runner. * Refer to the native source code for the current expected interface. * @param modelSource - A fetchable resource pointing to the model binary. * @param tokenizerSource - A fetchable resource pointing to the tokenizer file. * @param isMultilingual - Whether the model supports multiple languages. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to a `SpeechToTextModule` instance. */ static fromCustomModel(modelSource, tokenizerSource, isMultilingual, onDownloadProgress = () => {}) { return SpeechToTextModule.fromModelName({ modelName: 'custom', modelSource, tokenizerSource, isMultilingual }, onDownloadProgress); } static async loadWhisper(model, onDownloadProgressCallback) { const tokenizerLoadPromise = ResourceFetcher.fetch(undefined, model.tokenizerSource); const modelPromise = ResourceFetcher.fetch(onDownloadProgressCallback, model.modelSource); const [tokenizerSources, modelSources] = await Promise.all([tokenizerLoadPromise, modelPromise]); if (!modelSources?.[0] || !tokenizerSources?.[0]) { throw new RnExecutorchError(RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'); } // Currently only Whisper architecture is supported return await global.loadSpeechToText('whisper', modelSources[0], tokenizerSources[0]); } /** * Unloads the model from memory. */ delete() { this.nativeModule?.unload(); } /** * Runs the encoding part of the model on the provided waveform. * Returns the encoded waveform as a Float32Array. * @param waveform - The input audio waveform. * @returns The encoded output. */ async encode(waveform) { const buffer = await this.nativeModule.encode(waveform); return new Float32Array(buffer); } /** * Runs the decoder of the model. * @param tokens - The input tokens. * @param encoderOutput - The encoder output. * @returns Decoded output. */ async decode(tokens, encoderOutput) { const buffer = await this.nativeModule.decode(tokens, encoderOutput); return new Float32Array(buffer); } /** * Starts a transcription process for a given input array (16kHz waveform). * For multilingual models, specify the language in `options`. * Returns the transcription as a string. Passing `number[]` is deprecated. * @param waveform - The Float32Array audio data. * @param options - Decoding options including language. * @returns The transcription string. */ async transcribe(waveform, options = {}) { this.validateOptions(options); return await this.nativeModule.transcribe(waveform, options.language || '', !!options.verbose); } /** * Starts a streaming transcription session. * Yields objects with `committed` and `nonCommitted` transcriptions. * Committed transcription contains the part of the transcription that is finalized and will not change. * Useful for displaying stable results during streaming. * Non-committed transcription contains the part of the transcription that is still being processed and may change. * Useful for displaying live, partial results during streaming. * Use with `streamInsert` and `streamStop` to control the stream. * @param options - Decoding options including language. * @yields An object containing `committed` and `nonCommitted` transcription results. * @returns An async generator yielding transcription updates. */ async *stream(options = {}) { this.validateOptions(options); const verbose = !!options.verbose; const language = options.language || ''; const queue = []; let waiter = null; let finished = false; let error; const wake = () => { waiter?.(); waiter = null; }; (async () => { try { await this.nativeModule.stream((committed, nonCommitted, isDone) => { queue.push({ committed, nonCommitted }); if (isDone) { finished = true; } wake(); }, language, verbose); finished = true; wake(); } catch (e) { error = e; finished = true; wake(); } })(); while (true) { if (queue.length > 0) { yield queue.shift(); if (finished && queue.length === 0) { return; } continue; } if (error) throw parseUnknownError(error); if (finished) return; await new Promise(r => waiter = r); } } /** * Inserts a new audio chunk into the streaming transcription session. * @param waveform - The audio chunk to insert. */ streamInsert(waveform) { this.nativeModule.streamInsert(waveform); } /** * Stops the current streaming transcription session. */ streamStop() { this.nativeModule.streamStop(); } validateOptions(options) { if (!this.modelConfig.isMultilingual && options.language) { throw new RnExecutorchError(RnExecutorchErrorCode.InvalidConfig, 'Model is not multilingual, cannot set language'); } if (this.modelConfig.isMultilingual && !options.language) { throw new RnExecutorchError(RnExecutorchErrorCode.InvalidConfig, 'Model is multilingual, provide a language'); } } } //# sourceMappingURL=SpeechToTextModule.js.map