UNPKG

echogarden

Version:

An easy-to-use speech toolset. Includes tools for synthesis, recognition, alignment, speech translation, language detection, source separation and more.

1,465 lines (1,083 loc) 97.9 kB
import chalk from 'chalk' import type * as Onnx from 'onnxruntime-node' import { Logger } from '../utilities/Logger.js' import { computeMelSpectrogramUsingFilterbanks, Filterbank } from '../dsp/MelSpectrogram.js' import { clip, getIntegerRange, getTopKIndexes, splitFloat32Array, yieldToEventLoop } from '../utilities/Utilities.js' import { indexOfMax, logOfVector, logSumExp, meanOfVector, medianOfVector, softmax, sumAndSumOfSquaresOfVector, sumOfSquaresOfVector, sumVector } from '../math/VectorMath.js' import { alignDTWWindowed } from '../alignment/DTWSequenceAlignmentWindowed.js' import { extendDeep } from '../utilities/ObjectUtilities.js' import { Timeline, TimelineEntry } from '../utilities/Timeline.js' import { AlignmentPath } from '../alignment/SpeechAlignment.js' import { getRawAudioDuration, RawAudio, sliceRawAudio } from '../audio/AudioUtilities.js' import { readFileAsUtf8 } from '../utilities/FileSystem.js' import { logLevelGreaterOrEqualTo, type LanguageDetectionResults } from '../api/API.js' import { formatLanguageCodeWithName, getShortLanguageCode, languageCodeToName } from '../utilities/Locale.js' import { loadPackage } from '../utilities/PackageManager.js' import { XorShift32PRNG } from '../utilities/RandomGenerator.js' import { detectSpeechLanguageByParts } from '../api/SpeechLanguageDetection.js' import { type Tiktoken } from 'tiktoken/lite' import { includesPunctuation, isWhitespace, splitToWords } from '../nlp/Segmentation.js' import { medianOf5Filter } from '../math/MedianFilter.js' import { getDeflateCompressionMetricsForString } from '../utilities/Compression.js' import { dmlProviderAvailable, getOnnxSessionOptions, makeOnnxLikeFloat32Tensor, OnnxExecutionProvider, OnnxLikeFloat32Tensor } from '../utilities/OnnxUtilities.js' import { murmurHash3_int32Input } from '../utilities/Hashing.js' import { containsInvalidCodepoint, getTokenRepetitionScore } from '../utilities/StringUtilities.js' import { joinPath } from '../utilities/PathUtilities.js' import { Timer } from '../utilities/Timer.js' export async function recognize( sourceRawAudio: RawAudio, modelName: WhisperModelName, modelDir: string, task: WhisperTask, sourceLanguage: string, options: WhisperOptions, onPart?: WhisperPartCallback) { options = extendDeep(defaultWhisperOptions, options) if (sourceRawAudio.sampleRate !== 16000) { throw new Error('Source audio must have a sample rate of 16000 Hz') } sourceLanguage = getShortLanguageCode(sourceLanguage) if (!(sourceLanguage in languageIdLookup)) { throw new Error(`The language ${formatLanguageCodeWithName(sourceLanguage)} is not supported by the Whisper engine.`) } if (isEnglishOnlyModel(modelName) && sourceLanguage !== 'en') { throw new Error(`The model '${modelName}' can only be used with English inputs. However, the given source language was ${languageCodeToName(sourceLanguage)}.`) } if (modelName === 'large-v3-turbo' && task === 'translate') { throw new Error(`The 'large-v3-turbo' model doesn't support translation tasks.`) } if (options.temperature && options.temperature < 0) { throw new Error(`Temperature can't be negative`) } // Workaround issue with large-v3-turbo that produces invalid results when a prompt is passed to it. // Always disable autoprompting for that model. if (options.autoPromptParts && modelName === 'large-v3-turbo') { options.autoPromptParts = false } // Select encoder ONNX provider const encoderProviders: OnnxExecutionProvider[] = options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName) // Select decoder ONNX provider const decoderProviders: OnnxExecutionProvider[] = options.decoderProvider ? [options.decoderProvider] : getDefaultDecoderProvidersForModel(modelName) const seed = options.seed const whisper = new Whisper( modelName, modelDir, encoderProviders, decoderProviders, seed) const result = await whisper.recognize(sourceRawAudio, task, sourceLanguage, options, undefined, onPart) return result } export async function align( sourceRawAudio: RawAudio, transcript: string, modelName: WhisperModelName, modelDir: string, sourceLanguage: string, options: WhisperAlignmentOptions) { options = extendDeep(defaultWhisperAlignmentOptions, options) if (sourceRawAudio.sampleRate !== 16000) { throw new Error('Source audio must have a sample rate of 16000 Hz') } sourceLanguage = getShortLanguageCode(sourceLanguage) if (!(sourceLanguage in languageIdLookup)) { throw new Error(`The language ${formatLanguageCodeWithName(sourceLanguage)} is not supported by the Whisper engine.`) } if (isEnglishOnlyModel(modelName) && sourceLanguage !== 'en') { throw new Error(`The model '${modelName}' can only be used with English inputs. However, the given source language was ${languageCodeToName(sourceLanguage)}.`) } // Select encoder ONNX provider const encoderProviders: OnnxExecutionProvider[] = options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName) // Select decoder ONNX provider const decoderProviders: OnnxExecutionProvider[] = options.decoderProvider ? [options.decoderProvider] : getDefaultDecoderProvidersForModel(modelName) const whisper = new Whisper( modelName, modelDir, encoderProviders, decoderProviders,) const timeline = await whisper.align(sourceRawAudio, transcript, sourceLanguage, 'transcribe', options) return timeline } export async function alignEnglishTranslation( sourceRawAudio: RawAudio, translatedTranscript: string, modelName: WhisperModelName, modelDir: string, sourceLanguage: string, options: WhisperAlignmentOptions) { options = extendDeep(defaultWhisperAlignmentOptions, options) if (sourceRawAudio.sampleRate !== 16000) { throw new Error('Source audio must have a sample rate of 16000 Hz') } sourceLanguage = getShortLanguageCode(sourceLanguage) if (!(sourceLanguage in languageIdLookup)) { throw new Error(`The source language ${formatLanguageCodeWithName(sourceLanguage)} is not supported by the Whisper engine.`) } if (modelName === 'large-v3-turbo') { throw new Error(`The 'large-v3-turbo' model doesn't support translation tasks, so cannot be used for translation alignment.`) } if (isEnglishOnlyModel(modelName)) { throw new Error(`Translation alignment can only be done with multilingual models.`) } // Select encoder ONNX provider const encoderProviders: OnnxExecutionProvider[] = options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName) // Select decoder ONNX provider const decoderProviders: OnnxExecutionProvider[] = options.decoderProvider ? [options.decoderProvider] : getDefaultDecoderProvidersForModel(modelName) const whisper = new Whisper( modelName, modelDir, encoderProviders, decoderProviders,) const timeline = await whisper.align(sourceRawAudio, translatedTranscript, sourceLanguage, 'translate', options) return timeline } export async function detectLanguage( sourceRawAudio: RawAudio, modelName: WhisperModelName, modelDir: string, options: WhisperLanguageDetectionOptions) { options = extendDeep(defaultWhisperLanguageDetectionOptions, options) if (sourceRawAudio.sampleRate !== 16000) { throw new Error('Source audio must have a sample rate of 16000 Hz') } if (!isMultilingualModel(modelName)) { throw new Error(`Language detection is only supported with multilingual models.`) } if (options.temperature! < 0) { throw new Error(`Temperature cannot be negative`) } // Select encoder ONNX provider const encoderProviders: OnnxExecutionProvider[] = options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName) // Select decoder ONNX provider const decoderProviders: OnnxExecutionProvider[] = options.decoderProvider ? [options.decoderProvider] : [] const whisper = new Whisper( modelName, modelDir, encoderProviders, decoderProviders) async function detectLanguageForPart(partAudio: RawAudio) { const audioFeatures = await whisper.encodeAudio(partAudio) const partResults = await whisper.detectLanguage(audioFeatures, options.temperature!) return partResults } const results = await detectSpeechLanguageByParts(sourceRawAudio, detectLanguageForPart) results.sort((entry1, entry2) => entry2.probability - entry1.probability) return results } export async function detectVoiceActivity( sourceRawAudio: RawAudio, modelName: WhisperModelName, modelDir: string, options: WhisperVADOptions) { options = extendDeep(defaultWhisperVADOptions, options) if (sourceRawAudio.sampleRate !== 16000) { throw new Error('Source audio must have a sample rate of 16000 Hz') } if (options.temperature! < 0) { throw new Error(`Temperature cannot be negative`) } const audioSamples = sourceRawAudio.audioChannels[0] const partDuration = 5 const maxSamplesCountForPart = sourceRawAudio.sampleRate * partDuration // Select encoder ONNX provider const encoderProviders: OnnxExecutionProvider[] = options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName) // Select decoder ONNX provider const decoderProviders: OnnxExecutionProvider[] = options.decoderProvider ? [options.decoderProvider] : [] const whisper = new Whisper( modelName, modelDir, encoderProviders, decoderProviders) const partProbabilities: Timeline = [] for (let sampleOffset = 0; sampleOffset < audioSamples.length; sampleOffset += maxSamplesCountForPart) { const partSamples = sliceRawAudio(sourceRawAudio, sampleOffset, sampleOffset + maxSamplesCountForPart) const samplesCountForPart = partSamples.audioChannels[0].length const startTime = sampleOffset / sourceRawAudio.sampleRate const endTime = (sampleOffset + samplesCountForPart) / sourceRawAudio.sampleRate const encodedPartSamples = await whisper.encodeAudio(partSamples) const probabilityForPart = await whisper.detectVoiceActivity(encodedPartSamples, options.temperature!) partProbabilities.push({ type: 'segment', text: '', startTime, endTime, confidence: probabilityForPart, }) } return { partProbabilities } } export class Whisper { isMultiligualModel: boolean audioEncoder?: Onnx.InferenceSession textDecoder?: Onnx.InferenceSession tiktoken?: Tiktoken tokenConfig: { endOfTextToken: number startOfTextToken: number languageTokensStart: number languageTokensEnd: number translateTaskToken: number transcribeTaskToken: number startOfPromptToken: number nonSpeechToken: number noTimestampsToken: number timestampTokensStart: number timestampTokensEnd: number } randomGen: XorShift32PRNG constructor( public readonly modelName: WhisperModelName, public readonly modelDir: string, public readonly encoderExecutionProviders: OnnxExecutionProvider[], public readonly decoderExecutionProviders: OnnxExecutionProvider[], prngSeed = 1234) { this.isMultiligualModel = isMultilingualModel(this.modelName) if (this.isMultiligualModel) { this.tokenConfig = { endOfTextToken: 50257, startOfTextToken: 50258, languageTokensStart: 50259, languageTokensEnd: 50358, translateTaskToken: 50358, transcribeTaskToken: 50359, startOfPromptToken: 50361, nonSpeechToken: 50362, noTimestampsToken: 50363, timestampTokensStart: 50364, timestampTokensEnd: 50364 + 1501, } } else { this.tokenConfig = { endOfTextToken: 50256, startOfTextToken: 50257, languageTokensStart: 50258, languageTokensEnd: 50358, translateTaskToken: 50358, transcribeTaskToken: 50359, startOfPromptToken: 50360, nonSpeechToken: 50361, noTimestampsToken: 50362, timestampTokensStart: 50363, timestampTokensEnd: 50363 + 1501, } } this.randomGen = new XorShift32PRNG(murmurHash3_int32Input(prngSeed)) } async recognize( rawAudio: RawAudio, task: WhisperTask, language: string, options: WhisperOptions, logitFilter?: WhisperLogitFilter, onPart?: WhisperPartCallback, ) { await this.initializeIfNeeded() const logger = new Logger() options = extendDeep(defaultWhisperOptions, options) options.model = this.modelName if (!options.timestampAccuracy) { options.timestampAccuracy = this.defaultTimestampAccuracy } if (options.maxTokensPerPart! > largestMaximumTokensPerPart) { //throw new Error(`The number of tokens per part cannot be greater than ${largestMaximumTokensPerPart}`) options.maxTokensPerPart = largestMaximumTokensPerPart } const audioSamples = rawAudio.audioChannels[0] const sampleRate = rawAudio.sampleRate const prompt = options.prompt const decodeTimestampTokens = options.decodeTimestampTokens! const maxAudioSamplesPerPart = sampleRate * 30 let previousPartTextTokens: number[] = [] let timeline: Timeline = [] let allDecodedTokens: number[] = [] let wrappedLogitFilter: WhisperLogitFilter | undefined if (logitFilter) { wrappedLogitFilter = (logits, partDecodedTokens, isFirstPart, isFinalPart) => { return logitFilter(logits, [...allDecodedTokens, ...partDecodedTokens], isFirstPart, isFinalPart) } } for (let audioOffset = 0; audioOffset < audioSamples.length;) { const segmentStartTime = audioOffset / sampleRate await logger.startAsync(`\nPrepare audio part at time position ${segmentStartTime.toFixed(2)}`, undefined, chalk.magentaBright) const audioPartSamples = audioSamples.slice(audioOffset, audioOffset + maxAudioSamplesPerPart) const audioPartRawAudio: RawAudio = { audioChannels: [audioPartSamples], sampleRate } const audioPartDuration = getRawAudioDuration(audioPartRawAudio) logger.end() const audioPartFeatures = await this.encodeAudio(audioPartRawAudio) const isFirstPart = audioOffset === 0 const isFinalPart = audioOffset + maxAudioSamplesPerPart >= audioSamples.length let initialTokens: number[] = [] if (isFirstPart && prompt) { let promptTokens = this.textToTokens(prompt) if (promptTokens.length > largestMaximumTokensPerPart) { promptTokens = promptTokens.slice(promptTokens.length - largestMaximumTokensPerPart) } initialTokens = [this.tokenConfig.startOfPromptToken, ...promptTokens] } else if (options.autoPromptParts && previousPartTextTokens.length > 0) { initialTokens = [this.tokenConfig.startOfPromptToken, ...previousPartTextTokens] } initialTokens = [...initialTokens, ...this.getTextStartTokens(language, task, !decodeTimestampTokens)] //logger.log(`Initial tokens count: ${initialTokens.length}`) logger.end() let { decodedTokens: partTokens, decodedTokensConfidence: partTokensConfidence, decodedTokensCrossAttentionQKs: partTokensCrossAttentionQKs, decodedTokensDecodingTime: partTokensDecodingTime, decodedTokensInferenceTime: partTokensInferenceTime, decodedTokensOverheadTime: partTokensOverheadTime, } = await this.decodeTokens( audioPartFeatures, initialTokens, audioPartDuration, isFirstPart, isFinalPart, options, wrappedLogitFilter, ) const lastToken = partTokens[partTokens.length - 1] const lastTokenIsTimestamp = this.isTimestampToken(lastToken) let audioEndOffset: number if (!isFinalPart && lastTokenIsTimestamp) { const timePosition = this.timestampTokenToSeconds(lastToken) audioEndOffset = audioOffset + Math.floor(timePosition * sampleRate) } else { audioEndOffset = Math.min(audioOffset + maxAudioSamplesPerPart, audioSamples.length) } const segmentEndTime = audioEndOffset / sampleRate const segmentFrameCount = this.secondsRangeToFrameCount(segmentStartTime, segmentEndTime) await logger.startAsync(`Extract timeline for part (timestamp accuracy: ${options.timestampAccuracy!})`) if (partTokens.length !== partTokensCrossAttentionQKs.length) { throw new Error('Unexpected: partTokens.length !== partCrossAttentionQKs.length') } // Prepare tokens partTokens = partTokens.slice(initialTokens.length) partTokensConfidence = partTokensConfidence.slice(initialTokens.length) partTokensCrossAttentionQKs = partTokensCrossAttentionQKs.slice(initialTokens.length) // Find alignment path let alignmentHeads: number[] | undefined if (options.timestampAccuracy === 'medium' || options.model === 'large-v3-turbo') { alignmentHeads = this.alignmentHeadIndexes } else if (options.timestampAccuracy === 'high') { alignmentHeads = undefined } else { throw new Error(`Unsupported timestamp accuracy '${options.timestampAccuracy}', can only be 'medium' or 'high'.`) } const alignmentPath = await this.findAlignmentPathFromQKs(partTokensCrossAttentionQKs, partTokens, 0, segmentFrameCount, alignmentHeads) // Generate timeline from alignment path const partTimeline = await this.getTokenTimelineFromAlignmentPath(alignmentPath, partTokens, segmentStartTime, segmentEndTime, partTokensConfidence) if (onPart) { const partWordTimeline = this.tokenTimelineToWordTimeline(partTimeline, language) const partTranscript = this.tokensToText(partTokens) onPart(partTranscript, partTimeline, partWordTimeline) } // Add tokens to output allDecodedTokens.push(...partTokens) timeline.push(...partTimeline) // Determine compression ratio for recognized text (normalized to lowercase) of this part const compressionRatioForPart = (await getDeflateCompressionMetricsForString(this.tokensToText(partTokens).toLocaleLowerCase())).ratio // If the recognized text isn't too repetitive if (compressionRatioForPart < options.repetitionThreshold!) { // Set current part tokens as the previous part text tokens previousPartTextTokens = partTokens.filter(token => this.isTextToken(token)) } else { // Otherwise, set previous part tokens to an empty array previousPartTextTokens = [] } audioOffset = audioEndOffset logger.end() if (logLevelGreaterOrEqualTo('trace')) { const promptDecodingTime = partTokensDecodingTime[0] const medianTokenDecodingTime = medianOfVector(partTokensDecodingTime.slice(1)) const medianTokenInferenceTime = medianOfVector(partTokensInferenceTime.slice(1)) const medianOverheadTime = medianOfVector(partTokensOverheadTime.slice(1)) logger.log(`${chalk.blueBright('Context')}: ${initialTokens.length + partTokens.length} tokens (${initialTokens.length} prompt, ${partTokens.length} decoded)\n${chalk.blueBright('Prompt decode time')}: ${promptDecodingTime.toFixed(1)}ms\n${chalk.blueBright('Median token decode time')}: ${medianTokenDecodingTime.toFixed(1)}ms (${medianTokenInferenceTime.toFixed(1)}ms inference, ${medianOverheadTime.toFixed(2)}ms overhead)`, 'trace') } } // Convert token timeline to word timeline timeline = this.tokenTimelineToWordTimeline(timeline, language) // Convert tokens to transcript const transcript = this.tokensToText(allDecodedTokens).trim() logger.end() return { transcript, timeline, allDecodedTokens } } async align(rawAudio: RawAudio, transcript: string, sourceLanguage: string, task: 'transcribe' | 'translate', whisperAlignmentOptions: WhisperAlignmentOptions) { await this.initializeTokenizerIfNeeded() whisperAlignmentOptions = extendDeep(defaultWhisperAlignmentOptions, whisperAlignmentOptions) if (!whisperAlignmentOptions.timestampAccuracy) { whisperAlignmentOptions.timestampAccuracy = this.defaultTimestampAccuracy } if (whisperAlignmentOptions.maxTokensPerPart! > largestMaximumTokensPerPart) { //throw new Error(`The number of tokens per part cannot be greater than ${largestMaximumTokensPerPart}`) whisperAlignmentOptions.maxTokensPerPart = largestMaximumTokensPerPart } const targetLanguage = task === 'transcribe' ? sourceLanguage : 'en' let simplifiedTranscript = '' { const words = (await splitToWords(transcript, targetLanguage)).nonPunctuationWords simplifiedTranscript = words.join(' ') } // Tokenize the transcript const simplifiedTranscriptTokens = this.textToTokens(simplifiedTranscript) // Initialize custom logit filter that allows only the transcript tokens to be decoded // in order. const endOfTextToken = this.tokenConfig.endOfTextToken const logitFilter: WhisperLogitFilter = (logits, decodedTokens, isFirstPart, isFinalPart) => { const decodedTextTokens = decodedTokens.filter(token => this.isTextToken(token)) const nextTokenToDecode = simplifiedTranscriptTokens[decodedTextTokens.length] ?? endOfTextToken const newLogits = logits.map((logit, index) => { if (index === nextTokenToDecode) { return logit } // If it's the final part, the ent-of-text token logit is set to -Infinity. // This will force to force all transcript tokens to be decoded even if the model doesn't // recognize them. if (!isFinalPart && index === endOfTextToken) { return logit } return -Infinity }) return newLogits } // Set options for alignment const options: WhisperOptions = { model: this.modelName, temperature: 0.0, prompt: undefined, topCandidateCount: 1, punctuationThreshold: Infinity, autoPromptParts: false, maxTokensPerPart: whisperAlignmentOptions.maxTokensPerPart!, suppressRepetition: false, repetitionThreshold: Infinity, decodeTimestampTokens: true, endTokenThreshold: whisperAlignmentOptions.endTokenThreshold!, includeEndTokenInCandidates: false, timestampAccuracy: whisperAlignmentOptions.timestampAccuracy!, encoderProvider: whisperAlignmentOptions.encoderProvider!, decoderProvider: whisperAlignmentOptions.decoderProvider!, seed: undefined, } // Recognize const { timeline, allDecodedTokens } = await this.recognize(rawAudio, task, sourceLanguage, options, logitFilter) { // If not all tokens were decoded, add the remaining ones to the timeline const lastKnownWordStartTime = timeline.length > 0 ? timeline[timeline.length - 1].startTime : 0 const allDecodedTextTokens = allDecodedTokens.filter(token => this.isTextToken(token)) while (allDecodedTextTokens.length < simplifiedTranscriptTokens.length) { const token = simplifiedTranscriptTokens[allDecodedTextTokens.length] const tokenText = this.tokenToText(token) allDecodedTextTokens.push(token) const newTokenEntry: TimelineEntry = { type: 'token', text: tokenText, startTime: lastKnownWordStartTime, endTime: lastKnownWordStartTime, id: token, confidence: 0, } if (tokenText.startsWith(' ') || timeline.length === 0) { timeline.push({ type: 'word', text: tokenText.trim(), startTime: lastKnownWordStartTime, endTime: lastKnownWordStartTime, timeline: [newTokenEntry], confidence: 0, }) } else { const lastWordEntry = timeline[timeline.length - 1] lastWordEntry.timeline!.push(newTokenEntry) lastWordEntry.text += tokenText } } } return timeline } async detectLanguage(audioFeatures: Onnx.Tensor, temperature: number): Promise<LanguageDetectionResults> { if (!this.isMultiligualModel) { throw new Error('Language detection is only supported with multilingual models') } await this.initializeDecoderSessionIfNeeded() // Prepare and run decoder const logger = new Logger() await logger.startAsync('Detect language with Whisper model') const sotToken = this.tokenConfig.startOfTextToken const initialTokens = [sotToken] const offset = 0 const Onnx = await import('onnxruntime-node') const initialKvDimensions = this.getKvDimensions(1, initialTokens.length) const kvCacheTensor = new Onnx.Tensor('float32', new Float32Array(initialKvDimensions[0] * initialKvDimensions[1] * initialKvDimensions[2] * initialKvDimensions[3]), initialKvDimensions) const tokensTensor = new Onnx.Tensor('int64', new BigInt64Array(initialTokens.map(token => BigInt(token))), [1, initialTokens.length]) const offsetTensor = new Onnx.Tensor('int64', new BigInt64Array([BigInt(offset)]), []) const decoderInputs = { tokens: tokensTensor, audio_features: audioFeatures, kv_cache: kvCacheTensor, offset: offsetTensor } const decoderOutputs = await this.textDecoder!.run(decoderInputs) const logitsBuffer = decoderOutputs['logits'].data as Float32Array const tokenConfig = this.tokenConfig const languageTokensLogits = Array.from(logitsBuffer.slice(tokenConfig.languageTokensStart, tokenConfig.languageTokensEnd)) const languageTokensProbabilities = softmax(languageTokensLogits, temperature) const results: LanguageDetectionResults = [] for (const language in languageIdLookup) { const langId = languageIdLookup[language] const probability = languageTokensProbabilities[langId] results.push({ language, languageName: languageCodeToName(language), probability }) } logger.end() return results } async detectVoiceActivity(audioFeatures: Onnx.Tensor, temperature: number): Promise<number> { await this.initializeDecoderSessionIfNeeded() // Prepare and run decoder const logger = new Logger() await logger.startAsync('Detect voice activity with Whisper model') const sotToken = this.tokenConfig.startOfTextToken const initialTokens = [sotToken] const offset = 0 const Onnx = await import('onnxruntime-node') const initialKvDimensions = this.getKvDimensions(1, initialTokens.length) const kvCacheTensor = new Onnx.Tensor('float32', new Float32Array(initialKvDimensions[0] * initialKvDimensions[1] * initialKvDimensions[2] * initialKvDimensions[3]), initialKvDimensions) const tokensTensor = new Onnx.Tensor('int64', new BigInt64Array(initialTokens.map(token => BigInt(token))), [1, initialTokens.length]) const offsetTensor = new Onnx.Tensor('int64', new BigInt64Array([BigInt(offset)]), []) const decoderInputs = { tokens: tokensTensor, audio_features: audioFeatures, kv_cache: kvCacheTensor, offset: offsetTensor } const decoderOutputs = await this.textDecoder!.run(decoderInputs) const logitsBuffer = decoderOutputs['logits'].data as Float32Array const tokenConfig = this.tokenConfig const logits = Array.from(logitsBuffer) const probabilities = softmax(logits, temperature) const noSpeechProbability = probabilities[tokenConfig.nonSpeechToken] return 1.0 - noSpeechProbability } // Decode tokens using the decoder model async decodeTokens( audioFeatures: Onnx.Tensor, initialTokens: number[], audioDuration: number, isFirstPart: boolean, isFinalPart: boolean, options: WhisperOptions, logitFilter?: WhisperLogitFilter) { // Initialize await this.initializeTokenizerIfNeeded() await this.initializeDecoderSessionIfNeeded() const logger = new Logger() await logger.startAsync('Decode text tokens with Whisper decoder model') options = extendDeep(defaultWhisperOptions, options) const Onnx = await import('onnxruntime-node') // Get token information const endOfTextToken = this.tokenConfig.endOfTextToken const timestampTokensStart = this.tokenConfig.timestampTokensStart const suppressedTextTokens = this.getSuppressedTextTokens() const suppressedMetadataTokens = this.getSuppressedMetadataTokens() const allowedPunctuationMarks = this.getAllowedPunctuationMarks() const spaceToken = this.textToTokens(' ')[0] // Initialize variables for decoding loop let decodedTokens = initialTokens.slice() const initialKvDimensions = this.getKvDimensions(1, decodedTokens.length) let kvCacheTensor = new Onnx.Tensor('float32', new Float32Array(initialKvDimensions[0] * initialKvDimensions[1] * initialKvDimensions[2] * initialKvDimensions[3]), initialKvDimensions) let decodedTokensTimestampLogits: number[][] = [] let decodedTokensConfidence: number[] = [] let decodedTokensCrossAttentionQKs: OnnxLikeFloat32Tensor[] = [] for (let i = 0; i < decodedTokens.length; i++) { decodedTokensTimestampLogits.push(new Array(1501)) // Should the length be 1500 instead? decodedTokensConfidence.push(1.0) decodedTokensCrossAttentionQKs.push(undefined as any) } let lastTimestampTokenIndex = -1 let timestampTokenSeenCount = 0 let bufferedTokensToPrint: number[] = [] // Define method to add a token to output function addToken(tokenToAdd: number, timestampLogits: number[], confidence: number, crossAttentionQKs: OnnxLikeFloat32Tensor) { decodedTokens.push(tokenToAdd) decodedTokensTimestampLogits.push(timestampLogits) decodedTokensConfidence.push(confidence) decodedTokensCrossAttentionQKs.push(crossAttentionQKs) } const maxTokensPerPart = Math.min(options.maxTokensPerPart!, largestMaximumTokensPerPart) let decodedTokensInferenceTime: number[] = [] let decodedTokensDecodingTime: number[] = [] const tokenDecodingTimeTimer = new Timer() // Start decoding loop for (let decodedTokenCount = 0; decodedTokenCount < maxTokensPerPart; decodedTokenCount++) { if (decodedTokenCount > 0) { decodedTokensDecodingTime.push(tokenDecodingTimeTimer.getElapsedTimeAndRestart()) } const isInitialState = decodedTokens.length === initialTokens.length const atLeastOneTextTokenDecoded = decodedTokens.slice(initialTokens.length).some(token => this.isTextToken(token)) // If not in initial state, reshape KV Cache tensor to accomodate a new output token if (!isInitialState) { const dims = kvCacheTensor.dims const currentKvCacheGroups = splitFloat32Array(kvCacheTensor.data as Float32Array, dims[2] * dims[3]) const reshapedKvCacheTensor = new Onnx.Tensor('float32', new Float32Array(dims[0] * dims[1] * (decodedTokens.length) * dims[3]), [dims[0], dims[1], decodedTokens.length, dims[3]]) const reshapedKvCacheGroups = splitFloat32Array(reshapedKvCacheTensor.data, decodedTokens.length * dims[3]) for (let i = 0; i < dims[0]; i++) { reshapedKvCacheGroups[i].set(currentKvCacheGroups[i]) } kvCacheTensor = reshapedKvCacheTensor } // Prepare values for decoder const tokensToDecode = isInitialState ? decodedTokens : [decodedTokens[decodedTokens.length - 1]] const offset = isInitialState ? 0 : decodedTokens.length const tokensTensor = new Onnx.Tensor('int64', new BigInt64Array(tokensToDecode.map(token => BigInt(token))), [1, tokensToDecode.length]) const offsetTensor = new Onnx.Tensor('int64', new BigInt64Array([BigInt(offset)]), []) const decoderInputs = { tokens: tokensTensor, audio_features: audioFeatures, kv_cache: kvCacheTensor, offset: offsetTensor } //// Infer with ONNX decoder model const tokenInferenceTimeTimer = new Timer() const decoderOutputs = await this.textDecoder!.run(decoderInputs) decodedTokensInferenceTime.push(tokenInferenceTimeTimer.elapsedTime) // Extract decoder model results const logitsBuffer = decoderOutputs['logits'].data as Float32Array kvCacheTensor = decoderOutputs['output_kv_cache'] as any const crossAttentionQKsForTokenOnnx = decoderOutputs['cross_attention_qks'] const crossAttentionQKsForToken = makeOnnxLikeFloat32Tensor(crossAttentionQKsForTokenOnnx) crossAttentionQKsForTokenOnnx.dispose() // Get logits const resultLogitsFloatArrays = splitFloat32Array(logitsBuffer, logitsBuffer.length / decoderOutputs['logits'].dims[1]) const allTokenLogits = Array.from(resultLogitsFloatArrays[resultLogitsFloatArrays.length - 1]) // Suppress metadata tokens in the suppression set for (const suppressedTokenIndex of suppressedMetadataTokens) { allTokenLogits[suppressedTokenIndex] = -Infinity } if (!atLeastOneTextTokenDecoded) { // If in initial state, suppress end-of-text token allTokenLogits[endOfTextToken] = -Infinity } const timestampTokenLogits = allTokenLogits.slice(timestampTokensStart) const decodeTimestampTokenIfNeeded = () => { // Try to decode a timestamp token, if needed // If timestamp tokens is disabled in options, don't decode a timestamp if (!options.decodeTimestampTokens) { return false } // If this is the first token in the part, unconditionally decode a timestamp token // for time 0.0 if (isInitialState) { addToken(timestampTokensStart, timestampTokenLogits, 1.0, crossAttentionQKsForToken) return true } const previousTokenWasTimestamp = this.isTimestampToken(decodedTokens[decodedTokens.length - 1]) const secondPreviousTokenWasTimestamp = this.isTimestampToken(decodedTokens[decodedTokens.length - 2]) // If there are two successive timestamp tokens decoded, or the previous timestamp was the first token, // don't decode a timestamp if (previousTokenWasTimestamp && ((decodedTokens.length === initialTokens.length + 1) || secondPreviousTokenWasTimestamp)) { return false } // Derive token probabilities const allTokenProbabilities = softmax(allTokenLogits as any, 1.0) const allTokenLogProbabilities = logOfVector(allTokenProbabilities) const nonTimestampTokenLogProbs = allTokenLogProbabilities.slice(0, timestampTokensStart) // Find highest non-timestamp token const indexOfMaxNonTimestampLogProb = indexOfMax(nonTimestampTokenLogProbs) const valueOfMaxNonTimestampLogProb = nonTimestampTokenLogProbs[indexOfMaxNonTimestampLogProb] // Find highest timestamp token const timestampTokenLogProbs = allTokenLogProbabilities.slice(timestampTokensStart) const indexOfMaxTimestampLogProb = indexOfMax(timestampTokenLogProbs) // Compute the log of the sum of exponentials of the log probabilities // of the timestamp tokens const logSumExpOfTimestampTokenLogProbs = logSumExp(timestampTokenLogProbs) // If the sum isn't greater than the log probability of the highest non-timestamp token, // don't decode a timestamp if (logSumExpOfTimestampTokenLogProbs <= valueOfMaxNonTimestampLogProb) { return false } // Decode a timestamp token timestampTokenSeenCount += 1 if (previousTokenWasTimestamp) { // If previously decoded token was a timestamp token, repeat it const previousToken = decodedTokens[decodedTokens.length - 1] const previousTokenTimestampLogits = decodedTokensTimestampLogits[decodedTokensTimestampLogits.length - 1] const previousTokenConfidence = decodedTokensConfidence[decodedTokensConfidence.length - 1] addToken(previousToken, previousTokenTimestampLogits, previousTokenConfidence, crossAttentionQKsForToken) lastTimestampTokenIndex = decodedTokens.length } else { // Otherwise decode the highest probability timestamp const timestampToken = timestampTokensStart + indexOfMaxTimestampLogProb const confidence = allTokenProbabilities[timestampToken] addToken(timestampToken, timestampTokenLogits, confidence, crossAttentionQKsForToken) } return true } // Call the method to decode timestamp token if needed const timestampTokenDecoded = decodeTimestampTokenIfNeeded() if (timestampTokenDecoded) { await yieldToEventLoop() continue } // Decode a non-timestamp token let nonTimestampTokenLogits = allTokenLogits.slice(0, timestampTokensStart) let shouldDecodeEndfOfTextToken = false // If at least one text token was decoded, and the end-of-text token's probability is // sufficiently higher than the second highest ranked token, then accept end-of-text if (atLeastOneTextTokenDecoded) { const endOfTextTokenLogit = nonTimestampTokenLogits[endOfTextToken] const otherTokensLogits = nonTimestampTokenLogits.slice() otherTokensLogits[endOfTextToken] = -Infinity const indexOfMaximumOtherTokenLogit = indexOfMax(otherTokensLogits) const maximumOtherTokenLogit = nonTimestampTokenLogits[indexOfMaximumOtherTokenLogit] const endProbabilities = softmax([endOfTextTokenLogit, maximumOtherTokenLogit], 1.0) if (endProbabilities[0] > options.endTokenThreshold!) { shouldDecodeEndfOfTextToken = true } } if (logitFilter) { // Apply custom logit filter function if given nonTimestampTokenLogits = logitFilter(nonTimestampTokenLogits, decodedTokens, isFirstPart, isFinalPart) // If the custom filter set the end-of-text token to be Infinity, or -Infinity, // then override any previous decision and accept or reject it, respectively if (nonTimestampTokenLogits[endOfTextToken] === Infinity) { shouldDecodeEndfOfTextToken = true } else if (nonTimestampTokenLogits[endOfTextToken] === -Infinity) { shouldDecodeEndfOfTextToken = false } // If filter caused all word token logits to be -Infinity, then there is no // other token to decode. Fall back to accept end-of-text if (nonTimestampTokenLogits.slice(0, endOfTextToken).every(logit => logit === -Infinity)) { shouldDecodeEndfOfTextToken = true } } else { // Otherwise, suppress text tokens in the suppression set for (const suppressedTokenIndex of suppressedTextTokens) { nonTimestampTokenLogits[suppressedTokenIndex] = -Infinity } // Suppress the space token if at initial state if (isInitialState) { nonTimestampTokenLogits[spaceToken] = -Infinity } } // If end-of-text token should be decoded, then add it and break // out of the loop if (shouldDecodeEndfOfTextToken) { addToken(endOfTextToken, timestampTokenLogits, 1.0, crossAttentionQKsForToken) break } // Suppress end-of-text token if it shouldn't be included in candidates if (!options.includeEndTokenInCandidates) { nonTimestampTokenLogits[endOfTextToken] = -Infinity } // Find top candidates const sortedTopCandidateTokens = getTopKIndexes(nonTimestampTokenLogits, options.topCandidateCount!, true) let topCandidates = Array.from(sortedTopCandidateTokens).map(token => { const logit = nonTimestampTokenLogits[token] return { token, logit, text: this.tokenToText(token, true) } }) // Apply repetition suppression if enabled if (options.suppressRepetition) { // Using some hardcoded constants, for now: const tokenWindowSize = 30 const thresholdMatchLength = 4 const thresholdCycleRepetitionCount = 3 const filteredCandidates: typeof topCandidates = [] for (const candidate of topCandidates) { const decodedTextTokens = decodedTokens.filter(token => this.isTextToken(token)) const lastDecodedTextTokens = decodedTextTokens .slice(Math.max(decodedTextTokens.length - tokenWindowSize, 0)) .reverse() const { longestMatchLength, longestCycleRepetitionCount } = getTokenRepetitionScore([candidate.token, ...lastDecodedTextTokens]) if (longestMatchLength >= thresholdMatchLength || longestCycleRepetitionCount >= thresholdCycleRepetitionCount) { continue } filteredCandidates.push(candidate) } // If all candidates have been filtered out, accept an end-of-text token if (filteredCandidates.length === 0) { filteredCandidates.push({ token: endOfTextToken, logit: Infinity, text: this.tokenToText(endOfTextToken, true) }) } topCandidates = filteredCandidates } // Compute top candidate probabilities const topCandidateProbabilities = softmax(topCandidates.map(a => a.logit), options.temperature) // Find highest ranking punctuation token const rankOfPromisingPunctuationToken = topCandidates.findIndex((entry, index) => { const tokenText = this.tokenToText(entry.token).trim() const isPunctuationToken = allowedPunctuationMarks.includes(tokenText) if (!isPunctuationToken) { return false } const tokenProb = topCandidateProbabilities[index] return tokenProb >= options.punctuationThreshold! }) // Find rank of space token let rankOfSpaceToken = topCandidates.findIndex(candidate => candidate.token === spaceToken) if (rankOfSpaceToken < 0) { rankOfSpaceToken = Infinity } // Choose token let chosenCandidateRank: number // Select a high-ranking punctuation token if found, and it has // a rank higher than the space token, if (rankOfPromisingPunctuationToken >= 0 && rankOfPromisingPunctuationToken < rankOfSpaceToken) { chosenCandidateRank = rankOfPromisingPunctuationToken } else { // Otherwise, select randomly from top k distribution chosenCandidateRank = this.randomGen.selectRandomIndexFromDistribution(topCandidateProbabilities) } // Add chosen token const chosenToken = topCandidates[chosenCandidateRank].token const chosenTokenConfidence = topCandidateProbabilities[chosenCandidateRank] addToken(chosenToken, timestampTokenLogits, chosenTokenConfidence, crossAttentionQKsForToken) // If chosen token is the end-of-text token, break if (chosenToken === endOfTextToken) { break } // Print token if needed if (this.isTextToken(chosenToken)) { bufferedTokensToPrint.push(chosenToken) let textToPrint = this.tokensToText(bufferedTokensToPrint) // If the decoded text is valid, print it if (!containsInvalidCodepoint(textToPrint)) { if (isFirstPart && decodedTokens.every(token => this.isMetadataToken(token))) { textToPrint = textToPrint.trimStart() } logger.write(textToPrint) bufferedTokensToPrint = [] } } await yieldToEventLoop() } if (decodedTokensDecodingTime.length === decodedTokensInferenceTime.length - 1) { decodedTokensDecodingTime.push(tokenDecodingTimeTimer.getElapsedTimeAndRestart()) } // If at least two timestamp tokens were decoded and it's not the final part, // truncate up to the last timestamp token if (timestampTokenSeenCount >= 2 && !isFinalPart) { const sliceEndTokenIndex = lastTimestampTokenIndex decodedTokens = decodedTokens.slice(0, sliceEndTokenIndex) decodedTokensTimestampLogits = decodedTokensTimestampLogits.slice(0, sliceEndTokenIndex) decodedTokensCrossAttentionQKs = decodedTokensCrossAttentionQKs.slice(0, sliceEndTokenIndex) decodedTokensConfidence = decodedTokensConfidence.slice(0, sliceEndTokenIndex) decodedTokensDecodingTime = decodedTokensDecodingTime.slice(0, sliceEndTokenIndex) decodedTokensInferenceTime = decodedTokensInferenceTime.slice(0, sliceEndTokenIndex) } const decodedTokensOverheadTime = decodedTokensDecodingTime.map((time, index) => time - decodedTokensInferenceTime[index]) logger.write('\n') logger.end() // Return the decoded tokens return { decodedTokens, decodedTokensTimestampLogits, decodedTokensConfidence, decodedTokensCrossAttentionQKs, decodedTokensDecodingTime, decodedTokensInferenceTime, decodedTokensOverheadTime, } } // Encode audio using the encoder model async encodeAudio(rawAudio: RawAudio) { await this.initializeEncoderSessionIfNeeded() const Onnx = await import('onnxruntime-node') const logger = new Logger() const audioSamples = rawAudio.audioChannels[0] const sampleRate = rawAudio.sampleRate const fftOrder = 400 const fftWindowSize = 400 const fftHopLength = 160 const filterbankCount = this.filterbankCount const filterbanks = this.filterbanks const maxAudioSamples = sampleRate * 30 const maxAudioFrames = 3000 if (sampleRate !== 16000) { throw new Error('Audio must have a sample rate of 16000 Hz') } if (audioSamples.length > maxAudioSamples) { throw new Error(`Audio part is longer than 30 seconds`) } await logger.startAsync('Extract Mel spectrogram from audio part') // Pad audio samples to ensure that have a duration of 30 seconds const paddedAudioSamples = new Float32Array(maxAudioSamples) paddedAudioSamples.set(audioSamples, 0) const rawAudioPart: RawAudio = { audioChannels: [paddedAudioSamples], sampleRate } // Compute Mel spectrogram const { melSpectrogram } = await computeMelSpectrogramUsingFilterbanks(rawAudioPart, fftOrder, fftWindowSize, fftHopLength, filterbanks) // Flatten, transpose, apply logarithm and normalize Mel spectrogram await logger.startAsync('Process Mel spectrogram') const flattenedLogMelSpectrogram = new Float32Array(maxAudioFrames * filterbankCount) let maxLogMel = -Infinity for (let i = 0; i < filterbankCount; i++) { for (let j = 0; j < maxAudioFrames; j++) { const mel = melSpectrogram[j][i] const logMel = Math.log10(Math.max(mel, 1e-10)) if (logMel > maxLogMel) { maxLogMel = logMel } flattenedLogMelSpectrogram[(i * maxAudioFrames) + j] = logMel } } for (let i = 0; i < flattenedLogMelSpectrogram.length; i++) { const logMel = flattenedLogMelSpectrogram[i] const normalizedLogMel = (Math.max(logMel, maxLogMel - 8) + 4) / 4 flattenedLogMelSpectrogram[i] = normalizedLogMel } // Run the encoder model await logger.startAsync('Encode Mel spectrogram with Whisper encoder model') const inputTensor = new Onnx.Tensor('float32', flattenedLogMelSpectrogram, [1, filterbankCount, maxAudioFrames]) const encoderInputs = { mel: inputTensor } const encoderOutputs = await this.audioEncoder!.run(encoderInputs) const encodedAudioFeatures = encoderOutputs['output'] logger.end() return encodedAudioFeatures } tokenTimelineToWordTimeline(tokenTimeline: Timeline, language: string): Timeline { function isSeparatorCharacter(char: string) { const nonSeparatingPunctuation = [`'`, `-`, `.`, `·`, `•`] if (nonSeparatingPunctuation.includes(char)) { return false } return isWhitespace(char) || includesPunctuation(char) } function startsWithSeparatorCharacter(text: string) { return isSeparatorCharacter(text[0]) } function endsWithSeparatorCharacter(text: string) { return isSeparatorCharacter(text[text.length - 1]) } if (language !== 'zh' && language !== 'ja') { tokenTimeline = tokenTimeline.filter(entry => this.isTextToken(entry.id!)) } const resultTimeline: Timeline = [] let groups: Timeline[] = [] for (let tokenIndex = 0; tokenIndex < tokenTimeline.length; tokenIndex++) { const entry = tokenTimeline[tokenIndex] const previousEntry = tokenIndex > 0 ? tokenTimeline[tokenIndex - 1] : undefined const text = entry.text const previousEntryText = previousEntry?.text if (groups.length === 0 || text === '' || startsWithSeparatorCharacter(text) || (previousEntryText != null && endsWithSeparatorCharacter(previousEntryText))) { groups.push([entry]) } else { groups[groups.length - 1].push(entry) } } { const splitGroups: Timeline[] = [] for (let groupIndex = 0; groupIndex < groups.length; groupIndex++) { const group = groups[groupIndex] const nextGroup = groups[groupIndex + 1] if ( group.length > 1 && group[group.length - 1].text === '.' && (!nextGroup || [' ', '['].includes(nextGroup[0].text[0]))) { splitGroups.push(group.slice(0, group.length - 1)) splitGroups.push(group.slice(group.length - 1)) } else { splitGroups.push(group) } } groups = splitGroups } for (const group of groups) { let groupText = this.tokensToText(group.map(entry => entry.id!)) if (groupText === '') { continue } const startTime = group[0].startTime const endTime = group[group.length - 1].endTime let confidence: number | undefined = undefined if (group[0].confidence != null) { confidence = meanOfVector(group.map(entry => entry.confidence!)) } const newEntry: TimelineEntry = { type: 'word', text: groupText.trim(), startTime, endTime, confidence, timeline: group, } resultTimeline.push(newEntry) } return resultTimeline } async getTokenTimelineFromAlignmentPath(alignmentPath: AlignmentPath, tokens: number[], startTimeOffset: number, endTimeOffset: number, tokensConfidence?: number[], correctionAmount = 0.0) { if (alignmentPath.length === 0) { return [] } const tokenTimeline: Timeline = [] for (let pathIndex = 0; pathIndex < alignmentPath.length; pathIndex++) { if (pathIndex !== 0 && alignmentPath[pathIndex].source === alignmentPath[pathIndex - 1].source) { continue } const tokenMappingEntry = alignmentPath[pathIndex] const tokenIndex = tokenMappingEntry.source const token = tokens[tokenIndex] const tokenConfidence = tokensConfidence ? tokensConfidence[tokenIndex] : undefined const tokenText = this.tokenToText(token, true) let startTime = startTimeOffset + (tokenMappingEntry.dest * 0.02) startTime = Math.max(startTime + correctionAmount, startTimeOffset) if (tokenTimeline.length > 0) { tokenTimeline[tokenTimeline.length - 1].endTime = startTime } tokenTimeline.push({ type: 'token', text: tokenText, id: token, startTime, endTime: -1, confidence: tokenConfidence }) } if (tokenTimeline.length > 0) { tokenTimeline[tokenTimeline.length - 1].endTime = endTimeOffset } return tokenTimeline } async findAlignmentPathFromQKs(qksTensors: OnnxLikeFloat32Tensor[], tokens: number[], segmentStartFrame: number, segmentEndFrame: number, headIndexes?: number[]) { const segmentFrameCount = segmentEndFrame - segmentStartFrame if (segmentFrameCount === 0 || tokens.length === 0 || qksTensors.length === 0) { return [] } const tokenCount = qksTensors.length const layerCount = qksTensors[0].dims[0] const headCount = qksTensors[0].dims[2] const frameCount = qksTensors[0].dims[4] if (!headIndexes) { headIndexes = getIntegerRange(0, layerCount * headCount) } // Load attention head weights from tensors const attentionHeads: Float32Array[][] = [] // structure: [heads, tokens, frames] for (const headIndex of headIndexes) { const attentionHead: Float32Array[] = [] // structure: [tokens, frames] for (let tokenIndex = 0; tokenIndex < tokenCount; tokenIndex++) { const bufferOffset = headIndex * frameCount const startIndexInBuffer = bufferOffset + segmentStartFrame cons