UNPKG

echogarden

Version:

An easy-to-use speech toolset. Includes tools for synthesis, recognition, alignment, speech translation, language detection, source separation and more.

897 lines (896 loc) 102 kB
import { Logger } from '../utilities/Logger.js'; import { computeMelSpectogramUsingFilterbanks } from '../dsp/MelSpectogram.js'; import { clip, getIntegerRange, splitFloat32Array, yieldToEventLoop } from '../utilities/Utilities.js'; import { indexOfMax, logOfVector, logSumExp, meanOfVector, softmax, sumOfSquaresForVector, sumVector } from '../math/VectorMath.js'; import { alignDTWWindowed } from '../alignment/DTWSequenceAlignmentWindowed.js'; import { extendDeep } from '../utilities/ObjectUtilities.js'; import { getRawAudioDuration, sliceRawAudio } from '../audio/AudioUtilities.js'; import { readFileAsUtf8 } from '../utilities/FileSystem.js'; import { formatLanguageCodeWithName, getShortLanguageCode, languageCodeToName } from '../utilities/Locale.js'; import { loadPackage } from '../utilities/PackageManager.js'; import chalk from 'chalk'; import { XorShift32PRNG } from '../utilities/RandomGenerator.js'; import { detectSpeechLanguageByParts } from '../api/SpeechLanguageDetection.js'; import { isPunctuation, isWhitespace, isWord, splitToSentences, splitToWords } from '../nlp/Segmentation.js'; import { medianOf5Filter } from '../math/MedianFilter.js'; import { getDeflateCompressionMetricsForString } from '../utilities/Compression.js'; import { dmlProviderAvailable, getOnnxSessionOptions, makeOnnxLikeFloat32Tensor } from '../utilities/OnnxUtilities.js'; import { murmurHash3_int32Input } from '../utilities/Hashing.js'; import { containsInvalidCodepoint, getTokenRepetitionScore } from '../utilities/StringUtilities.js'; import { decodeUtf8 } from '../encodings/Utf8.js'; import { joinPath } from '../utilities/PathUtilities.js'; export async function recognize(sourceRawAudio, modelName, modelDir, task, sourceLanguage, options) { options = extendDeep(defaultWhisperOptions, options); if (sourceRawAudio.sampleRate != 16000) { throw new Error('Source audio must have a sample rate of 16000 Hz'); } sourceLanguage = getShortLanguageCode(sourceLanguage); if (!(sourceLanguage in languageIdLookup)) { throw new Error(`The language ${formatLanguageCodeWithName(sourceLanguage)} is not supported by the Whisper engine.`); } if (isEnglishOnlyModel(modelName) && sourceLanguage != 'en') { throw new Error(`The model '${modelName}' can only be used with English inputs. However, the given source language was ${languageCodeToName(sourceLanguage)}.`); } if (modelName === 'large-v3-turbo' && task === 'translate') { throw new Error(`The 'large-v3-turbo' model doesn't support translation tasks.`); } if (options.temperature && options.temperature < 0) { throw new Error(`Temperature can't be negative`); } // Workaround issue with large-v3-turbo that produces invalid results when a prompt is passed to it. // Always disable autoprompting for that model. if (options.autoPromptParts && modelName == 'large-v3-turbo') { options.autoPromptParts = false; } // Select encoder ONNX provider const encoderProviders = options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName); // Select decoder ONNX provider const decoderProviders = options.decoderProvider ? [options.decoderProvider] : getDefaultDecoderProvidersForModel(modelName); const seed = options.seed; const whisper = new Whisper(modelName, modelDir, encoderProviders, decoderProviders, seed); const result = await whisper.recognize(sourceRawAudio, task, sourceLanguage, options); return result; } export async function align(sourceRawAudio, transcript, modelName, modelDir, sourceLanguage, options) { options = extendDeep(defaultWhisperAlignmentOptions, options); if (sourceRawAudio.sampleRate != 16000) { throw new Error('Source audio must have a sample rate of 16000 Hz'); } sourceLanguage = getShortLanguageCode(sourceLanguage); if (!(sourceLanguage in languageIdLookup)) { throw new Error(`The language ${formatLanguageCodeWithName(sourceLanguage)} is not supported by the Whisper engine.`); } if (isEnglishOnlyModel(modelName) && sourceLanguage != 'en') { throw new Error(`The model '${modelName}' can only be used with English inputs. However, the given source language was ${languageCodeToName(sourceLanguage)}.`); } // Select encoder ONNX provider const encoderProviders = options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName); // Select decoder ONNX provider const decoderProviders = options.decoderProvider ? [options.decoderProvider] : getDefaultDecoderProvidersForModel(modelName); const whisper = new Whisper(modelName, modelDir, encoderProviders, decoderProviders); const timeline = await whisper.align(sourceRawAudio, transcript, sourceLanguage, 'transcribe', options); return timeline; } export async function alignEnglishTranslation(sourceRawAudio, translatedTranscript, modelName, modelDir, sourceLanguage, options) { options = extendDeep(defaultWhisperAlignmentOptions, options); if (sourceRawAudio.sampleRate != 16000) { throw new Error('Source audio must have a sample rate of 16000 Hz'); } sourceLanguage = getShortLanguageCode(sourceLanguage); if (!(sourceLanguage in languageIdLookup)) { throw new Error(`The source language ${formatLanguageCodeWithName(sourceLanguage)} is not supported by the Whisper engine.`); } if (modelName === 'large-v3-turbo') { throw new Error(`The 'large-v3-turbo' model doesn't support translation tasks, so cannot be used for translation alignment.`); } if (isEnglishOnlyModel(modelName)) { throw new Error(`Translation alignment can only be done with multilingual models.`); } // Select encoder ONNX provider const encoderProviders = options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName); // Select decoder ONNX provider const decoderProviders = options.decoderProvider ? [options.decoderProvider] : getDefaultDecoderProvidersForModel(modelName); const whisper = new Whisper(modelName, modelDir, encoderProviders, decoderProviders); const timeline = await whisper.align(sourceRawAudio, translatedTranscript, sourceLanguage, 'translate', options); return timeline; } export async function detectLanguage(sourceRawAudio, modelName, modelDir, options) { options = extendDeep(defaultWhisperLanguageDetectionOptions, options); if (sourceRawAudio.sampleRate != 16000) { throw new Error('Source audio must have a sample rate of 16000 Hz'); } if (!isMultilingualModel(modelName)) { throw new Error(`Language detection is only supported with multilingual models.`); } if (options.temperature < 0) { throw new Error(`Temperature cannot be negative`); } // Select encoder ONNX provider const encoderProviders = options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName); // Select decoder ONNX provider const decoderProviders = options.decoderProvider ? [options.decoderProvider] : []; const whisper = new Whisper(modelName, modelDir, encoderProviders, decoderProviders); async function detectLanguageForPart(partAudio) { const audioFeatures = await whisper.encodeAudio(partAudio); const partResults = await whisper.detectLanguage(audioFeatures, options.temperature); return partResults; } const results = await detectSpeechLanguageByParts(sourceRawAudio, detectLanguageForPart); results.sort((entry1, entry2) => entry2.probability - entry1.probability); return results; } export async function detectVoiceActivity(sourceRawAudio, modelName, modelDir, options) { options = extendDeep(defaultWhisperVADOptions, options); if (sourceRawAudio.sampleRate != 16000) { throw new Error('Source audio must have a sample rate of 16000 Hz'); } if (options.temperature < 0) { throw new Error(`Temperature cannot be negative`); } const audioSamples = sourceRawAudio.audioChannels[0]; const partDuration = 5; const maxSamplesCountForPart = sourceRawAudio.sampleRate * partDuration; // Select encoder ONNX provider const encoderProviders = options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName); // Select decoder ONNX provider const decoderProviders = options.decoderProvider ? [options.decoderProvider] : []; const whisper = new Whisper(modelName, modelDir, encoderProviders, decoderProviders); const partProbabilities = []; for (let sampleOffset = 0; sampleOffset < audioSamples.length; sampleOffset += maxSamplesCountForPart) { const partSamples = sliceRawAudio(sourceRawAudio, sampleOffset, sampleOffset + maxSamplesCountForPart); const samplesCountForPart = partSamples.audioChannels[0].length; const startTime = sampleOffset / sourceRawAudio.sampleRate; const endTime = (sampleOffset + samplesCountForPart) / sourceRawAudio.sampleRate; const encodedPartSamples = await whisper.encodeAudio(partSamples); const probabilityForPart = await whisper.detectVoiceActivity(encodedPartSamples, options.temperature); partProbabilities.push({ type: 'segment', text: '', startTime, endTime, confidence: probabilityForPart, }); } return { partProbabilities }; } export class Whisper { modelName; modelDir; encoderExecutionProviders; decoderExecutionProviders; isMultiligualModel; audioEncoder; textDecoder; tiktoken; tokenConfig; randomGen; constructor(modelName, modelDir, encoderExecutionProviders, decoderExecutionProviders, prngSeed = 1234) { this.modelName = modelName; this.modelDir = modelDir; this.encoderExecutionProviders = encoderExecutionProviders; this.decoderExecutionProviders = decoderExecutionProviders; this.isMultiligualModel = isMultilingualModel(this.modelName); if (this.isMultiligualModel) { this.tokenConfig = { endOfTextToken: 50257, startOfTextToken: 50258, languageTokensStart: 50259, languageTokensEnd: 50358, translateTaskToken: 50358, transcribeTaskToken: 50359, startOfPromptToken: 50361, nonSpeechToken: 50362, noTimestampsToken: 50363, timestampTokensStart: 50364, timestampTokensEnd: 50364 + 1501, }; } else { this.tokenConfig = { endOfTextToken: 50256, startOfTextToken: 50257, languageTokensStart: 50258, languageTokensEnd: 50358, translateTaskToken: 50358, transcribeTaskToken: 50359, startOfPromptToken: 50360, nonSpeechToken: 50361, noTimestampsToken: 50362, timestampTokensStart: 50363, timestampTokensEnd: 50363 + 1501, }; } this.randomGen = new XorShift32PRNG(murmurHash3_int32Input(prngSeed)); } async recognize(rawAudio, task, language, options, logitFilter) { await this.initializeIfNeeded(); const logger = new Logger(); options = extendDeep(defaultWhisperOptions, options); options.model = this.modelName; if (!options.timestampAccuracy) { options.timestampAccuracy = this.defaultTimestampAccuracy; } const audioSamples = rawAudio.audioChannels[0]; const sampleRate = rawAudio.sampleRate; const prompt = options.prompt; const decodeTimestampTokens = options.decodeTimestampTokens; const maxAudioSamplesPerPart = sampleRate * 30; let previousPartTextTokens = []; let timeline = []; let allDecodedTokens = []; let wrappedLogitFilter; if (logitFilter) { wrappedLogitFilter = (logits, partDecodedTokens, isFirstPart, isFinalPart) => { return logitFilter(logits, [...allDecodedTokens, ...partDecodedTokens], isFirstPart, isFinalPart); }; } for (let audioOffset = 0; audioOffset < audioSamples.length;) { const segmentStartTime = audioOffset / sampleRate; await logger.startAsync(`\nPrepare audio part at time position ${segmentStartTime.toFixed(2)}`, undefined, chalk.magentaBright); const audioPartSamples = audioSamples.slice(audioOffset, audioOffset + maxAudioSamplesPerPart); const audioPartRawAudio = { audioChannels: [audioPartSamples], sampleRate }; const audioPartDuration = getRawAudioDuration(audioPartRawAudio); logger.end(); const audioPartFeatures = await this.encodeAudio(audioPartRawAudio); const isFirstPart = audioOffset === 0; const isFinalPart = audioOffset + maxAudioSamplesPerPart >= audioSamples.length; let initialTokens = []; if (isFirstPart && prompt) { const promptTokens = this.textToTokens(prompt); initialTokens = [this.tokenConfig.startOfPromptToken, ...promptTokens]; } else if (options.autoPromptParts && previousPartTextTokens.length > 0) { initialTokens = [this.tokenConfig.startOfPromptToken, ...previousPartTextTokens]; } initialTokens = [...initialTokens, ...this.getTextStartTokens(language, task, !decodeTimestampTokens)]; logger.end(); let { decodedTokens: partTokens, decodedTokensConfidence: partTokensConfidence, decodedTokensCrossAttentionQKs: partCrossAttentionQKs, } = await this.decodeTokens(audioPartFeatures, initialTokens, audioPartDuration, isFirstPart, isFinalPart, options, wrappedLogitFilter); const lastToken = partTokens[partTokens.length - 1]; const lastTokenIsTimestamp = this.isTimestampToken(lastToken); let audioEndOffset; if (!isFinalPart && lastTokenIsTimestamp) { const timePosition = this.timestampTokenToSeconds(lastToken); audioEndOffset = audioOffset + Math.floor(timePosition * sampleRate); } else { audioEndOffset = Math.min(audioOffset + maxAudioSamplesPerPart, audioSamples.length); } const segmentEndTime = audioEndOffset / sampleRate; const segmentFrameCount = this.secondsRangeToFrameCount(segmentStartTime, segmentEndTime); await logger.startAsync(`Extract timeline for part (timestamp accuracy: ${options.timestampAccuracy})`); if (partTokens.length != partCrossAttentionQKs.length) { throw new Error('Unexpected: partTokens.length != partCrossAttentionQKs.length'); } // Prepare tokens partTokens = partTokens.slice(initialTokens.length); partTokensConfidence = partTokensConfidence.slice(initialTokens.length); partCrossAttentionQKs = partCrossAttentionQKs.slice(initialTokens.length); // Find alignment path let alignmentHeads; if (options.timestampAccuracy === 'medium' || options.model == 'large-v3-turbo') { alignmentHeads = this.alignmentHeadIndexes; } else if (options.timestampAccuracy === 'high') { alignmentHeads = undefined; } else { throw new Error(`Unsupported timestamp accuracy '${options.timestampAccuracy}', can only be 'medium' or 'high'.`); } const alignmentPath = await this.findAlignmentPathFromQKs(partCrossAttentionQKs, partTokens, 0, segmentFrameCount, alignmentHeads); // Generate timeline from alignment path const partTimeline = await this.getTokenTimelineFromAlignmentPath(alignmentPath, partTokens, segmentStartTime, segmentEndTime, partTokensConfidence); // Add tokens to output allDecodedTokens.push(...partTokens); timeline.push(...partTimeline); // Determine compression ratio for recognized text (normalized to lowercase) of this part const compressionRatioForPart = (await getDeflateCompressionMetricsForString(this.tokensToText(partTokens).toLocaleLowerCase())).ratio; // If the recognized text isn't too repetitive if (compressionRatioForPart < options.repetitionThreshold) { // Set current part tokens as the previous part text tokens previousPartTextTokens = partTokens.filter(token => this.isTextToken(token)); } else { // Otherwise, set previous part tokens to an empty array previousPartTextTokens = []; } audioOffset = audioEndOffset; logger.end(); } // Convert token timeline to word timeline timeline = this.tokenTimelineToWordTimeline(timeline, language); // Convert tokens to transcript const transcript = this.tokensToText(allDecodedTokens).trim(); logger.end(); return { transcript, timeline, allDecodedTokens }; } async align(rawAudio, transcript, sourceLanguage, task, whisperAlignmentOptions) { await this.initializeTokenizerIfNeeded(); whisperAlignmentOptions = extendDeep(defaultWhisperAlignmentOptions, whisperAlignmentOptions); if (!whisperAlignmentOptions.timestampAccuracy) { whisperAlignmentOptions.timestampAccuracy = this.defaultTimestampAccuracy; } const targetLanguage = task === 'transcribe' ? sourceLanguage : 'en'; const shouldSplitToSentences = false; let simplifiedTranscript = ''; if (shouldSplitToSentences) { const sentences = splitToSentences(transcript, targetLanguage); for (const sentence of sentences) { let sentenceWords = await splitToWords(sentence, targetLanguage); sentenceWords = sentenceWords.filter(word => isWord(word)); simplifiedTranscript += sentenceWords.join(' '); simplifiedTranscript += ' '; } } else { let words = await splitToWords(transcript, targetLanguage); words = words.map(word => word.trim()); words = words.filter(word => isWord(word)); simplifiedTranscript = words.join(' '); } // Tokenize the transcript const simplifiedTranscriptTokens = this.textToTokens(simplifiedTranscript); // Initialize custom logit filter that allows only the transcript tokens to be decoded // in order. const endOfTextToken = this.tokenConfig.endOfTextToken; const logitFilter = (logits, decodedTokens, isFirstPart, isFinalPart) => { const decodedTextTokens = decodedTokens.filter(token => this.isTextToken(token)); const nextTokenToDecode = simplifiedTranscriptTokens[decodedTextTokens.length] ?? endOfTextToken; const newLogits = logits.map((logit, index) => { if (index === nextTokenToDecode) { return logit; } // If it's the final part, the ent-of-text token logit is set to -Infinity. // This will force to force all transcript tokens to be decoded even if the model doesn't // recognize them. if (!isFinalPart && index === endOfTextToken) { return logit; } return -Infinity; }); return newLogits; }; // Set options for alignment const options = { model: this.modelName, temperature: 0.0, prompt: undefined, topCandidateCount: 1, punctuationThreshold: Infinity, autoPromptParts: false, maxTokensPerPart: whisperAlignmentOptions.maxTokensPerPart, suppressRepetition: false, repetitionThreshold: Infinity, decodeTimestampTokens: true, endTokenThreshold: whisperAlignmentOptions.endTokenThreshold, includeEndTokenInCandidates: false, timestampAccuracy: whisperAlignmentOptions.timestampAccuracy, encoderProvider: whisperAlignmentOptions.encoderProvider, decoderProvider: whisperAlignmentOptions.decoderProvider, seed: undefined, }; // Recognize const { timeline, allDecodedTokens } = await this.recognize(rawAudio, task, sourceLanguage, options, logitFilter); { // If not all tokens were decoded, add the remaining ones to the timeline const lastKnownWordStartTime = timeline.length > 0 ? timeline[timeline.length - 1].startTime : 0; const allDecodedTextTokens = allDecodedTokens.filter(token => this.isTextToken(token)); while (allDecodedTextTokens.length < simplifiedTranscriptTokens.length) { const token = simplifiedTranscriptTokens[allDecodedTextTokens.length]; const tokenText = this.tokenToText(token); allDecodedTextTokens.push(token); const newTokenEntry = { type: 'token', text: tokenText, startTime: lastKnownWordStartTime, endTime: lastKnownWordStartTime, id: token, confidence: 0, }; if (tokenText.startsWith(' ') || timeline.length === 0) { timeline.push({ type: 'word', text: tokenText.trim(), startTime: lastKnownWordStartTime, endTime: lastKnownWordStartTime, timeline: [newTokenEntry], confidence: 0, }); } else { const lastWordEntry = timeline[timeline.length - 1]; lastWordEntry.timeline.push(newTokenEntry); lastWordEntry.text += tokenText; } } } return timeline; } async detectLanguage(audioFeatures, temperature) { if (!this.isMultiligualModel) { throw new Error('Language detection is only supported with multilingual models'); } await this.initializeDecoderSessionIfNeeded(); // Prepare and run decoder const logger = new Logger(); await logger.startAsync('Detect language with Whisper model'); const sotToken = this.tokenConfig.startOfTextToken; const initialTokens = [sotToken]; const offset = 0; const Onnx = await import('onnxruntime-node'); const initialKvDimensions = this.getKvDimensions(1, initialTokens.length); const kvCacheTensor = new Onnx.Tensor('float32', new Float32Array(initialKvDimensions[0] * initialKvDimensions[1] * initialKvDimensions[2] * initialKvDimensions[3]), initialKvDimensions); const tokensTensor = new Onnx.Tensor('int64', new BigInt64Array(initialTokens.map(token => BigInt(token))), [1, initialTokens.length]); const offsetTensor = new Onnx.Tensor('int64', new BigInt64Array([BigInt(offset)]), []); const decoderInputs = { tokens: tokensTensor, audio_features: audioFeatures, kv_cache: kvCacheTensor, offset: offsetTensor }; const decoderOutputs = await this.textDecoder.run(decoderInputs); const logitsBuffer = decoderOutputs['logits'].data; const tokenConfig = this.tokenConfig; const languageTokensLogits = Array.from(logitsBuffer.slice(tokenConfig.languageTokensStart, tokenConfig.languageTokensEnd)); const languageTokensProbabilities = softmax(languageTokensLogits, temperature); const results = []; for (const language in languageIdLookup) { const langId = languageIdLookup[language]; const probability = languageTokensProbabilities[langId]; results.push({ language, languageName: languageCodeToName(language), probability }); } logger.end(); return results; } async detectVoiceActivity(audioFeatures, temperature) { await this.initializeDecoderSessionIfNeeded(); // Prepare and run decoder const logger = new Logger(); await logger.startAsync('Detect voice activity with Whisper model'); const sotToken = this.tokenConfig.startOfTextToken; const initialTokens = [sotToken]; const offset = 0; const Onnx = await import('onnxruntime-node'); const initialKvDimensions = this.getKvDimensions(1, initialTokens.length); const kvCacheTensor = new Onnx.Tensor('float32', new Float32Array(initialKvDimensions[0] * initialKvDimensions[1] * initialKvDimensions[2] * initialKvDimensions[3]), initialKvDimensions); const tokensTensor = new Onnx.Tensor('int64', new BigInt64Array(initialTokens.map(token => BigInt(token))), [1, initialTokens.length]); const offsetTensor = new Onnx.Tensor('int64', new BigInt64Array([BigInt(offset)]), []); const decoderInputs = { tokens: tokensTensor, audio_features: audioFeatures, kv_cache: kvCacheTensor, offset: offsetTensor }; const decoderOutputs = await this.textDecoder.run(decoderInputs); const logitsBuffer = decoderOutputs['logits'].data; const tokenConfig = this.tokenConfig; const logits = Array.from(logitsBuffer); const probabilities = softmax(logits, temperature); const noSpeechProbability = probabilities[tokenConfig.nonSpeechToken]; return 1.0 - noSpeechProbability; } // Decode tokens using the decoder model async decodeTokens(audioFeatures, initialTokens, audioDuration, isFirstPart, isFinalPart, options, logitFilter) { // Initialize await this.initializeTokenizerIfNeeded(); await this.initializeDecoderSessionIfNeeded(); const logger = new Logger(); await logger.startAsync('Decode text tokens with Whisper decoder model'); options = extendDeep(defaultWhisperOptions, options); const Onnx = await import('onnxruntime-node'); // Get token information const endOfTextToken = this.tokenConfig.endOfTextToken; const timestampTokensStart = this.tokenConfig.timestampTokensStart; const suppressedTextTokens = this.getSuppressedTextTokens(); const suppressedMetadataTokens = this.getSuppressedMetadataTokens(); const allowedPunctuationMarks = this.getAllowedPunctuationMarks(); const spaceToken = this.textToTokens(' ')[0]; // Initialize variables for decoding loop let decodedTokens = initialTokens.slice(); const initialKvDimensions = this.getKvDimensions(1, decodedTokens.length); let kvCacheTensor = new Onnx.Tensor('float32', new Float32Array(initialKvDimensions[0] * initialKvDimensions[1] * initialKvDimensions[2] * initialKvDimensions[3]), initialKvDimensions); let decodedTokensTimestampLogits = []; let decodedTokensConfidence = []; let decodedTokensCrossAttentionQKs = []; for (let i = 0; i < decodedTokens.length; i++) { decodedTokensTimestampLogits.push(new Array(1501)); // Should the length be 1500 instead? decodedTokensConfidence.push(1.0); decodedTokensCrossAttentionQKs.push(undefined); } let lastTimestampTokenIndex = -1; let timestampTokenSeenCount = 0; let bufferedTokensToPrint = []; // Define method to add a token to output function addToken(tokenToAdd, timestampLogits, confidence, crossAttentionQKs) { decodedTokens.push(tokenToAdd); decodedTokensTimestampLogits.push(timestampLogits); decodedTokensConfidence.push(confidence); decodedTokensCrossAttentionQKs.push(crossAttentionQKs); } // Start decoding loop for (let decodedTokenCount = 0; decodedTokenCount < options.maxTokensPerPart; decodedTokenCount++) { const isInitialState = decodedTokens.length === initialTokens.length; const atLeastOneTextTokenDecoded = decodedTokens.slice(initialTokens.length).some(token => this.isTextToken(token)); // If not in initial state, reshape KV Cache tensor to accomodate a new output token if (!isInitialState) { const dims = kvCacheTensor.dims; const currentKvCacheGroups = splitFloat32Array(kvCacheTensor.data, dims[2] * dims[3]); const reshapedKvCacheTensor = new Onnx.Tensor('float32', new Float32Array(dims[0] * dims[1] * (decodedTokens.length) * dims[3]), [dims[0], dims[1], decodedTokens.length, dims[3]]); const reshapedKvCacheGroups = splitFloat32Array(reshapedKvCacheTensor.data, decodedTokens.length * dims[3]); for (let i = 0; i < dims[0]; i++) { reshapedKvCacheGroups[i].set(currentKvCacheGroups[i]); } kvCacheTensor = reshapedKvCacheTensor; } // Prepare values for decoder const tokensToDecode = isInitialState ? decodedTokens : [decodedTokens[decodedTokens.length - 1]]; const offset = isInitialState ? 0 : decodedTokens.length; const tokensTensor = new Onnx.Tensor('int64', new BigInt64Array(tokensToDecode.map(token => BigInt(token))), [1, tokensToDecode.length]); const offsetTensor = new Onnx.Tensor('int64', new BigInt64Array([BigInt(offset)]), []); const decoderInputs = { tokens: tokensTensor, audio_features: audioFeatures, kv_cache: kvCacheTensor, offset: offsetTensor }; // Run decoder model const decoderOutputs = await this.textDecoder.run(decoderInputs); // Extract decoder model results const logitsBuffer = decoderOutputs['logits'].data; kvCacheTensor = decoderOutputs['output_kv_cache']; const crossAttentionQKsForTokenOnnx = decoderOutputs['cross_attention_qks']; const crossAttentionQKsForToken = makeOnnxLikeFloat32Tensor(crossAttentionQKsForTokenOnnx); crossAttentionQKsForTokenOnnx.dispose(); // Get logits const resultLogitsFloatArrays = splitFloat32Array(logitsBuffer, logitsBuffer.length / decoderOutputs['logits'].dims[1]); const allTokenLogits = Array.from(resultLogitsFloatArrays[resultLogitsFloatArrays.length - 1]); // Suppress metadata tokens in the suppression set for (const suppressedTokenIndex of suppressedMetadataTokens) { allTokenLogits[suppressedTokenIndex] = -Infinity; } if (!atLeastOneTextTokenDecoded) { // If in initial state, suppress end-of-text token allTokenLogits[endOfTextToken] = -Infinity; } const timestampTokenLogits = allTokenLogits.slice(timestampTokensStart); const decodeTimestampTokenIfNeeded = () => { // Try to decode a timestamp token, if needed // If timestamp tokens is disabled in options, don't decode a timestamp if (!options.decodeTimestampTokens) { return false; } // If this is the first token in the part, unconditionally decode a timestamp token // for time 0.0 if (isInitialState) { addToken(timestampTokensStart, timestampTokenLogits, 1.0, crossAttentionQKsForToken); return true; } const previousTokenWasTimestamp = this.isTimestampToken(decodedTokens[decodedTokens.length - 1]); const secondPreviousTokenWasTimestamp = this.isTimestampToken(decodedTokens[decodedTokens.length - 2]); // If there are two successive timestamp tokens decoded, or the previous timestamp was the first token, // don't decode a timestamp if (previousTokenWasTimestamp && ((decodedTokens.length === initialTokens.length + 1) || secondPreviousTokenWasTimestamp)) { return false; } // Derive token probabilities const allTokenProbabilities = softmax(allTokenLogits, 1.0); const allTokenLogProbabilities = logOfVector(allTokenProbabilities); const nonTimestampTokenLogProbs = allTokenLogProbabilities.slice(0, timestampTokensStart); // Find highest non-timestamp token const indexOfMaxNonTimestampLogProb = indexOfMax(nonTimestampTokenLogProbs); const valueOfMaxNonTimestampLogProb = nonTimestampTokenLogProbs[indexOfMaxNonTimestampLogProb]; // Find highest timestamp token const timestampTokenLogProbs = allTokenLogProbabilities.slice(timestampTokensStart); const indexOfMaxTimestampLogProb = indexOfMax(timestampTokenLogProbs); // Compute the log of the sum of exponentials of the log probabilities // of the timestamp tokens const logSumExpOfTimestampTokenLogProbs = logSumExp(timestampTokenLogProbs); // If the sum isn't greater than the log probability of the highest non-timestamp token, // don't decode a timestamp if (logSumExpOfTimestampTokenLogProbs <= valueOfMaxNonTimestampLogProb) { return false; } // Decode a timestamp token timestampTokenSeenCount += 1; if (previousTokenWasTimestamp) { // If previously decoded token was a timestamp token, repeat it const previousToken = decodedTokens[decodedTokens.length - 1]; const previousTokenTimestampLogits = decodedTokensTimestampLogits[decodedTokensTimestampLogits.length - 1]; const previousTokenConfidence = decodedTokensConfidence[decodedTokensConfidence.length - 1]; addToken(previousToken, previousTokenTimestampLogits, previousTokenConfidence, crossAttentionQKsForToken); lastTimestampTokenIndex = decodedTokens.length; } else { // Otherwise decode the highest probability timestamp const timestampToken = timestampTokensStart + indexOfMaxTimestampLogProb; const confidence = allTokenProbabilities[timestampToken]; addToken(timestampToken, timestampTokenLogits, confidence, crossAttentionQKsForToken); } return true; }; // Call the method to decode timestamp token if needed const timestampTokenDecoded = decodeTimestampTokenIfNeeded(); if (timestampTokenDecoded) { await yieldToEventLoop(); continue; } // Decode a non-timestamp token let nonTimestampTokenLogits = allTokenLogits.slice(0, timestampTokensStart); let shouldDecodeEndfOfTextToken = false; // If at least one text token was decoded, and the end-of-text token's probability is // sufficiently higher than the second highest ranked token, then accept end-of-text if (atLeastOneTextTokenDecoded) { const endOfTextTokenLogit = nonTimestampTokenLogits[endOfTextToken]; const otherTokensLogits = nonTimestampTokenLogits.slice(); otherTokensLogits[endOfTextToken] = -Infinity; const indexOfMaximumOtherTokenLogit = indexOfMax(otherTokensLogits); const maximumOtherTokenLogit = nonTimestampTokenLogits[indexOfMaximumOtherTokenLogit]; const endProbabilities = softmax([endOfTextTokenLogit, maximumOtherTokenLogit], 1.0); if (endProbabilities[0] > options.endTokenThreshold) { shouldDecodeEndfOfTextToken = true; } } if (logitFilter) { // Apply custom logit filter function if given nonTimestampTokenLogits = logitFilter(nonTimestampTokenLogits, decodedTokens, isFirstPart, isFinalPart); // If the custom filter set the end-of-text token to be Infinity, or -Infinity, // then override any previous decision and accept or reject it, respectively if (nonTimestampTokenLogits[endOfTextToken] === Infinity) { shouldDecodeEndfOfTextToken = true; } else if (nonTimestampTokenLogits[endOfTextToken] === -Infinity) { shouldDecodeEndfOfTextToken = false; } // If filter caused all word token logits to be -Infinity, then there is no // other token to decode. Fall back to accept end-of-text if (nonTimestampTokenLogits.slice(0, endOfTextToken).every(logit => logit === -Infinity)) { shouldDecodeEndfOfTextToken = true; } } else { // Otherwise, suppress text tokens in the suppression set for (const suppressedTokenIndex of suppressedTextTokens) { nonTimestampTokenLogits[suppressedTokenIndex] = -Infinity; } // Suppress the space token if at initial state if (isInitialState) { nonTimestampTokenLogits[spaceToken] = -Infinity; } } // If end-of-text token should be decoded, then add it and break // out of the loop if (shouldDecodeEndfOfTextToken) { addToken(endOfTextToken, timestampTokenLogits, 1.0, crossAttentionQKsForToken); break; } // Suppress end-of-text token if it shouldn't be included in candidates if (!options.includeEndTokenInCandidates) { nonTimestampTokenLogits[endOfTextToken] = -Infinity; } // Find top candidates const sortedNonTimestampLogitsWithIndexes = Array.from(nonTimestampTokenLogits).map((logit, index) => ({ token: index, logit })); sortedNonTimestampLogitsWithIndexes.sort((a, b) => b.logit - a.logit); let topCandidates = sortedNonTimestampLogitsWithIndexes.slice(0, options.topCandidateCount) .map(entry => ({ token: entry.token, logit: entry.logit, text: this.tokenToText(entry.token, true) })); // Apply repetition suppression if enabled if (options.suppressRepetition) { // Using some hardcoded constants, for now const tokenWindowSize = 30; const thresholdMatchLength = 4; const thresholdCycleRepetition = 3; const filteredCandidates = []; for (const candidate of topCandidates) { const lastDecodedTextTokens = decodedTokens .filter(token => this.isTextToken(token)) .reverse() .slice(0, tokenWindowSize); const { longestMatch, longestCycleRepetition } = getTokenRepetitionScore([candidate.token, ...lastDecodedTextTokens]); if (longestMatch >= thresholdMatchLength || longestCycleRepetition >= thresholdCycleRepetition) { continue; } filteredCandidates.push(candidate); } // If all candidates have been filtered out, accept an end-of-text token if (filteredCandidates.length === 0) { filteredCandidates.push({ token: endOfTextToken, logit: Infinity, text: this.tokenToText(endOfTextToken, true) }); } topCandidates = filteredCandidates; } // Compute top candidate probabilities const topCandidateProbabilities = softmax(topCandidates.map(a => a.logit), options.temperature); // Find highest ranking punctuation token const rankOfPromisingPunctuationToken = topCandidates.findIndex((entry, index) => { const tokenText = this.tokenToText(entry.token).trim(); const isPunctuationToken = allowedPunctuationMarks.includes(tokenText); if (!isPunctuationToken) { return false; } const tokenProb = topCandidateProbabilities[index]; return tokenProb >= options.punctuationThreshold; }); // Find rank of space token let rankOfSpaceToken = topCandidates.findIndex(candidate => candidate.token === spaceToken); if (rankOfSpaceToken < 0) { rankOfSpaceToken = Infinity; } // Choose token let chosenCandidateRank; // Select a high-ranking punctuation token if found, and it has // a rank higher than the space token, if (rankOfPromisingPunctuationToken >= 0 && rankOfPromisingPunctuationToken < rankOfSpaceToken) { chosenCandidateRank = rankOfPromisingPunctuationToken; } else { // Otherwise, select randomly from top k distribution chosenCandidateRank = this.randomGen.selectRandomIndexFromDistribution(topCandidateProbabilities); } // Add chosen token const chosenToken = topCandidates[chosenCandidateRank].token; const chosenTokenConfidence = topCandidateProbabilities[chosenCandidateRank]; addToken(chosenToken, timestampTokenLogits, chosenTokenConfidence, crossAttentionQKsForToken); // If chosen token is the end-of-text token, break if (chosenToken === endOfTextToken) { break; } // Print token if needed if (this.isTextToken(chosenToken)) { bufferedTokensToPrint.push(chosenToken); let textToPrint = this.tokensToText(bufferedTokensToPrint); // If the decoded text is valid, print it if (!containsInvalidCodepoint(textToPrint)) { if (isFirstPart && decodedTokens.every(token => this.isMetadataToken(token))) { textToPrint = textToPrint.trimStart(); } logger.write(textToPrint); bufferedTokensToPrint = []; } } await yieldToEventLoop(); } // If at least two timestamp tokens were decoded and it's not the final part, // truncate up to the last timestamp token if (timestampTokenSeenCount >= 2 && !isFinalPart) { const sliceEndTokenIndex = lastTimestampTokenIndex; decodedTokens = decodedTokens.slice(0, sliceEndTokenIndex); decodedTokensTimestampLogits = decodedTokensTimestampLogits.slice(0, sliceEndTokenIndex); decodedTokensCrossAttentionQKs = decodedTokensCrossAttentionQKs.slice(0, sliceEndTokenIndex); decodedTokensConfidence = decodedTokensConfidence.slice(0, sliceEndTokenIndex); } logger.write('\n'); logger.end(); // Return the decoded tokens return { decodedTokens, decodedTokensTimestampLogits, decodedTokensConfidence, decodedTokensCrossAttentionQKs, }; } // Encode audio using the encoder model async encodeAudio(rawAudio) { await this.initializeEncoderSessionIfNeeded(); const Onnx = await import('onnxruntime-node'); const logger = new Logger(); const audioSamples = rawAudio.audioChannels[0]; const sampleRate = rawAudio.sampleRate; const fftOrder = 400; const fftWindowSize = 400; const fftHopLength = 160; const filterbankCount = this.filterbankCount; const filterbanks = this.filterbanks; const maxAudioSamples = sampleRate * 30; const maxAudioFrames = 3000; if (sampleRate != 16000) { throw new Error('Audio must have a sample rate of 16000 Hz'); } if (audioSamples.length > maxAudioSamples) { throw new Error(`Audio part is longer than 30 seconds`); } // Compute a mel spectogram await logger.startAsync('Extract mel spectogram from audio part'); // Pad audio samples to ensure that have a duration of 30 seconds const paddedAudioSamples = new Float32Array(maxAudioSamples); paddedAudioSamples.set(audioSamples, 0); const rawAudioPart = { audioChannels: [paddedAudioSamples], sampleRate }; const { melSpectogram } = await computeMelSpectogramUsingFilterbanks(rawAudioPart, fftOrder, fftWindowSize, fftHopLength, filterbanks); await logger.startAsync('Normalize mel spectogram'); const logMelSpectogram = melSpectogram.map(spectrum => spectrum.map(mel => Math.log10(Math.max(mel, 1e-10)))); // Find maximum log mel value in the spectrum let maxLogMel = -Infinity; for (const spectrum of logMelSpectogram) { for (const mel of spectrum) { if (mel > maxLogMel) { maxLogMel = mel; } } } // Normalize log mel spectogram (based on Python reference code) const normalizedLogMelSpectogram = logMelSpectogram.map(spectrum => spectrum.map(logMel => (Math.max(logMel, maxLogMel - 8) + 4) / 4)); // Flatten the normalized log mel spectogram const flattenedNormalizedLogMelSpectogram = new Float32Array(maxAudioFrames * filterbankCount); for (let i = 0; i < filterbankCount; i++) { for (let j = 0; j < maxAudioFrames; j++) { flattenedNormalizedLogMelSpectogram[(i * maxAudioFrames) + j] = normalizedLogMelSpectogram[j][i]; } } // Run the encoder model await logger.startAsync('Encode mel spectogram with Whisper encoder model'); const inputTensor = new Onnx.Tensor('float32', flattenedNormalizedLogMelSpectogram, [1, filterbankCount, maxAudioFrames]); const encoderInputs = { mel: inputTensor }; const encoderOutputs = await this.audioEncoder.run(encoderInputs); const encodedAudioFeatures = encoderOutputs['output']; logger.end(); return encodedAudioFeatures; } tokenTimelineToWordTimeline(tokenTimeline, language) { function isSeparatorCharacter(char) { const nonSeparatingPunctuation = [`'`, `-`, `.`, `·`, `•`]; if (nonSeparatingPunctuation.includes(char)) { return false; } return isWhitespace(char) || isPunctuation(char); } function startsWithSeparatorCharacter(text) { return isSeparatorCharacter(text[0]); } function endsWithSeparatorCharacter(text) { return isSeparatorCharacter(text[text.length - 1]); } if (language != 'zh' && language != 'ja') { tokenTimeline = tokenTimeline.filter(entry => this.isTextToken(entry.id)); } const resultTimeline = []; let groups = []; for (let tokenIndex = 0; tokenIndex < tokenTimeline.length; tokenIndex++) { const entry = tokenTimeline[tokenIndex]; const previousEntry = tokenIndex > 0 ? tokenTimeline[tokenIndex - 1] : undefined; const text = entry.text; const previousEntryText = previousEntry?.text; if (groups.length == 0 || text === '' || startsWithSeparatorCharacter(text) || (previousEntryText != null && endsWithSeparatorCharacter(previousEntryText))) { groups.push([entry]); } else { groups[groups.length - 1].push(entry); } } { const splitGroups = []; for (let groupIndex = 0; groupIndex < groups.length; groupIndex++) { const group = groups[groupIndex]; const nextGroup = groups[groupIndex + 1]; if (group.length > 1 && group[group.length - 1].text === '.' && (!nextGroup || [' ', '['].includes(nextGroup[0].text[0]))) { splitGroups.push(group.slice(0, group.length - 1)); splitGroups.push(group.slice(group.length - 1)); } else { splitGroups.push(group); } } groups = splitGroups; } for (const group of groups) { let groupText = this.tokensToText(group.map(entry => entry.id)); if (groupText === '') { continue; } const startTime = group[0].startTime; const endTime = group[group.length - 1].endTime; let confidence = undefined;