react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
194 lines (183 loc) • 7.32 kB
JavaScript
;
import { ResourceFetcher } from '../../utils/ResourceFetcher';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils';
import { Logger } from '../../common/Logger';
/**
* Module for Speech to Text (STT) functionalities.
* @category Typescript API
*/
export class SpeechToTextModule {
constructor(nativeModule, modelConfig) {
this.nativeModule = nativeModule;
this.modelConfig = modelConfig;
}
/**
* Creates a Speech to Text instance for a built-in model.
* @param namedSources - Configuration object containing model name, sources, and multilingual flag.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `SpeechToTextModule` instance.
* @example
* ```ts
* import { SpeechToTextModule, WHISPER_TINY_EN } from 'react-native-executorch';
* const stt = await SpeechToTextModule.fromModelName(WHISPER_TINY_EN);
* ```
*/
static async fromModelName(namedSources, onDownloadProgress = () => {}) {
try {
const nativeModule = await SpeechToTextModule.loadWhisper(namedSources, onDownloadProgress);
return new SpeechToTextModule(nativeModule, namedSources);
} catch (error) {
Logger.error('Load failed:', error);
throw parseUnknownError(error);
}
}
/**
* Creates a Speech to Text instance with user-provided model binaries.
* Use this when working with a custom-exported STT model.
* Internally uses `'custom'` as the model name for telemetry.
* @remarks The native model contract for this method is not formally defined and may change
* between releases. Currently only the Whisper architecture is supported by the native runner.
* Refer to the native source code for the current expected interface.
* @param modelSource - A fetchable resource pointing to the model binary.
* @param tokenizerSource - A fetchable resource pointing to the tokenizer file.
* @param isMultilingual - Whether the model supports multiple languages.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
* @returns A Promise resolving to a `SpeechToTextModule` instance.
*/
static fromCustomModel(modelSource, tokenizerSource, isMultilingual, onDownloadProgress = () => {}) {
return SpeechToTextModule.fromModelName({
modelName: 'custom',
modelSource,
tokenizerSource,
isMultilingual
}, onDownloadProgress);
}
static async loadWhisper(model, onDownloadProgressCallback) {
const tokenizerLoadPromise = ResourceFetcher.fetch(undefined, model.tokenizerSource);
const modelPromise = ResourceFetcher.fetch(onDownloadProgressCallback, model.modelSource);
const [tokenizerSources, modelSources] = await Promise.all([tokenizerLoadPromise, modelPromise]);
if (!modelSources?.[0] || !tokenizerSources?.[0]) {
throw new RnExecutorchError(RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.');
}
// Currently only Whisper architecture is supported
return await global.loadSpeechToText('whisper', modelSources[0], tokenizerSources[0]);
}
/**
* Unloads the model from memory.
*/
delete() {
this.nativeModule?.unload();
}
/**
* Runs the encoding part of the model on the provided waveform.
* Returns the encoded waveform as a Float32Array.
* @param waveform - The input audio waveform.
* @returns The encoded output.
*/
async encode(waveform) {
const buffer = await this.nativeModule.encode(waveform);
return new Float32Array(buffer);
}
/**
* Runs the decoder of the model.
* @param tokens - The input tokens.
* @param encoderOutput - The encoder output.
* @returns Decoded output.
*/
async decode(tokens, encoderOutput) {
const buffer = await this.nativeModule.decode(tokens, encoderOutput);
return new Float32Array(buffer);
}
/**
* Starts a transcription process for a given input array (16kHz waveform).
* For multilingual models, specify the language in `options`.
* Returns the transcription as a string. Passing `number[]` is deprecated.
* @param waveform - The Float32Array audio data.
* @param options - Decoding options including language.
* @returns The transcription string.
*/
async transcribe(waveform, options = {}) {
this.validateOptions(options);
return await this.nativeModule.transcribe(waveform, options.language || '', !!options.verbose);
}
/**
* Starts a streaming transcription session.
* Yields objects with `committed` and `nonCommitted` transcriptions.
* Committed transcription contains the part of the transcription that is finalized and will not change.
* Useful for displaying stable results during streaming.
* Non-committed transcription contains the part of the transcription that is still being processed and may change.
* Useful for displaying live, partial results during streaming.
* Use with `streamInsert` and `streamStop` to control the stream.
* @param options - Decoding options including language.
* @yields An object containing `committed` and `nonCommitted` transcription results.
* @returns An async generator yielding transcription updates.
*/
async *stream(options = {}) {
this.validateOptions(options);
const verbose = !!options.verbose;
const language = options.language || '';
const queue = [];
let waiter = null;
let finished = false;
let error;
const wake = () => {
waiter?.();
waiter = null;
};
(async () => {
try {
await this.nativeModule.stream((committed, nonCommitted, isDone) => {
queue.push({
committed,
nonCommitted
});
if (isDone) {
finished = true;
}
wake();
}, language, verbose);
finished = true;
wake();
} catch (e) {
error = e;
finished = true;
wake();
}
})();
while (true) {
if (queue.length > 0) {
yield queue.shift();
if (finished && queue.length === 0) {
return;
}
continue;
}
if (error) throw parseUnknownError(error);
if (finished) return;
await new Promise(r => waiter = r);
}
}
/**
* Inserts a new audio chunk into the streaming transcription session.
* @param waveform - The audio chunk to insert.
*/
streamInsert(waveform) {
this.nativeModule.streamInsert(waveform);
}
/**
* Stops the current streaming transcription session.
*/
streamStop() {
this.nativeModule.streamStop();
}
validateOptions(options) {
if (!this.modelConfig.isMultilingual && options.language) {
throw new RnExecutorchError(RnExecutorchErrorCode.InvalidConfig, 'Model is not multilingual, cannot set language');
}
if (this.modelConfig.isMultilingual && !options.language) {
throw new RnExecutorchError(RnExecutorchErrorCode.InvalidConfig, 'Model is multilingual, provide a language');
}
}
}
//# sourceMappingURL=SpeechToTextModule.js.map