react-native-executorch
Version:
An easy way to run AI models in react native with ExecuTorch
307 lines (297 loc) • 13.4 kB
JavaScript
"use strict";
import { HAMMING_DIST_THRESHOLD, MODEL_CONFIGS, SECOND, MODES, NUM_TOKENS_TO_TRIM, STREAMING_ACTION } from '../constants/sttDefaults';
import { SpeechToTextNativeModule } from '../native/RnExecutorchModules';
import { TokenizerModule } from '../modules/natural_language_processing/TokenizerModule';
import { ResourceFetcher } from '../utils/ResourceFetcher';
import { longCommonInfPref } from '../utils/stt';
import { ETError, getError } from '../Error';
export class SpeechToTextController {
speechToTextNativeModule = SpeechToTextNativeModule;
sequence = [];
isReady = false;
isGenerating = false;
chunks = [];
seqs = [];
prevSeq = [];
waveform = [];
numOfChunks = 0;
streaming = false;
// User callbacks
constructor({
transcribeCallback,
modelDownloadProgressCallback,
isReadyCallback,
isGeneratingCallback,
onErrorCallback,
overlapSeconds,
windowSize,
streamingConfig
}) {
this.decodedTranscribeCallback = async seq => transcribeCallback(await this.tokenIdsToText(seq));
this.modelDownloadProgressCallback = modelDownloadProgressCallback;
this.isReadyCallback = isReady => {
this.isReady = isReady;
isReadyCallback?.(isReady);
};
this.isGeneratingCallback = isGenerating => {
this.isGenerating = isGenerating;
isGeneratingCallback?.(isGenerating);
};
this.onErrorCallback = error => {
if (onErrorCallback) {
onErrorCallback(error ? new Error(getError(error)) : undefined);
return;
} else {
throw new Error(getError(error));
}
};
this.configureStreaming(overlapSeconds, windowSize, streamingConfig || 'balanced');
}
async loadModel(modelName, encoderSource, decoderSource, tokenizerSource) {
this.onErrorCallback(undefined);
this.isReadyCallback(false);
this.config = MODEL_CONFIGS[modelName];
try {
await TokenizerModule.load(tokenizerSource || this.config.tokenizer.source);
[encoderSource, decoderSource] = await ResourceFetcher.fetchMultipleResources(this.modelDownloadProgressCallback, encoderSource || this.config.sources.encoder, decoderSource || this.config.sources.decoder);
} catch (e) {
this.onErrorCallback(e);
return;
}
if (modelName === 'whisperMultilingual') {
// The underlying native class is instantiated based on the name of the model. There is no need to
// create a separate class for multilingual version of Whisper, since it is the same. We just need
// the distinction here, in TS, for start tokens and such. If we introduce
// more versions of Whisper, such as the small one, this should be refactored.
modelName = 'whisper';
}
try {
await this.speechToTextNativeModule.loadModule(modelName, [encoderSource, decoderSource]);
this.modelDownloadProgressCallback?.(1);
this.isReadyCallback(true);
} catch (e) {
this.onErrorCallback(e);
}
}
configureStreaming(overlapSeconds, windowSize, streamingConfig) {
if (streamingConfig) {
this.windowSize = MODES[streamingConfig].windowSize * SECOND;
this.overlapSeconds = MODES[streamingConfig].overlapSeconds * SECOND;
}
if (streamingConfig && (windowSize || overlapSeconds)) {
console.warn(`windowSize and overlapSeconds overrides values from streamingConfig ${streamingConfig}.`);
}
this.windowSize = (windowSize || 0) * SECOND || this.windowSize;
this.overlapSeconds = (overlapSeconds || 0) * SECOND || this.overlapSeconds;
if (2 * this.overlapSeconds + this.windowSize >= 30 * SECOND) {
console.warn(`Invalid values for overlapSeconds and/or windowSize provided. Expected windowSize + 2 * overlapSeconds (== ${this.windowSize + 2 * this.overlapSeconds}) <= 30. Setting windowSize to ${30 * SECOND - 2 * this.overlapSeconds}.`);
this.windowSize = 30 * SECOND - 2 * this.overlapSeconds;
}
}
chunkWaveform() {
this.numOfChunks = Math.ceil(this.waveform.length / this.windowSize);
for (let i = 0; i < this.numOfChunks; i++) {
let chunk = [];
const left = Math.max(this.windowSize * i - this.overlapSeconds, 0);
const right = Math.min(this.windowSize * (i + 1) + this.overlapSeconds, this.waveform.length);
chunk = this.waveform.slice(left, right);
this.chunks.push(chunk);
}
}
resetState() {
this.sequence = [];
this.seqs = [];
this.waveform = [];
this.prevSeq = [];
this.chunks = [];
this.decodedTranscribeCallback([]);
this.onErrorCallback(undefined);
}
expectedChunkLength() {
//only first chunk can be of shorter length, for first chunk there are no seqs decoded
return this.seqs.length ? this.windowSize + 2 * this.overlapSeconds : this.windowSize + this.overlapSeconds;
}
async getStartingTokenIds(audioLanguage) {
// We need different starting token ids based on the multilingualism of the model.
// The eng version only needs BOS token, while the multilingual one needs:
// [BOS, LANG, TRANSCRIBE]. Optionally we should also set notimestamps token, as timestamps
// is not yet supported.
if (!audioLanguage) {
return [this.config.tokenizer.bos];
}
// FIXME: I should use .getTokenId for the BOS as well, should remove it from config
const langTokenId = await TokenizerModule.tokenToId(`<|${audioLanguage}|>`);
const transcribeTokenId = await TokenizerModule.tokenToId('<|transcribe|>');
const noTimestampsTokenId = await TokenizerModule.tokenToId('<|notimestamps|>');
const startingTokenIds = [this.config.tokenizer.bos, langTokenId, transcribeTokenId, noTimestampsTokenId];
return startingTokenIds;
}
async decodeChunk(chunk, audioLanguage) {
const seq = await this.getStartingTokenIds(audioLanguage);
let prevSeqTokenIdx = 0;
this.prevSeq = this.sequence.slice();
try {
await this.encode(chunk);
} catch (error) {
this.onErrorCallback(new Error(getError(error) + ' encoding error'));
return [];
}
let lastToken = seq.at(-1);
while (lastToken !== this.config.tokenizer.eos) {
try {
lastToken = await this.decode(seq);
} catch (error) {
this.onErrorCallback(new Error(getError(error) + ' decoding error'));
return [...seq, this.config.tokenizer.eos];
}
seq.push(lastToken);
if (this.seqs.length > 0 && seq.length < this.seqs.at(-1).length && seq.length % 3 !== 0) {
this.prevSeq.push(this.seqs.at(-1)[prevSeqTokenIdx++]);
this.decodedTranscribeCallback(this.prevSeq);
}
}
return seq;
}
async handleOverlaps(seqs) {
const maxInd = longCommonInfPref(seqs.at(-2), seqs.at(-1), HAMMING_DIST_THRESHOLD);
this.sequence = [...this.sequence, ...seqs.at(-2).slice(0, maxInd)];
this.decodedTranscribeCallback(this.sequence);
return this.sequence.slice();
}
trimLeft(numOfTokensToTrim) {
const idx = this.seqs.length - 1;
if (this.seqs[idx][0] === this.config.tokenizer.bos) {
this.seqs[idx] = this.seqs[idx].slice(numOfTokensToTrim);
}
}
trimRight(numOfTokensToTrim) {
const idx = this.seqs.length - 2;
if (this.seqs[idx].at(-1) === this.config.tokenizer.eos) {
this.seqs[idx] = this.seqs[idx].slice(0, -numOfTokensToTrim);
}
}
// since we are calling this every time (except first) after a new seq is pushed to this.seqs
// we can only trim left the last seq and trim right the second to last seq
async trimSequences(audioLanguage) {
const numSpecialTokens = (await this.getStartingTokenIds(audioLanguage)).length;
this.trimLeft(numSpecialTokens + NUM_TOKENS_TO_TRIM);
this.trimRight(numSpecialTokens + NUM_TOKENS_TO_TRIM);
}
// if last chunk is too short combine it with second to last to improve quality
validateAndFixLastChunk() {
if (this.chunks.length < 2) return;
const lastChunkLength = this.chunks.at(-1).length / SECOND;
const secondToLastChunkLength = this.chunks.at(-2).length / SECOND;
if (lastChunkLength < 5 && secondToLastChunkLength + lastChunkLength < 30) {
this.chunks[this.chunks.length - 2] = [...this.chunks.at(-2).slice(0, -this.overlapSeconds * 2), ...this.chunks.at(-1)];
this.chunks = this.chunks.slice(0, -1);
}
}
async tokenIdsToText(tokenIds) {
try {
return TokenizerModule.decode(tokenIds, true);
} catch (e) {
this.onErrorCallback(new Error(`An error has occurred when decoding the token ids: ${e}`));
return '';
}
}
async transcribe(waveform, audioLanguage) {
try {
if (!this.isReady) throw Error(getError(ETError.ModuleNotLoaded));
if (this.isGenerating || this.streaming) throw Error(getError(ETError.ModelGenerating));
if (!!audioLanguage !== this.config.isMultilingual) throw new Error(getError(ETError.MultilingualConfiguration));
} catch (e) {
this.onErrorCallback(e);
return '';
}
// Making sure that the error is not set when we get there
this.isGeneratingCallback(true);
this.resetState();
this.waveform = waveform;
this.chunkWaveform();
this.validateAndFixLastChunk();
for (let chunkId = 0; chunkId < this.chunks.length; chunkId++) {
const seq = await this.decodeChunk(this.chunks.at(chunkId), audioLanguage);
// whole audio is inside one chunk, no processing required
if (this.chunks.length === 1) {
this.sequence = seq;
this.decodedTranscribeCallback(seq);
break;
}
this.seqs.push(seq);
if (this.seqs.length < 2) continue;
// Remove starting tokenIds and some additional ones
await this.trimSequences(audioLanguage);
this.prevSeq = await this.handleOverlaps(this.seqs);
// last sequence processed
// overlaps are already handled, so just append the last seq
if (this.seqs.length === this.chunks.length) {
this.sequence = [...this.sequence, ...this.seqs.at(-1)];
this.decodedTranscribeCallback(this.sequence);
this.prevSeq = this.sequence;
}
}
const decodedText = await this.tokenIdsToText(this.sequence);
this.isGeneratingCallback(false);
return decodedText;
}
async streamingTranscribe(streamAction, waveform, audioLanguage) {
try {
if (!this.isReady) throw Error(getError(ETError.ModuleNotLoaded));
if (!!audioLanguage !== this.config.isMultilingual) throw new Error(getError(ETError.MultilingualConfiguration));
if (streamAction === STREAMING_ACTION.START && !this.streaming && this.isGenerating) throw Error(getError(ETError.ModelGenerating));
if (streamAction === STREAMING_ACTION.START && this.streaming) throw Error(getError(ETError.ModelGenerating));
if (streamAction === STREAMING_ACTION.DATA && !this.streaming) throw Error(getError(ETError.StreamingNotStarted));
if (streamAction === STREAMING_ACTION.STOP && !this.streaming) throw Error(getError(ETError.StreamingNotStarted));
if (streamAction === STREAMING_ACTION.DATA && !waveform) throw new Error(getError(ETError.MissingDataChunk));
} catch (e) {
this.onErrorCallback(e);
return '';
}
if (streamAction === STREAMING_ACTION.START) {
this.resetState();
this.streaming = true;
this.isGeneratingCallback(true);
}
this.waveform = [...this.waveform, ...(waveform || [])];
// while buffer has at least required size get chunk and decode
while (this.waveform.length >= this.expectedChunkLength()) {
const chunk = this.waveform.slice(0, this.windowSize + this.overlapSeconds * (1 + Number(this.seqs.length > 0)));
this.chunks = [chunk]; //save last chunk for STREAMING_ACTION.STOP
this.waveform = this.waveform.slice(this.windowSize - this.overlapSeconds * Number(this.seqs.length === 0));
const seq = await this.decodeChunk(chunk, audioLanguage);
this.seqs.push(seq);
if (this.seqs.length < 2) continue;
await this.trimSequences(audioLanguage);
await this.handleOverlaps(this.seqs);
}
// got final package, process all remaining waveform data
// since we run the loop above the waveform has at most one chunk in it
if (streamAction === STREAMING_ACTION.STOP) {
// pad remaining waveform data with previous chunk to this.windowSize + 2 * this.overlapSeconds
const chunk = this.chunks.length ? [...this.chunks[0].slice(0, this.windowSize), ...this.waveform].slice(-this.windowSize - 2 * this.overlapSeconds) : this.waveform;
this.waveform = [];
const seq = await this.decodeChunk(chunk, audioLanguage);
this.seqs.push(seq);
if (this.seqs.length === 1) {
this.sequence = this.seqs[0];
} else {
await this.trimSequences(audioLanguage);
await this.handleOverlaps(this.seqs);
this.sequence = [...this.sequence, ...this.seqs.at(-1)];
}
this.decodedTranscribeCallback(this.sequence);
this.isGeneratingCallback(false);
this.streaming = false;
}
const decodedText = await this.tokenIdsToText(this.sequence);
return decodedText;
}
async encode(waveform) {
return await this.speechToTextNativeModule.encode(waveform);
}
async decode(seq, encodings) {
return await this.speechToTextNativeModule.decode(seq, encodings || []);
}
}
//# sourceMappingURL=SpeechToTextController.js.map