UNPKG

react-native-executorch

Version:

An easy way to run AI models in react native with ExecuTorch

307 lines (297 loc) 13.4 kB
"use strict"; import { HAMMING_DIST_THRESHOLD, MODEL_CONFIGS, SECOND, MODES, NUM_TOKENS_TO_TRIM, STREAMING_ACTION } from '../constants/sttDefaults'; import { SpeechToTextNativeModule } from '../native/RnExecutorchModules'; import { TokenizerModule } from '../modules/natural_language_processing/TokenizerModule'; import { ResourceFetcher } from '../utils/ResourceFetcher'; import { longCommonInfPref } from '../utils/stt'; import { ETError, getError } from '../Error'; export class SpeechToTextController { speechToTextNativeModule = SpeechToTextNativeModule; sequence = []; isReady = false; isGenerating = false; chunks = []; seqs = []; prevSeq = []; waveform = []; numOfChunks = 0; streaming = false; // User callbacks constructor({ transcribeCallback, modelDownloadProgressCallback, isReadyCallback, isGeneratingCallback, onErrorCallback, overlapSeconds, windowSize, streamingConfig }) { this.decodedTranscribeCallback = async seq => transcribeCallback(await this.tokenIdsToText(seq)); this.modelDownloadProgressCallback = modelDownloadProgressCallback; this.isReadyCallback = isReady => { this.isReady = isReady; isReadyCallback?.(isReady); }; this.isGeneratingCallback = isGenerating => { this.isGenerating = isGenerating; isGeneratingCallback?.(isGenerating); }; this.onErrorCallback = error => { if (onErrorCallback) { onErrorCallback(error ? new Error(getError(error)) : undefined); return; } else { throw new Error(getError(error)); } }; this.configureStreaming(overlapSeconds, windowSize, streamingConfig || 'balanced'); } async loadModel(modelName, encoderSource, decoderSource, tokenizerSource) { this.onErrorCallback(undefined); this.isReadyCallback(false); this.config = MODEL_CONFIGS[modelName]; try { await TokenizerModule.load(tokenizerSource || this.config.tokenizer.source); [encoderSource, decoderSource] = await ResourceFetcher.fetchMultipleResources(this.modelDownloadProgressCallback, encoderSource || this.config.sources.encoder, decoderSource || this.config.sources.decoder); } catch (e) { this.onErrorCallback(e); return; } if (modelName === 'whisperMultilingual') { // The underlying native class is instantiated based on the name of the model. There is no need to // create a separate class for multilingual version of Whisper, since it is the same. We just need // the distinction here, in TS, for start tokens and such. If we introduce // more versions of Whisper, such as the small one, this should be refactored. modelName = 'whisper'; } try { await this.speechToTextNativeModule.loadModule(modelName, [encoderSource, decoderSource]); this.modelDownloadProgressCallback?.(1); this.isReadyCallback(true); } catch (e) { this.onErrorCallback(e); } } configureStreaming(overlapSeconds, windowSize, streamingConfig) { if (streamingConfig) { this.windowSize = MODES[streamingConfig].windowSize * SECOND; this.overlapSeconds = MODES[streamingConfig].overlapSeconds * SECOND; } if (streamingConfig && (windowSize || overlapSeconds)) { console.warn(`windowSize and overlapSeconds overrides values from streamingConfig ${streamingConfig}.`); } this.windowSize = (windowSize || 0) * SECOND || this.windowSize; this.overlapSeconds = (overlapSeconds || 0) * SECOND || this.overlapSeconds; if (2 * this.overlapSeconds + this.windowSize >= 30 * SECOND) { console.warn(`Invalid values for overlapSeconds and/or windowSize provided. Expected windowSize + 2 * overlapSeconds (== ${this.windowSize + 2 * this.overlapSeconds}) <= 30. Setting windowSize to ${30 * SECOND - 2 * this.overlapSeconds}.`); this.windowSize = 30 * SECOND - 2 * this.overlapSeconds; } } chunkWaveform() { this.numOfChunks = Math.ceil(this.waveform.length / this.windowSize); for (let i = 0; i < this.numOfChunks; i++) { let chunk = []; const left = Math.max(this.windowSize * i - this.overlapSeconds, 0); const right = Math.min(this.windowSize * (i + 1) + this.overlapSeconds, this.waveform.length); chunk = this.waveform.slice(left, right); this.chunks.push(chunk); } } resetState() { this.sequence = []; this.seqs = []; this.waveform = []; this.prevSeq = []; this.chunks = []; this.decodedTranscribeCallback([]); this.onErrorCallback(undefined); } expectedChunkLength() { //only first chunk can be of shorter length, for first chunk there are no seqs decoded return this.seqs.length ? this.windowSize + 2 * this.overlapSeconds : this.windowSize + this.overlapSeconds; } async getStartingTokenIds(audioLanguage) { // We need different starting token ids based on the multilingualism of the model. // The eng version only needs BOS token, while the multilingual one needs: // [BOS, LANG, TRANSCRIBE]. Optionally we should also set notimestamps token, as timestamps // is not yet supported. if (!audioLanguage) { return [this.config.tokenizer.bos]; } // FIXME: I should use .getTokenId for the BOS as well, should remove it from config const langTokenId = await TokenizerModule.tokenToId(`<|${audioLanguage}|>`); const transcribeTokenId = await TokenizerModule.tokenToId('<|transcribe|>'); const noTimestampsTokenId = await TokenizerModule.tokenToId('<|notimestamps|>'); const startingTokenIds = [this.config.tokenizer.bos, langTokenId, transcribeTokenId, noTimestampsTokenId]; return startingTokenIds; } async decodeChunk(chunk, audioLanguage) { const seq = await this.getStartingTokenIds(audioLanguage); let prevSeqTokenIdx = 0; this.prevSeq = this.sequence.slice(); try { await this.encode(chunk); } catch (error) { this.onErrorCallback(new Error(getError(error) + ' encoding error')); return []; } let lastToken = seq.at(-1); while (lastToken !== this.config.tokenizer.eos) { try { lastToken = await this.decode(seq); } catch (error) { this.onErrorCallback(new Error(getError(error) + ' decoding error')); return [...seq, this.config.tokenizer.eos]; } seq.push(lastToken); if (this.seqs.length > 0 && seq.length < this.seqs.at(-1).length && seq.length % 3 !== 0) { this.prevSeq.push(this.seqs.at(-1)[prevSeqTokenIdx++]); this.decodedTranscribeCallback(this.prevSeq); } } return seq; } async handleOverlaps(seqs) { const maxInd = longCommonInfPref(seqs.at(-2), seqs.at(-1), HAMMING_DIST_THRESHOLD); this.sequence = [...this.sequence, ...seqs.at(-2).slice(0, maxInd)]; this.decodedTranscribeCallback(this.sequence); return this.sequence.slice(); } trimLeft(numOfTokensToTrim) { const idx = this.seqs.length - 1; if (this.seqs[idx][0] === this.config.tokenizer.bos) { this.seqs[idx] = this.seqs[idx].slice(numOfTokensToTrim); } } trimRight(numOfTokensToTrim) { const idx = this.seqs.length - 2; if (this.seqs[idx].at(-1) === this.config.tokenizer.eos) { this.seqs[idx] = this.seqs[idx].slice(0, -numOfTokensToTrim); } } // since we are calling this every time (except first) after a new seq is pushed to this.seqs // we can only trim left the last seq and trim right the second to last seq async trimSequences(audioLanguage) { const numSpecialTokens = (await this.getStartingTokenIds(audioLanguage)).length; this.trimLeft(numSpecialTokens + NUM_TOKENS_TO_TRIM); this.trimRight(numSpecialTokens + NUM_TOKENS_TO_TRIM); } // if last chunk is too short combine it with second to last to improve quality validateAndFixLastChunk() { if (this.chunks.length < 2) return; const lastChunkLength = this.chunks.at(-1).length / SECOND; const secondToLastChunkLength = this.chunks.at(-2).length / SECOND; if (lastChunkLength < 5 && secondToLastChunkLength + lastChunkLength < 30) { this.chunks[this.chunks.length - 2] = [...this.chunks.at(-2).slice(0, -this.overlapSeconds * 2), ...this.chunks.at(-1)]; this.chunks = this.chunks.slice(0, -1); } } async tokenIdsToText(tokenIds) { try { return TokenizerModule.decode(tokenIds, true); } catch (e) { this.onErrorCallback(new Error(`An error has occurred when decoding the token ids: ${e}`)); return ''; } } async transcribe(waveform, audioLanguage) { try { if (!this.isReady) throw Error(getError(ETError.ModuleNotLoaded)); if (this.isGenerating || this.streaming) throw Error(getError(ETError.ModelGenerating)); if (!!audioLanguage !== this.config.isMultilingual) throw new Error(getError(ETError.MultilingualConfiguration)); } catch (e) { this.onErrorCallback(e); return ''; } // Making sure that the error is not set when we get there this.isGeneratingCallback(true); this.resetState(); this.waveform = waveform; this.chunkWaveform(); this.validateAndFixLastChunk(); for (let chunkId = 0; chunkId < this.chunks.length; chunkId++) { const seq = await this.decodeChunk(this.chunks.at(chunkId), audioLanguage); // whole audio is inside one chunk, no processing required if (this.chunks.length === 1) { this.sequence = seq; this.decodedTranscribeCallback(seq); break; } this.seqs.push(seq); if (this.seqs.length < 2) continue; // Remove starting tokenIds and some additional ones await this.trimSequences(audioLanguage); this.prevSeq = await this.handleOverlaps(this.seqs); // last sequence processed // overlaps are already handled, so just append the last seq if (this.seqs.length === this.chunks.length) { this.sequence = [...this.sequence, ...this.seqs.at(-1)]; this.decodedTranscribeCallback(this.sequence); this.prevSeq = this.sequence; } } const decodedText = await this.tokenIdsToText(this.sequence); this.isGeneratingCallback(false); return decodedText; } async streamingTranscribe(streamAction, waveform, audioLanguage) { try { if (!this.isReady) throw Error(getError(ETError.ModuleNotLoaded)); if (!!audioLanguage !== this.config.isMultilingual) throw new Error(getError(ETError.MultilingualConfiguration)); if (streamAction === STREAMING_ACTION.START && !this.streaming && this.isGenerating) throw Error(getError(ETError.ModelGenerating)); if (streamAction === STREAMING_ACTION.START && this.streaming) throw Error(getError(ETError.ModelGenerating)); if (streamAction === STREAMING_ACTION.DATA && !this.streaming) throw Error(getError(ETError.StreamingNotStarted)); if (streamAction === STREAMING_ACTION.STOP && !this.streaming) throw Error(getError(ETError.StreamingNotStarted)); if (streamAction === STREAMING_ACTION.DATA && !waveform) throw new Error(getError(ETError.MissingDataChunk)); } catch (e) { this.onErrorCallback(e); return ''; } if (streamAction === STREAMING_ACTION.START) { this.resetState(); this.streaming = true; this.isGeneratingCallback(true); } this.waveform = [...this.waveform, ...(waveform || [])]; // while buffer has at least required size get chunk and decode while (this.waveform.length >= this.expectedChunkLength()) { const chunk = this.waveform.slice(0, this.windowSize + this.overlapSeconds * (1 + Number(this.seqs.length > 0))); this.chunks = [chunk]; //save last chunk for STREAMING_ACTION.STOP this.waveform = this.waveform.slice(this.windowSize - this.overlapSeconds * Number(this.seqs.length === 0)); const seq = await this.decodeChunk(chunk, audioLanguage); this.seqs.push(seq); if (this.seqs.length < 2) continue; await this.trimSequences(audioLanguage); await this.handleOverlaps(this.seqs); } // got final package, process all remaining waveform data // since we run the loop above the waveform has at most one chunk in it if (streamAction === STREAMING_ACTION.STOP) { // pad remaining waveform data with previous chunk to this.windowSize + 2 * this.overlapSeconds const chunk = this.chunks.length ? [...this.chunks[0].slice(0, this.windowSize), ...this.waveform].slice(-this.windowSize - 2 * this.overlapSeconds) : this.waveform; this.waveform = []; const seq = await this.decodeChunk(chunk, audioLanguage); this.seqs.push(seq); if (this.seqs.length === 1) { this.sequence = this.seqs[0]; } else { await this.trimSequences(audioLanguage); await this.handleOverlaps(this.seqs); this.sequence = [...this.sequence, ...this.seqs.at(-1)]; } this.decodedTranscribeCallback(this.sequence); this.isGeneratingCallback(false); this.streaming = false; } const decodedText = await this.tokenIdsToText(this.sequence); return decodedText; } async encode(waveform) { return await this.speechToTextNativeModule.encode(waveform); } async decode(seq, encodings) { return await this.speechToTextNativeModule.decode(seq, encodings || []); } } //# sourceMappingURL=SpeechToTextController.js.map