react-native-executorch
Version:
An easy way to run AI models in React Native with ExecuTorch
350 lines (349 loc) • 15.7 kB
JavaScript
import { HAMMING_DIST_THRESHOLD, MODEL_CONFIGS, SECOND, MODES, NUM_TOKENS_TO_TRIM, STREAMING_ACTION, } from '../constants/sttDefaults';
import { AvailableModels } from '../types/stt';
import { TokenizerModule } from '../modules/natural_language_processing/TokenizerModule';
import { ResourceFetcher } from '../utils/ResourceFetcher';
import { longCommonInfPref } from '../utils/stt';
import { ETError, getError } from '../Error';
import { Logger } from '../common/Logger';
export class SpeechToTextController {
speechToTextNativeModule;
sequence = [];
isReady = false;
isGenerating = false;
tokenizerModule;
overlapSeconds;
windowSize;
chunks = [];
seqs = [];
prevSeq = [];
waveform = [];
numOfChunks = 0;
streaming = false;
// User callbacks
decodedTranscribeCallback;
isReadyCallback;
isGeneratingCallback;
onErrorCallback;
config;
constructor({ transcribeCallback, isReadyCallback, isGeneratingCallback, onErrorCallback, overlapSeconds, windowSize, streamingConfig, }) {
this.tokenizerModule = new TokenizerModule();
this.decodedTranscribeCallback = async (seq) => transcribeCallback(await this.tokenIdsToText(seq));
this.isReadyCallback = (isReady) => {
this.isReady = isReady;
isReadyCallback?.(isReady);
};
this.isGeneratingCallback = (isGenerating) => {
this.isGenerating = isGenerating;
isGeneratingCallback?.(isGenerating);
};
this.onErrorCallback = (error) => {
if (onErrorCallback) {
onErrorCallback(error ? new Error(getError(error)) : undefined);
return;
}
else {
throw new Error(getError(error));
}
};
this.configureStreaming(overlapSeconds, windowSize, streamingConfig || 'balanced');
}
async load({ modelName, encoderSource, decoderSource, tokenizerSource, onDownloadProgressCallback, }) {
this.onErrorCallback(undefined);
this.isReadyCallback(false);
this.config = MODEL_CONFIGS[modelName];
try {
const tokenizerLoadPromise = this.tokenizerModule.load({
tokenizerSource: tokenizerSource || this.config.tokenizer.source,
});
const pathsPromise = ResourceFetcher.fetch(onDownloadProgressCallback, encoderSource || this.config.sources.encoder, decoderSource || this.config.sources.decoder);
const [_, encoderDecoderResults] = await Promise.all([
tokenizerLoadPromise,
pathsPromise,
]);
encoderSource = encoderDecoderResults?.[0];
decoderSource = encoderDecoderResults?.[1];
if (!encoderSource || !decoderSource) {
throw new Error('Download interrupted.');
}
}
catch (e) {
this.onErrorCallback(e);
return;
}
if (modelName === 'whisperMultilingual') {
// The underlying native class is instantiated based on the name of the model. There is no need to
// create a separate class for multilingual version of Whisper, since it is the same. We just need
// the distinction here, in TS, for start tokens and such. If we introduce
// more versions of Whisper, such as the small one, this should be refactored.
modelName = AvailableModels.WHISPER;
}
try {
const nativeSpeechToText = await global.loadSpeechToText(encoderSource, decoderSource, modelName);
this.speechToTextNativeModule = nativeSpeechToText;
this.isReadyCallback(true);
}
catch (e) {
this.onErrorCallback(e);
}
}
configureStreaming(overlapSeconds, windowSize, streamingConfig) {
if (streamingConfig) {
this.windowSize = MODES[streamingConfig].windowSize * SECOND;
this.overlapSeconds = MODES[streamingConfig].overlapSeconds * SECOND;
}
if (streamingConfig && (windowSize || overlapSeconds)) {
Logger.warn(`windowSize and overlapSeconds overrides values from streamingConfig ${streamingConfig}.`);
}
this.windowSize = (windowSize || 0) * SECOND || this.windowSize;
this.overlapSeconds = (overlapSeconds || 0) * SECOND || this.overlapSeconds;
if (2 * this.overlapSeconds + this.windowSize >= 30 * SECOND) {
Logger.warn(`Invalid values for overlapSeconds and/or windowSize provided. Expected windowSize + 2 * overlapSeconds (== ${this.windowSize + 2 * this.overlapSeconds}) <= 30. Setting windowSize to ${30 * SECOND - 2 * this.overlapSeconds}.`);
this.windowSize = 30 * SECOND - 2 * this.overlapSeconds;
}
}
chunkWaveform() {
this.numOfChunks = Math.ceil(this.waveform.length / this.windowSize);
for (let i = 0; i < this.numOfChunks; i++) {
let chunk = [];
const left = Math.max(this.windowSize * i - this.overlapSeconds, 0);
const right = Math.min(this.windowSize * (i + 1) + this.overlapSeconds, this.waveform.length);
chunk = this.waveform.slice(left, right);
this.chunks.push(chunk);
}
}
resetState() {
this.sequence = [];
this.seqs = [];
this.waveform = [];
this.prevSeq = [];
this.chunks = [];
this.decodedTranscribeCallback([]);
this.onErrorCallback(undefined);
}
expectedChunkLength() {
//only first chunk can be of shorter length, for first chunk there are no seqs decoded
return this.seqs.length
? this.windowSize + 2 * this.overlapSeconds
: this.windowSize + this.overlapSeconds;
}
async getStartingTokenIds(audioLanguage) {
// We need different starting token ids based on the multilingualism of the model.
// The eng version only needs BOS token, while the multilingual one needs:
// [BOS, LANG, TRANSCRIBE]. Optionally we should also set notimestamps token, as timestamps
// is not yet supported.
if (!audioLanguage) {
return [this.config.tokenizer.bos];
}
// FIXME: I should use .getTokenId for the BOS as well, should remove it from config
const langTokenId = await this.tokenizerModule.tokenToId(`<|${audioLanguage}|>`);
const transcribeTokenId = await this.tokenizerModule.tokenToId('<|transcribe|>');
const noTimestampsTokenId = await this.tokenizerModule.tokenToId('<|notimestamps|>');
const startingTokenIds = [
this.config.tokenizer.bos,
langTokenId,
transcribeTokenId,
noTimestampsTokenId,
];
return startingTokenIds;
}
async decodeChunk(chunk, audioLanguage) {
const seq = await this.getStartingTokenIds(audioLanguage);
let prevSeqTokenIdx = 0;
this.prevSeq = this.sequence.slice();
try {
await this.encode(new Float32Array(chunk));
}
catch (error) {
this.onErrorCallback(new Error(getError(error) + ' encoding error'));
return [];
}
let lastToken = seq.at(-1);
while (lastToken !== this.config.tokenizer.eos) {
try {
lastToken = await this.decode(seq);
}
catch (error) {
this.onErrorCallback(new Error(getError(error) + ' decoding error'));
return [...seq, this.config.tokenizer.eos];
}
seq.push(lastToken);
if (this.seqs.length > 0 &&
seq.length < this.seqs.at(-1).length &&
seq.length % 3 !== 0) {
this.prevSeq.push(this.seqs.at(-1)[prevSeqTokenIdx++]);
this.decodedTranscribeCallback(this.prevSeq);
}
}
return seq;
}
async handleOverlaps(seqs) {
const maxInd = longCommonInfPref(seqs.at(-2), seqs.at(-1), HAMMING_DIST_THRESHOLD);
this.sequence = [...this.sequence, ...seqs.at(-2).slice(0, maxInd)];
this.decodedTranscribeCallback(this.sequence);
return this.sequence.slice();
}
trimLeft(numOfTokensToTrim) {
const idx = this.seqs.length - 1;
if (this.seqs[idx][0] === this.config.tokenizer.bos) {
this.seqs[idx] = this.seqs[idx].slice(numOfTokensToTrim);
}
}
trimRight(numOfTokensToTrim) {
const idx = this.seqs.length - 2;
if (this.seqs[idx].at(-1) === this.config.tokenizer.eos) {
this.seqs[idx] = this.seqs[idx].slice(0, -numOfTokensToTrim);
}
}
// since we are calling this every time (except first) after a new seq is pushed to this.seqs
// we can only trim left the last seq and trim right the second to last seq
async trimSequences(audioLanguage) {
const numSpecialTokens = (await this.getStartingTokenIds(audioLanguage))
.length;
this.trimLeft(numSpecialTokens + NUM_TOKENS_TO_TRIM);
this.trimRight(numSpecialTokens + NUM_TOKENS_TO_TRIM);
}
// if last chunk is too short combine it with second to last to improve quality
validateAndFixLastChunk() {
if (this.chunks.length < 2)
return;
const lastChunkLength = this.chunks.at(-1).length / SECOND;
const secondToLastChunkLength = this.chunks.at(-2).length / SECOND;
if (lastChunkLength < 5 && secondToLastChunkLength + lastChunkLength < 30) {
this.chunks[this.chunks.length - 2] = [
...this.chunks.at(-2).slice(0, -this.overlapSeconds * 2),
...this.chunks.at(-1),
];
this.chunks = this.chunks.slice(0, -1);
}
}
async tokenIdsToText(tokenIds) {
try {
return await this.tokenizerModule.decode(tokenIds, true);
}
catch (e) {
this.onErrorCallback(new Error(`An error has occurred when decoding the token ids: ${e}`));
return '';
}
}
async transcribe(waveform, audioLanguage) {
try {
if (!this.isReady)
throw Error(getError(ETError.ModuleNotLoaded));
if (this.isGenerating || this.streaming)
throw Error(getError(ETError.ModelGenerating));
if (!!audioLanguage !== this.config.isMultilingual)
throw new Error(getError(ETError.MultilingualConfiguration));
}
catch (e) {
this.onErrorCallback(e);
return '';
}
// Making sure that the error is not set when we get there
this.isGeneratingCallback(true);
this.resetState();
this.waveform = waveform;
this.chunkWaveform();
this.validateAndFixLastChunk();
for (let chunkId = 0; chunkId < this.chunks.length; chunkId++) {
const seq = await this.decodeChunk(this.chunks.at(chunkId), audioLanguage);
// whole audio is inside one chunk, no processing required
if (this.chunks.length === 1) {
this.sequence = seq;
this.decodedTranscribeCallback(seq);
break;
}
this.seqs.push(seq);
if (this.seqs.length < 2)
continue;
// Remove starting tokenIds and some additional ones
await this.trimSequences(audioLanguage);
this.prevSeq = await this.handleOverlaps(this.seqs);
// last sequence processed
// overlaps are already handled, so just append the last seq
if (this.seqs.length === this.chunks.length) {
this.sequence = [...this.sequence, ...this.seqs.at(-1)];
this.decodedTranscribeCallback(this.sequence);
this.prevSeq = this.sequence;
}
}
const decodedText = await this.tokenIdsToText(this.sequence);
this.isGeneratingCallback(false);
return decodedText;
}
async streamingTranscribe(streamAction, waveform, audioLanguage) {
try {
if (!this.isReady)
throw Error(getError(ETError.ModuleNotLoaded));
if (!!audioLanguage !== this.config.isMultilingual)
throw new Error(getError(ETError.MultilingualConfiguration));
if (streamAction === STREAMING_ACTION.START &&
!this.streaming &&
this.isGenerating)
throw Error(getError(ETError.ModelGenerating));
if (streamAction === STREAMING_ACTION.START && this.streaming)
throw Error(getError(ETError.ModelGenerating));
if (streamAction === STREAMING_ACTION.DATA && !this.streaming)
throw Error(getError(ETError.StreamingNotStarted));
if (streamAction === STREAMING_ACTION.STOP && !this.streaming)
throw Error(getError(ETError.StreamingNotStarted));
if (streamAction === STREAMING_ACTION.DATA && !waveform)
throw new Error(getError(ETError.MissingDataChunk));
}
catch (e) {
this.onErrorCallback(e);
return '';
}
if (streamAction === STREAMING_ACTION.START) {
this.resetState();
this.streaming = true;
this.isGeneratingCallback(true);
}
this.waveform = [...this.waveform, ...(waveform || [])];
// while buffer has at least required size get chunk and decode
while (this.waveform.length >= this.expectedChunkLength()) {
const chunk = this.waveform.slice(0, this.windowSize +
this.overlapSeconds * (1 + Number(this.seqs.length > 0)));
this.chunks = [chunk]; //save last chunk for STREAMING_ACTION.STOP
this.waveform = this.waveform.slice(this.windowSize - this.overlapSeconds * Number(this.seqs.length === 0));
const seq = await this.decodeChunk(chunk, audioLanguage);
this.seqs.push(seq);
if (this.seqs.length < 2)
continue;
await this.trimSequences(audioLanguage);
await this.handleOverlaps(this.seqs);
}
// got final package, process all remaining waveform data
// since we run the loop above the waveform has at most one chunk in it
if (streamAction === STREAMING_ACTION.STOP) {
// pad remaining waveform data with previous chunk to this.windowSize + 2 * this.overlapSeconds
const chunk = this.chunks.length
? [
...this.chunks[0].slice(0, this.windowSize),
...this.waveform,
].slice(-this.windowSize - 2 * this.overlapSeconds)
: this.waveform;
this.waveform = [];
const seq = await this.decodeChunk(chunk, audioLanguage);
this.seqs.push(seq);
if (this.seqs.length === 1) {
this.sequence = this.seqs[0];
}
else {
await this.trimSequences(audioLanguage);
await this.handleOverlaps(this.seqs);
this.sequence = [...this.sequence, ...this.seqs.at(-1)];
}
this.decodedTranscribeCallback(this.sequence);
this.isGeneratingCallback(false);
this.streaming = false;
}
const decodedText = await this.tokenIdsToText(this.sequence);
return decodedText;
}
async encode(waveform) {
return await this.speechToTextNativeModule.encode(waveform);
}
async decode(seq) {
return await this.speechToTextNativeModule.decode(seq);
}
}