echogarden
Version:
An easy-to-use speech toolset. Includes tools for synthesis, recognition, alignment, speech translation, language detection, source separation and more.
1,465 lines (1,083 loc) • 97.9 kB
text/typescript
import chalk from 'chalk'
import type * as Onnx from 'onnxruntime-node'
import { Logger } from '../utilities/Logger.js'
import { computeMelSpectrogramUsingFilterbanks, Filterbank } from '../dsp/MelSpectrogram.js'
import { clip, getIntegerRange, getTopKIndexes, splitFloat32Array, yieldToEventLoop } from '../utilities/Utilities.js'
import { indexOfMax, logOfVector, logSumExp, meanOfVector, medianOfVector, softmax, sumAndSumOfSquaresOfVector, sumOfSquaresOfVector, sumVector } from '../math/VectorMath.js'
import { alignDTWWindowed } from '../alignment/DTWSequenceAlignmentWindowed.js'
import { extendDeep } from '../utilities/ObjectUtilities.js'
import { Timeline, TimelineEntry } from '../utilities/Timeline.js'
import { AlignmentPath } from '../alignment/SpeechAlignment.js'
import { getRawAudioDuration, RawAudio, sliceRawAudio } from '../audio/AudioUtilities.js'
import { readFileAsUtf8 } from '../utilities/FileSystem.js'
import { logLevelGreaterOrEqualTo, type LanguageDetectionResults } from '../api/API.js'
import { formatLanguageCodeWithName, getShortLanguageCode, languageCodeToName } from '../utilities/Locale.js'
import { loadPackage } from '../utilities/PackageManager.js'
import { XorShift32PRNG } from '../utilities/RandomGenerator.js'
import { detectSpeechLanguageByParts } from '../api/SpeechLanguageDetection.js'
import { type Tiktoken } from 'tiktoken/lite'
import { includesPunctuation, isWhitespace, splitToWords } from '../nlp/Segmentation.js'
import { medianOf5Filter } from '../math/MedianFilter.js'
import { getDeflateCompressionMetricsForString } from '../utilities/Compression.js'
import { dmlProviderAvailable, getOnnxSessionOptions, makeOnnxLikeFloat32Tensor, OnnxExecutionProvider, OnnxLikeFloat32Tensor } from '../utilities/OnnxUtilities.js'
import { murmurHash3_int32Input } from '../utilities/Hashing.js'
import { containsInvalidCodepoint, getTokenRepetitionScore } from '../utilities/StringUtilities.js'
import { joinPath } from '../utilities/PathUtilities.js'
import { Timer } from '../utilities/Timer.js'
export async function recognize(
sourceRawAudio: RawAudio,
modelName: WhisperModelName,
modelDir: string,
task: WhisperTask,
sourceLanguage: string,
options: WhisperOptions,
onPart?: WhisperPartCallback) {
options = extendDeep(defaultWhisperOptions, options)
if (sourceRawAudio.sampleRate !== 16000) {
throw new Error('Source audio must have a sample rate of 16000 Hz')
}
sourceLanguage = getShortLanguageCode(sourceLanguage)
if (!(sourceLanguage in languageIdLookup)) {
throw new Error(`The language ${formatLanguageCodeWithName(sourceLanguage)} is not supported by the Whisper engine.`)
}
if (isEnglishOnlyModel(modelName) && sourceLanguage !== 'en') {
throw new Error(`The model '${modelName}' can only be used with English inputs. However, the given source language was ${languageCodeToName(sourceLanguage)}.`)
}
if (modelName === 'large-v3-turbo' && task === 'translate') {
throw new Error(`The 'large-v3-turbo' model doesn't support translation tasks.`)
}
if (options.temperature && options.temperature < 0) {
throw new Error(`Temperature can't be negative`)
}
// Workaround issue with large-v3-turbo that produces invalid results when a prompt is passed to it.
// Always disable autoprompting for that model.
if (options.autoPromptParts && modelName === 'large-v3-turbo') {
options.autoPromptParts = false
}
// Select encoder ONNX provider
const encoderProviders: OnnxExecutionProvider[] =
options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName)
// Select decoder ONNX provider
const decoderProviders: OnnxExecutionProvider[] =
options.decoderProvider ? [options.decoderProvider] : getDefaultDecoderProvidersForModel(modelName)
const seed = options.seed
const whisper = new Whisper(
modelName,
modelDir,
encoderProviders,
decoderProviders,
seed)
const result = await whisper.recognize(sourceRawAudio, task, sourceLanguage, options, undefined, onPart)
return result
}
export async function align(
sourceRawAudio: RawAudio,
transcript: string,
modelName: WhisperModelName,
modelDir: string,
sourceLanguage: string,
options: WhisperAlignmentOptions) {
options = extendDeep(defaultWhisperAlignmentOptions, options)
if (sourceRawAudio.sampleRate !== 16000) {
throw new Error('Source audio must have a sample rate of 16000 Hz')
}
sourceLanguage = getShortLanguageCode(sourceLanguage)
if (!(sourceLanguage in languageIdLookup)) {
throw new Error(`The language ${formatLanguageCodeWithName(sourceLanguage)} is not supported by the Whisper engine.`)
}
if (isEnglishOnlyModel(modelName) && sourceLanguage !== 'en') {
throw new Error(`The model '${modelName}' can only be used with English inputs. However, the given source language was ${languageCodeToName(sourceLanguage)}.`)
}
// Select encoder ONNX provider
const encoderProviders: OnnxExecutionProvider[] =
options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName)
// Select decoder ONNX provider
const decoderProviders: OnnxExecutionProvider[] =
options.decoderProvider ? [options.decoderProvider] : getDefaultDecoderProvidersForModel(modelName)
const whisper = new Whisper(
modelName,
modelDir,
encoderProviders,
decoderProviders,)
const timeline = await whisper.align(sourceRawAudio, transcript, sourceLanguage, 'transcribe', options)
return timeline
}
export async function alignEnglishTranslation(
sourceRawAudio: RawAudio,
translatedTranscript: string,
modelName: WhisperModelName,
modelDir: string,
sourceLanguage: string,
options: WhisperAlignmentOptions) {
options = extendDeep(defaultWhisperAlignmentOptions, options)
if (sourceRawAudio.sampleRate !== 16000) {
throw new Error('Source audio must have a sample rate of 16000 Hz')
}
sourceLanguage = getShortLanguageCode(sourceLanguage)
if (!(sourceLanguage in languageIdLookup)) {
throw new Error(`The source language ${formatLanguageCodeWithName(sourceLanguage)} is not supported by the Whisper engine.`)
}
if (modelName === 'large-v3-turbo') {
throw new Error(`The 'large-v3-turbo' model doesn't support translation tasks, so cannot be used for translation alignment.`)
}
if (isEnglishOnlyModel(modelName)) {
throw new Error(`Translation alignment can only be done with multilingual models.`)
}
// Select encoder ONNX provider
const encoderProviders: OnnxExecutionProvider[] =
options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName)
// Select decoder ONNX provider
const decoderProviders: OnnxExecutionProvider[] =
options.decoderProvider ? [options.decoderProvider] : getDefaultDecoderProvidersForModel(modelName)
const whisper = new Whisper(
modelName,
modelDir,
encoderProviders,
decoderProviders,)
const timeline = await whisper.align(sourceRawAudio, translatedTranscript, sourceLanguage, 'translate', options)
return timeline
}
export async function detectLanguage(
sourceRawAudio: RawAudio,
modelName: WhisperModelName,
modelDir: string,
options: WhisperLanguageDetectionOptions) {
options = extendDeep(defaultWhisperLanguageDetectionOptions, options)
if (sourceRawAudio.sampleRate !== 16000) {
throw new Error('Source audio must have a sample rate of 16000 Hz')
}
if (!isMultilingualModel(modelName)) {
throw new Error(`Language detection is only supported with multilingual models.`)
}
if (options.temperature! < 0) {
throw new Error(`Temperature cannot be negative`)
}
// Select encoder ONNX provider
const encoderProviders: OnnxExecutionProvider[] =
options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName)
// Select decoder ONNX provider
const decoderProviders: OnnxExecutionProvider[] =
options.decoderProvider ? [options.decoderProvider] : []
const whisper = new Whisper(
modelName,
modelDir,
encoderProviders,
decoderProviders)
async function detectLanguageForPart(partAudio: RawAudio) {
const audioFeatures = await whisper.encodeAudio(partAudio)
const partResults = await whisper.detectLanguage(audioFeatures, options.temperature!)
return partResults
}
const results = await detectSpeechLanguageByParts(sourceRawAudio, detectLanguageForPart)
results.sort((entry1, entry2) => entry2.probability - entry1.probability)
return results
}
export async function detectVoiceActivity(
sourceRawAudio: RawAudio,
modelName: WhisperModelName,
modelDir: string,
options: WhisperVADOptions) {
options = extendDeep(defaultWhisperVADOptions, options)
if (sourceRawAudio.sampleRate !== 16000) {
throw new Error('Source audio must have a sample rate of 16000 Hz')
}
if (options.temperature! < 0) {
throw new Error(`Temperature cannot be negative`)
}
const audioSamples = sourceRawAudio.audioChannels[0]
const partDuration = 5
const maxSamplesCountForPart = sourceRawAudio.sampleRate * partDuration
// Select encoder ONNX provider
const encoderProviders: OnnxExecutionProvider[] =
options.encoderProvider ? [options.encoderProvider] : getDefaultEncoderProvidersForModel(modelName)
// Select decoder ONNX provider
const decoderProviders: OnnxExecutionProvider[] =
options.decoderProvider ? [options.decoderProvider] : []
const whisper = new Whisper(
modelName,
modelDir,
encoderProviders,
decoderProviders)
const partProbabilities: Timeline = []
for (let sampleOffset = 0; sampleOffset < audioSamples.length; sampleOffset += maxSamplesCountForPart) {
const partSamples = sliceRawAudio(sourceRawAudio, sampleOffset, sampleOffset + maxSamplesCountForPart)
const samplesCountForPart = partSamples.audioChannels[0].length
const startTime = sampleOffset / sourceRawAudio.sampleRate
const endTime = (sampleOffset + samplesCountForPart) / sourceRawAudio.sampleRate
const encodedPartSamples = await whisper.encodeAudio(partSamples)
const probabilityForPart = await whisper.detectVoiceActivity(encodedPartSamples, options.temperature!)
partProbabilities.push({
type: 'segment',
text: '',
startTime,
endTime,
confidence: probabilityForPart,
})
}
return { partProbabilities }
}
export class Whisper {
isMultiligualModel: boolean
audioEncoder?: Onnx.InferenceSession
textDecoder?: Onnx.InferenceSession
tiktoken?: Tiktoken
tokenConfig: {
endOfTextToken: number
startOfTextToken: number
languageTokensStart: number
languageTokensEnd: number
translateTaskToken: number
transcribeTaskToken: number
startOfPromptToken: number
nonSpeechToken: number
noTimestampsToken: number
timestampTokensStart: number
timestampTokensEnd: number
}
randomGen: XorShift32PRNG
constructor(
public readonly modelName: WhisperModelName,
public readonly modelDir: string,
public readonly encoderExecutionProviders: OnnxExecutionProvider[],
public readonly decoderExecutionProviders: OnnxExecutionProvider[],
prngSeed = 1234) {
this.isMultiligualModel = isMultilingualModel(this.modelName)
if (this.isMultiligualModel) {
this.tokenConfig = {
endOfTextToken: 50257,
startOfTextToken: 50258,
languageTokensStart: 50259,
languageTokensEnd: 50358,
translateTaskToken: 50358,
transcribeTaskToken: 50359,
startOfPromptToken: 50361,
nonSpeechToken: 50362,
noTimestampsToken: 50363,
timestampTokensStart: 50364,
timestampTokensEnd: 50364 + 1501,
}
} else {
this.tokenConfig = {
endOfTextToken: 50256,
startOfTextToken: 50257,
languageTokensStart: 50258,
languageTokensEnd: 50358,
translateTaskToken: 50358,
transcribeTaskToken: 50359,
startOfPromptToken: 50360,
nonSpeechToken: 50361,
noTimestampsToken: 50362,
timestampTokensStart: 50363,
timestampTokensEnd: 50363 + 1501,
}
}
this.randomGen = new XorShift32PRNG(murmurHash3_int32Input(prngSeed))
}
async recognize(
rawAudio: RawAudio,
task: WhisperTask,
language: string,
options: WhisperOptions,
logitFilter?: WhisperLogitFilter,
onPart?: WhisperPartCallback,
) {
await this.initializeIfNeeded()
const logger = new Logger()
options = extendDeep(defaultWhisperOptions, options)
options.model = this.modelName
if (!options.timestampAccuracy) {
options.timestampAccuracy = this.defaultTimestampAccuracy
}
if (options.maxTokensPerPart! > largestMaximumTokensPerPart) {
//throw new Error(`The number of tokens per part cannot be greater than ${largestMaximumTokensPerPart}`)
options.maxTokensPerPart = largestMaximumTokensPerPart
}
const audioSamples = rawAudio.audioChannels[0]
const sampleRate = rawAudio.sampleRate
const prompt = options.prompt
const decodeTimestampTokens = options.decodeTimestampTokens!
const maxAudioSamplesPerPart = sampleRate * 30
let previousPartTextTokens: number[] = []
let timeline: Timeline = []
let allDecodedTokens: number[] = []
let wrappedLogitFilter: WhisperLogitFilter | undefined
if (logitFilter) {
wrappedLogitFilter = (logits, partDecodedTokens, isFirstPart, isFinalPart) => {
return logitFilter(logits, [...allDecodedTokens, ...partDecodedTokens], isFirstPart, isFinalPart)
}
}
for (let audioOffset = 0; audioOffset < audioSamples.length;) {
const segmentStartTime = audioOffset / sampleRate
await logger.startAsync(`\nPrepare audio part at time position ${segmentStartTime.toFixed(2)}`, undefined, chalk.magentaBright)
const audioPartSamples = audioSamples.slice(audioOffset, audioOffset + maxAudioSamplesPerPart)
const audioPartRawAudio: RawAudio = { audioChannels: [audioPartSamples], sampleRate }
const audioPartDuration = getRawAudioDuration(audioPartRawAudio)
logger.end()
const audioPartFeatures = await this.encodeAudio(audioPartRawAudio)
const isFirstPart = audioOffset === 0
const isFinalPart = audioOffset + maxAudioSamplesPerPart >= audioSamples.length
let initialTokens: number[] = []
if (isFirstPart && prompt) {
let promptTokens = this.textToTokens(prompt)
if (promptTokens.length > largestMaximumTokensPerPart) {
promptTokens = promptTokens.slice(promptTokens.length - largestMaximumTokensPerPart)
}
initialTokens = [this.tokenConfig.startOfPromptToken, ...promptTokens]
} else if (options.autoPromptParts && previousPartTextTokens.length > 0) {
initialTokens = [this.tokenConfig.startOfPromptToken, ...previousPartTextTokens]
}
initialTokens = [...initialTokens, ...this.getTextStartTokens(language, task, !decodeTimestampTokens)]
//logger.log(`Initial tokens count: ${initialTokens.length}`)
logger.end()
let {
decodedTokens: partTokens,
decodedTokensConfidence: partTokensConfidence,
decodedTokensCrossAttentionQKs: partTokensCrossAttentionQKs,
decodedTokensDecodingTime: partTokensDecodingTime,
decodedTokensInferenceTime: partTokensInferenceTime,
decodedTokensOverheadTime: partTokensOverheadTime,
} = await this.decodeTokens(
audioPartFeatures,
initialTokens,
audioPartDuration,
isFirstPart,
isFinalPart,
options,
wrappedLogitFilter,
)
const lastToken = partTokens[partTokens.length - 1]
const lastTokenIsTimestamp = this.isTimestampToken(lastToken)
let audioEndOffset: number
if (!isFinalPart && lastTokenIsTimestamp) {
const timePosition = this.timestampTokenToSeconds(lastToken)
audioEndOffset = audioOffset + Math.floor(timePosition * sampleRate)
} else {
audioEndOffset = Math.min(audioOffset + maxAudioSamplesPerPart, audioSamples.length)
}
const segmentEndTime = audioEndOffset / sampleRate
const segmentFrameCount = this.secondsRangeToFrameCount(segmentStartTime, segmentEndTime)
await logger.startAsync(`Extract timeline for part (timestamp accuracy: ${options.timestampAccuracy!})`)
if (partTokens.length !== partTokensCrossAttentionQKs.length) {
throw new Error('Unexpected: partTokens.length !== partCrossAttentionQKs.length')
}
// Prepare tokens
partTokens = partTokens.slice(initialTokens.length)
partTokensConfidence = partTokensConfidence.slice(initialTokens.length)
partTokensCrossAttentionQKs = partTokensCrossAttentionQKs.slice(initialTokens.length)
// Find alignment path
let alignmentHeads: number[] | undefined
if (options.timestampAccuracy === 'medium' || options.model === 'large-v3-turbo') {
alignmentHeads = this.alignmentHeadIndexes
} else if (options.timestampAccuracy === 'high') {
alignmentHeads = undefined
} else {
throw new Error(`Unsupported timestamp accuracy '${options.timestampAccuracy}', can only be 'medium' or 'high'.`)
}
const alignmentPath = await this.findAlignmentPathFromQKs(partTokensCrossAttentionQKs, partTokens, 0, segmentFrameCount, alignmentHeads)
// Generate timeline from alignment path
const partTimeline = await this.getTokenTimelineFromAlignmentPath(alignmentPath, partTokens, segmentStartTime, segmentEndTime, partTokensConfidence)
if (onPart) {
const partWordTimeline = this.tokenTimelineToWordTimeline(partTimeline, language)
const partTranscript = this.tokensToText(partTokens)
onPart(partTranscript, partTimeline, partWordTimeline)
}
// Add tokens to output
allDecodedTokens.push(...partTokens)
timeline.push(...partTimeline)
// Determine compression ratio for recognized text (normalized to lowercase) of this part
const compressionRatioForPart = (await getDeflateCompressionMetricsForString(this.tokensToText(partTokens).toLocaleLowerCase())).ratio
// If the recognized text isn't too repetitive
if (compressionRatioForPart < options.repetitionThreshold!) {
// Set current part tokens as the previous part text tokens
previousPartTextTokens = partTokens.filter(token => this.isTextToken(token))
} else {
// Otherwise, set previous part tokens to an empty array
previousPartTextTokens = []
}
audioOffset = audioEndOffset
logger.end()
if (logLevelGreaterOrEqualTo('trace')) {
const promptDecodingTime = partTokensDecodingTime[0]
const medianTokenDecodingTime = medianOfVector(partTokensDecodingTime.slice(1))
const medianTokenInferenceTime = medianOfVector(partTokensInferenceTime.slice(1))
const medianOverheadTime = medianOfVector(partTokensOverheadTime.slice(1))
logger.log(`${chalk.blueBright('Context')}: ${initialTokens.length + partTokens.length} tokens (${initialTokens.length} prompt, ${partTokens.length} decoded)\n${chalk.blueBright('Prompt decode time')}: ${promptDecodingTime.toFixed(1)}ms\n${chalk.blueBright('Median token decode time')}: ${medianTokenDecodingTime.toFixed(1)}ms (${medianTokenInferenceTime.toFixed(1)}ms inference, ${medianOverheadTime.toFixed(2)}ms overhead)`, 'trace')
}
}
// Convert token timeline to word timeline
timeline = this.tokenTimelineToWordTimeline(timeline, language)
// Convert tokens to transcript
const transcript = this.tokensToText(allDecodedTokens).trim()
logger.end()
return { transcript, timeline, allDecodedTokens }
}
async align(rawAudio: RawAudio, transcript: string, sourceLanguage: string, task: 'transcribe' | 'translate', whisperAlignmentOptions: WhisperAlignmentOptions) {
await this.initializeTokenizerIfNeeded()
whisperAlignmentOptions = extendDeep(defaultWhisperAlignmentOptions, whisperAlignmentOptions)
if (!whisperAlignmentOptions.timestampAccuracy) {
whisperAlignmentOptions.timestampAccuracy = this.defaultTimestampAccuracy
}
if (whisperAlignmentOptions.maxTokensPerPart! > largestMaximumTokensPerPart) {
//throw new Error(`The number of tokens per part cannot be greater than ${largestMaximumTokensPerPart}`)
whisperAlignmentOptions.maxTokensPerPart = largestMaximumTokensPerPart
}
const targetLanguage = task === 'transcribe' ? sourceLanguage : 'en'
let simplifiedTranscript = ''
{
const words = (await splitToWords(transcript, targetLanguage)).nonPunctuationWords
simplifiedTranscript = words.join(' ')
}
// Tokenize the transcript
const simplifiedTranscriptTokens = this.textToTokens(simplifiedTranscript)
// Initialize custom logit filter that allows only the transcript tokens to be decoded
// in order.
const endOfTextToken = this.tokenConfig.endOfTextToken
const logitFilter: WhisperLogitFilter = (logits, decodedTokens, isFirstPart, isFinalPart) => {
const decodedTextTokens = decodedTokens.filter(token => this.isTextToken(token))
const nextTokenToDecode = simplifiedTranscriptTokens[decodedTextTokens.length] ?? endOfTextToken
const newLogits = logits.map((logit, index) => {
if (index === nextTokenToDecode) {
return logit
}
// If it's the final part, the ent-of-text token logit is set to -Infinity.
// This will force to force all transcript tokens to be decoded even if the model doesn't
// recognize them.
if (!isFinalPart && index === endOfTextToken) {
return logit
}
return -Infinity
})
return newLogits
}
// Set options for alignment
const options: WhisperOptions = {
model: this.modelName,
temperature: 0.0,
prompt: undefined,
topCandidateCount: 1,
punctuationThreshold: Infinity,
autoPromptParts: false,
maxTokensPerPart: whisperAlignmentOptions.maxTokensPerPart!,
suppressRepetition: false,
repetitionThreshold: Infinity,
decodeTimestampTokens: true,
endTokenThreshold: whisperAlignmentOptions.endTokenThreshold!,
includeEndTokenInCandidates: false,
timestampAccuracy: whisperAlignmentOptions.timestampAccuracy!,
encoderProvider: whisperAlignmentOptions.encoderProvider!,
decoderProvider: whisperAlignmentOptions.decoderProvider!,
seed: undefined,
}
// Recognize
const { timeline, allDecodedTokens } = await this.recognize(rawAudio, task, sourceLanguage, options, logitFilter)
{
// If not all tokens were decoded, add the remaining ones to the timeline
const lastKnownWordStartTime = timeline.length > 0 ? timeline[timeline.length - 1].startTime : 0
const allDecodedTextTokens = allDecodedTokens.filter(token => this.isTextToken(token))
while (allDecodedTextTokens.length < simplifiedTranscriptTokens.length) {
const token = simplifiedTranscriptTokens[allDecodedTextTokens.length]
const tokenText = this.tokenToText(token)
allDecodedTextTokens.push(token)
const newTokenEntry: TimelineEntry = {
type: 'token',
text: tokenText,
startTime: lastKnownWordStartTime,
endTime: lastKnownWordStartTime,
id: token,
confidence: 0,
}
if (tokenText.startsWith(' ') || timeline.length === 0) {
timeline.push({
type: 'word',
text: tokenText.trim(),
startTime: lastKnownWordStartTime,
endTime: lastKnownWordStartTime,
timeline: [newTokenEntry],
confidence: 0,
})
} else {
const lastWordEntry = timeline[timeline.length - 1]
lastWordEntry.timeline!.push(newTokenEntry)
lastWordEntry.text += tokenText
}
}
}
return timeline
}
async detectLanguage(audioFeatures: Onnx.Tensor, temperature: number): Promise<LanguageDetectionResults> {
if (!this.isMultiligualModel) {
throw new Error('Language detection is only supported with multilingual models')
}
await this.initializeDecoderSessionIfNeeded()
// Prepare and run decoder
const logger = new Logger()
await logger.startAsync('Detect language with Whisper model')
const sotToken = this.tokenConfig.startOfTextToken
const initialTokens = [sotToken]
const offset = 0
const Onnx = await import('onnxruntime-node')
const initialKvDimensions = this.getKvDimensions(1, initialTokens.length)
const kvCacheTensor = new Onnx.Tensor('float32', new Float32Array(initialKvDimensions[0] * initialKvDimensions[1] * initialKvDimensions[2] * initialKvDimensions[3]), initialKvDimensions)
const tokensTensor = new Onnx.Tensor('int64', new BigInt64Array(initialTokens.map(token => BigInt(token))), [1, initialTokens.length])
const offsetTensor = new Onnx.Tensor('int64', new BigInt64Array([BigInt(offset)]), [])
const decoderInputs = {
tokens: tokensTensor,
audio_features: audioFeatures,
kv_cache: kvCacheTensor,
offset: offsetTensor
}
const decoderOutputs = await this.textDecoder!.run(decoderInputs)
const logitsBuffer = decoderOutputs['logits'].data as Float32Array
const tokenConfig = this.tokenConfig
const languageTokensLogits = Array.from(logitsBuffer.slice(tokenConfig.languageTokensStart, tokenConfig.languageTokensEnd))
const languageTokensProbabilities = softmax(languageTokensLogits, temperature)
const results: LanguageDetectionResults = []
for (const language in languageIdLookup) {
const langId = languageIdLookup[language]
const probability = languageTokensProbabilities[langId]
results.push({
language,
languageName: languageCodeToName(language),
probability
})
}
logger.end()
return results
}
async detectVoiceActivity(audioFeatures: Onnx.Tensor, temperature: number): Promise<number> {
await this.initializeDecoderSessionIfNeeded()
// Prepare and run decoder
const logger = new Logger()
await logger.startAsync('Detect voice activity with Whisper model')
const sotToken = this.tokenConfig.startOfTextToken
const initialTokens = [sotToken]
const offset = 0
const Onnx = await import('onnxruntime-node')
const initialKvDimensions = this.getKvDimensions(1, initialTokens.length)
const kvCacheTensor = new Onnx.Tensor('float32', new Float32Array(initialKvDimensions[0] * initialKvDimensions[1] * initialKvDimensions[2] * initialKvDimensions[3]), initialKvDimensions)
const tokensTensor = new Onnx.Tensor('int64', new BigInt64Array(initialTokens.map(token => BigInt(token))), [1, initialTokens.length])
const offsetTensor = new Onnx.Tensor('int64', new BigInt64Array([BigInt(offset)]), [])
const decoderInputs = {
tokens: tokensTensor,
audio_features: audioFeatures,
kv_cache: kvCacheTensor,
offset: offsetTensor
}
const decoderOutputs = await this.textDecoder!.run(decoderInputs)
const logitsBuffer = decoderOutputs['logits'].data as Float32Array
const tokenConfig = this.tokenConfig
const logits = Array.from(logitsBuffer)
const probabilities = softmax(logits, temperature)
const noSpeechProbability = probabilities[tokenConfig.nonSpeechToken]
return 1.0 - noSpeechProbability
}
// Decode tokens using the decoder model
async decodeTokens(
audioFeatures: Onnx.Tensor,
initialTokens: number[],
audioDuration: number,
isFirstPart: boolean,
isFinalPart: boolean,
options: WhisperOptions,
logitFilter?: WhisperLogitFilter) {
// Initialize
await this.initializeTokenizerIfNeeded()
await this.initializeDecoderSessionIfNeeded()
const logger = new Logger()
await logger.startAsync('Decode text tokens with Whisper decoder model')
options = extendDeep(defaultWhisperOptions, options)
const Onnx = await import('onnxruntime-node')
// Get token information
const endOfTextToken = this.tokenConfig.endOfTextToken
const timestampTokensStart = this.tokenConfig.timestampTokensStart
const suppressedTextTokens = this.getSuppressedTextTokens()
const suppressedMetadataTokens = this.getSuppressedMetadataTokens()
const allowedPunctuationMarks = this.getAllowedPunctuationMarks()
const spaceToken = this.textToTokens(' ')[0]
// Initialize variables for decoding loop
let decodedTokens = initialTokens.slice()
const initialKvDimensions = this.getKvDimensions(1, decodedTokens.length)
let kvCacheTensor = new Onnx.Tensor('float32', new Float32Array(initialKvDimensions[0] * initialKvDimensions[1] * initialKvDimensions[2] * initialKvDimensions[3]), initialKvDimensions)
let decodedTokensTimestampLogits: number[][] = []
let decodedTokensConfidence: number[] = []
let decodedTokensCrossAttentionQKs: OnnxLikeFloat32Tensor[] = []
for (let i = 0; i < decodedTokens.length; i++) {
decodedTokensTimestampLogits.push(new Array(1501)) // Should the length be 1500 instead?
decodedTokensConfidence.push(1.0)
decodedTokensCrossAttentionQKs.push(undefined as any)
}
let lastTimestampTokenIndex = -1
let timestampTokenSeenCount = 0
let bufferedTokensToPrint: number[] = []
// Define method to add a token to output
function addToken(tokenToAdd: number, timestampLogits: number[], confidence: number, crossAttentionQKs: OnnxLikeFloat32Tensor) {
decodedTokens.push(tokenToAdd)
decodedTokensTimestampLogits.push(timestampLogits)
decodedTokensConfidence.push(confidence)
decodedTokensCrossAttentionQKs.push(crossAttentionQKs)
}
const maxTokensPerPart = Math.min(options.maxTokensPerPart!, largestMaximumTokensPerPart)
let decodedTokensInferenceTime: number[] = []
let decodedTokensDecodingTime: number[] = []
const tokenDecodingTimeTimer = new Timer()
// Start decoding loop
for (let decodedTokenCount = 0; decodedTokenCount < maxTokensPerPart; decodedTokenCount++) {
if (decodedTokenCount > 0) {
decodedTokensDecodingTime.push(tokenDecodingTimeTimer.getElapsedTimeAndRestart())
}
const isInitialState = decodedTokens.length === initialTokens.length
const atLeastOneTextTokenDecoded = decodedTokens.slice(initialTokens.length).some(token => this.isTextToken(token))
// If not in initial state, reshape KV Cache tensor to accomodate a new output token
if (!isInitialState) {
const dims = kvCacheTensor.dims
const currentKvCacheGroups = splitFloat32Array(kvCacheTensor.data as Float32Array, dims[2] * dims[3])
const reshapedKvCacheTensor = new Onnx.Tensor('float32', new Float32Array(dims[0] * dims[1] * (decodedTokens.length) * dims[3]), [dims[0], dims[1], decodedTokens.length, dims[3]])
const reshapedKvCacheGroups = splitFloat32Array(reshapedKvCacheTensor.data, decodedTokens.length * dims[3])
for (let i = 0; i < dims[0]; i++) {
reshapedKvCacheGroups[i].set(currentKvCacheGroups[i])
}
kvCacheTensor = reshapedKvCacheTensor
}
// Prepare values for decoder
const tokensToDecode = isInitialState ? decodedTokens : [decodedTokens[decodedTokens.length - 1]]
const offset = isInitialState ? 0 : decodedTokens.length
const tokensTensor = new Onnx.Tensor('int64', new BigInt64Array(tokensToDecode.map(token => BigInt(token))), [1, tokensToDecode.length])
const offsetTensor = new Onnx.Tensor('int64', new BigInt64Array([BigInt(offset)]), [])
const decoderInputs = {
tokens: tokensTensor,
audio_features: audioFeatures,
kv_cache: kvCacheTensor,
offset: offsetTensor
}
//// Infer with ONNX decoder model
const tokenInferenceTimeTimer = new Timer()
const decoderOutputs = await this.textDecoder!.run(decoderInputs)
decodedTokensInferenceTime.push(tokenInferenceTimeTimer.elapsedTime)
// Extract decoder model results
const logitsBuffer = decoderOutputs['logits'].data as Float32Array
kvCacheTensor = decoderOutputs['output_kv_cache'] as any
const crossAttentionQKsForTokenOnnx = decoderOutputs['cross_attention_qks']
const crossAttentionQKsForToken = makeOnnxLikeFloat32Tensor(crossAttentionQKsForTokenOnnx)
crossAttentionQKsForTokenOnnx.dispose()
// Get logits
const resultLogitsFloatArrays = splitFloat32Array(logitsBuffer, logitsBuffer.length / decoderOutputs['logits'].dims[1])
const allTokenLogits = Array.from(resultLogitsFloatArrays[resultLogitsFloatArrays.length - 1])
// Suppress metadata tokens in the suppression set
for (const suppressedTokenIndex of suppressedMetadataTokens) {
allTokenLogits[suppressedTokenIndex] = -Infinity
}
if (!atLeastOneTextTokenDecoded) {
// If in initial state, suppress end-of-text token
allTokenLogits[endOfTextToken] = -Infinity
}
const timestampTokenLogits = allTokenLogits.slice(timestampTokensStart)
const decodeTimestampTokenIfNeeded = () => {
// Try to decode a timestamp token, if needed
// If timestamp tokens is disabled in options, don't decode a timestamp
if (!options.decodeTimestampTokens) {
return false
}
// If this is the first token in the part, unconditionally decode a timestamp token
// for time 0.0
if (isInitialState) {
addToken(timestampTokensStart, timestampTokenLogits, 1.0, crossAttentionQKsForToken)
return true
}
const previousTokenWasTimestamp = this.isTimestampToken(decodedTokens[decodedTokens.length - 1])
const secondPreviousTokenWasTimestamp = this.isTimestampToken(decodedTokens[decodedTokens.length - 2])
// If there are two successive timestamp tokens decoded, or the previous timestamp was the first token,
// don't decode a timestamp
if (previousTokenWasTimestamp &&
((decodedTokens.length === initialTokens.length + 1) || secondPreviousTokenWasTimestamp)) {
return false
}
// Derive token probabilities
const allTokenProbabilities = softmax(allTokenLogits as any, 1.0)
const allTokenLogProbabilities = logOfVector(allTokenProbabilities)
const nonTimestampTokenLogProbs = allTokenLogProbabilities.slice(0, timestampTokensStart)
// Find highest non-timestamp token
const indexOfMaxNonTimestampLogProb = indexOfMax(nonTimestampTokenLogProbs)
const valueOfMaxNonTimestampLogProb = nonTimestampTokenLogProbs[indexOfMaxNonTimestampLogProb]
// Find highest timestamp token
const timestampTokenLogProbs = allTokenLogProbabilities.slice(timestampTokensStart)
const indexOfMaxTimestampLogProb = indexOfMax(timestampTokenLogProbs)
// Compute the log of the sum of exponentials of the log probabilities
// of the timestamp tokens
const logSumExpOfTimestampTokenLogProbs = logSumExp(timestampTokenLogProbs)
// If the sum isn't greater than the log probability of the highest non-timestamp token,
// don't decode a timestamp
if (logSumExpOfTimestampTokenLogProbs <= valueOfMaxNonTimestampLogProb) {
return false
}
// Decode a timestamp token
timestampTokenSeenCount += 1
if (previousTokenWasTimestamp) {
// If previously decoded token was a timestamp token, repeat it
const previousToken = decodedTokens[decodedTokens.length - 1]
const previousTokenTimestampLogits = decodedTokensTimestampLogits[decodedTokensTimestampLogits.length - 1]
const previousTokenConfidence = decodedTokensConfidence[decodedTokensConfidence.length - 1]
addToken(previousToken, previousTokenTimestampLogits, previousTokenConfidence, crossAttentionQKsForToken)
lastTimestampTokenIndex = decodedTokens.length
} else {
// Otherwise decode the highest probability timestamp
const timestampToken = timestampTokensStart + indexOfMaxTimestampLogProb
const confidence = allTokenProbabilities[timestampToken]
addToken(timestampToken, timestampTokenLogits, confidence, crossAttentionQKsForToken)
}
return true
}
// Call the method to decode timestamp token if needed
const timestampTokenDecoded = decodeTimestampTokenIfNeeded()
if (timestampTokenDecoded) {
await yieldToEventLoop()
continue
}
// Decode a non-timestamp token
let nonTimestampTokenLogits = allTokenLogits.slice(0, timestampTokensStart)
let shouldDecodeEndfOfTextToken = false
// If at least one text token was decoded, and the end-of-text token's probability is
// sufficiently higher than the second highest ranked token, then accept end-of-text
if (atLeastOneTextTokenDecoded) {
const endOfTextTokenLogit = nonTimestampTokenLogits[endOfTextToken]
const otherTokensLogits = nonTimestampTokenLogits.slice()
otherTokensLogits[endOfTextToken] = -Infinity
const indexOfMaximumOtherTokenLogit = indexOfMax(otherTokensLogits)
const maximumOtherTokenLogit = nonTimestampTokenLogits[indexOfMaximumOtherTokenLogit]
const endProbabilities = softmax([endOfTextTokenLogit, maximumOtherTokenLogit], 1.0)
if (endProbabilities[0] > options.endTokenThreshold!) {
shouldDecodeEndfOfTextToken = true
}
}
if (logitFilter) {
// Apply custom logit filter function if given
nonTimestampTokenLogits = logitFilter(nonTimestampTokenLogits, decodedTokens, isFirstPart, isFinalPart)
// If the custom filter set the end-of-text token to be Infinity, or -Infinity,
// then override any previous decision and accept or reject it, respectively
if (nonTimestampTokenLogits[endOfTextToken] === Infinity) {
shouldDecodeEndfOfTextToken = true
} else if (nonTimestampTokenLogits[endOfTextToken] === -Infinity) {
shouldDecodeEndfOfTextToken = false
}
// If filter caused all word token logits to be -Infinity, then there is no
// other token to decode. Fall back to accept end-of-text
if (nonTimestampTokenLogits.slice(0, endOfTextToken).every(logit => logit === -Infinity)) {
shouldDecodeEndfOfTextToken = true
}
} else {
// Otherwise, suppress text tokens in the suppression set
for (const suppressedTokenIndex of suppressedTextTokens) {
nonTimestampTokenLogits[suppressedTokenIndex] = -Infinity
}
// Suppress the space token if at initial state
if (isInitialState) {
nonTimestampTokenLogits[spaceToken] = -Infinity
}
}
// If end-of-text token should be decoded, then add it and break
// out of the loop
if (shouldDecodeEndfOfTextToken) {
addToken(endOfTextToken, timestampTokenLogits, 1.0, crossAttentionQKsForToken)
break
}
// Suppress end-of-text token if it shouldn't be included in candidates
if (!options.includeEndTokenInCandidates) {
nonTimestampTokenLogits[endOfTextToken] = -Infinity
}
// Find top candidates
const sortedTopCandidateTokens = getTopKIndexes(nonTimestampTokenLogits, options.topCandidateCount!, true)
let topCandidates = Array.from(sortedTopCandidateTokens).map(token => {
const logit = nonTimestampTokenLogits[token]
return {
token,
logit,
text: this.tokenToText(token, true)
}
})
// Apply repetition suppression if enabled
if (options.suppressRepetition) {
// Using some hardcoded constants, for now:
const tokenWindowSize = 30
const thresholdMatchLength = 4
const thresholdCycleRepetitionCount = 3
const filteredCandidates: typeof topCandidates = []
for (const candidate of topCandidates) {
const decodedTextTokens = decodedTokens.filter(token => this.isTextToken(token))
const lastDecodedTextTokens = decodedTextTokens
.slice(Math.max(decodedTextTokens.length - tokenWindowSize, 0))
.reverse()
const { longestMatchLength, longestCycleRepetitionCount } = getTokenRepetitionScore([candidate.token, ...lastDecodedTextTokens])
if (longestMatchLength >= thresholdMatchLength || longestCycleRepetitionCount >= thresholdCycleRepetitionCount) {
continue
}
filteredCandidates.push(candidate)
}
// If all candidates have been filtered out, accept an end-of-text token
if (filteredCandidates.length === 0) {
filteredCandidates.push({
token: endOfTextToken,
logit: Infinity,
text: this.tokenToText(endOfTextToken, true)
})
}
topCandidates = filteredCandidates
}
// Compute top candidate probabilities
const topCandidateProbabilities = softmax(topCandidates.map(a => a.logit), options.temperature)
// Find highest ranking punctuation token
const rankOfPromisingPunctuationToken = topCandidates.findIndex((entry, index) => {
const tokenText = this.tokenToText(entry.token).trim()
const isPunctuationToken = allowedPunctuationMarks.includes(tokenText)
if (!isPunctuationToken) {
return false
}
const tokenProb = topCandidateProbabilities[index]
return tokenProb >= options.punctuationThreshold!
})
// Find rank of space token
let rankOfSpaceToken = topCandidates.findIndex(candidate => candidate.token === spaceToken)
if (rankOfSpaceToken < 0) {
rankOfSpaceToken = Infinity
}
// Choose token
let chosenCandidateRank: number
// Select a high-ranking punctuation token if found, and it has
// a rank higher than the space token,
if (rankOfPromisingPunctuationToken >= 0 &&
rankOfPromisingPunctuationToken < rankOfSpaceToken) {
chosenCandidateRank = rankOfPromisingPunctuationToken
} else {
// Otherwise, select randomly from top k distribution
chosenCandidateRank = this.randomGen.selectRandomIndexFromDistribution(topCandidateProbabilities)
}
// Add chosen token
const chosenToken = topCandidates[chosenCandidateRank].token
const chosenTokenConfidence = topCandidateProbabilities[chosenCandidateRank]
addToken(chosenToken, timestampTokenLogits, chosenTokenConfidence, crossAttentionQKsForToken)
// If chosen token is the end-of-text token, break
if (chosenToken === endOfTextToken) {
break
}
// Print token if needed
if (this.isTextToken(chosenToken)) {
bufferedTokensToPrint.push(chosenToken)
let textToPrint = this.tokensToText(bufferedTokensToPrint)
// If the decoded text is valid, print it
if (!containsInvalidCodepoint(textToPrint)) {
if (isFirstPart && decodedTokens.every(token => this.isMetadataToken(token))) {
textToPrint = textToPrint.trimStart()
}
logger.write(textToPrint)
bufferedTokensToPrint = []
}
}
await yieldToEventLoop()
}
if (decodedTokensDecodingTime.length === decodedTokensInferenceTime.length - 1) {
decodedTokensDecodingTime.push(tokenDecodingTimeTimer.getElapsedTimeAndRestart())
}
// If at least two timestamp tokens were decoded and it's not the final part,
// truncate up to the last timestamp token
if (timestampTokenSeenCount >= 2 && !isFinalPart) {
const sliceEndTokenIndex = lastTimestampTokenIndex
decodedTokens = decodedTokens.slice(0, sliceEndTokenIndex)
decodedTokensTimestampLogits = decodedTokensTimestampLogits.slice(0, sliceEndTokenIndex)
decodedTokensCrossAttentionQKs = decodedTokensCrossAttentionQKs.slice(0, sliceEndTokenIndex)
decodedTokensConfidence = decodedTokensConfidence.slice(0, sliceEndTokenIndex)
decodedTokensDecodingTime = decodedTokensDecodingTime.slice(0, sliceEndTokenIndex)
decodedTokensInferenceTime = decodedTokensInferenceTime.slice(0, sliceEndTokenIndex)
}
const decodedTokensOverheadTime = decodedTokensDecodingTime.map((time, index) => time - decodedTokensInferenceTime[index])
logger.write('\n')
logger.end()
// Return the decoded tokens
return {
decodedTokens,
decodedTokensTimestampLogits,
decodedTokensConfidence,
decodedTokensCrossAttentionQKs,
decodedTokensDecodingTime,
decodedTokensInferenceTime,
decodedTokensOverheadTime,
}
}
// Encode audio using the encoder model
async encodeAudio(rawAudio: RawAudio) {
await this.initializeEncoderSessionIfNeeded()
const Onnx = await import('onnxruntime-node')
const logger = new Logger()
const audioSamples = rawAudio.audioChannels[0]
const sampleRate = rawAudio.sampleRate
const fftOrder = 400
const fftWindowSize = 400
const fftHopLength = 160
const filterbankCount = this.filterbankCount
const filterbanks = this.filterbanks
const maxAudioSamples = sampleRate * 30
const maxAudioFrames = 3000
if (sampleRate !== 16000) {
throw new Error('Audio must have a sample rate of 16000 Hz')
}
if (audioSamples.length > maxAudioSamples) {
throw new Error(`Audio part is longer than 30 seconds`)
}
await logger.startAsync('Extract Mel spectrogram from audio part')
// Pad audio samples to ensure that have a duration of 30 seconds
const paddedAudioSamples = new Float32Array(maxAudioSamples)
paddedAudioSamples.set(audioSamples, 0)
const rawAudioPart: RawAudio = { audioChannels: [paddedAudioSamples], sampleRate }
// Compute Mel spectrogram
const { melSpectrogram } = await computeMelSpectrogramUsingFilterbanks(rawAudioPart, fftOrder, fftWindowSize, fftHopLength, filterbanks)
// Flatten, transpose, apply logarithm and normalize Mel spectrogram
await logger.startAsync('Process Mel spectrogram')
const flattenedLogMelSpectrogram = new Float32Array(maxAudioFrames * filterbankCount)
let maxLogMel = -Infinity
for (let i = 0; i < filterbankCount; i++) {
for (let j = 0; j < maxAudioFrames; j++) {
const mel = melSpectrogram[j][i]
const logMel = Math.log10(Math.max(mel, 1e-10))
if (logMel > maxLogMel) {
maxLogMel = logMel
}
flattenedLogMelSpectrogram[(i * maxAudioFrames) + j] = logMel
}
}
for (let i = 0; i < flattenedLogMelSpectrogram.length; i++) {
const logMel = flattenedLogMelSpectrogram[i]
const normalizedLogMel = (Math.max(logMel, maxLogMel - 8) + 4) / 4
flattenedLogMelSpectrogram[i] = normalizedLogMel
}
// Run the encoder model
await logger.startAsync('Encode Mel spectrogram with Whisper encoder model')
const inputTensor = new Onnx.Tensor('float32', flattenedLogMelSpectrogram, [1, filterbankCount, maxAudioFrames])
const encoderInputs = { mel: inputTensor }
const encoderOutputs = await this.audioEncoder!.run(encoderInputs)
const encodedAudioFeatures = encoderOutputs['output']
logger.end()
return encodedAudioFeatures
}
tokenTimelineToWordTimeline(tokenTimeline: Timeline, language: string): Timeline {
function isSeparatorCharacter(char: string) {
const nonSeparatingPunctuation = [`'`, `-`, `.`, `·`, `•`]
if (nonSeparatingPunctuation.includes(char)) {
return false
}
return isWhitespace(char) || includesPunctuation(char)
}
function startsWithSeparatorCharacter(text: string) {
return isSeparatorCharacter(text[0])
}
function endsWithSeparatorCharacter(text: string) {
return isSeparatorCharacter(text[text.length - 1])
}
if (language !== 'zh' && language !== 'ja') {
tokenTimeline = tokenTimeline.filter(entry => this.isTextToken(entry.id!))
}
const resultTimeline: Timeline = []
let groups: Timeline[] = []
for (let tokenIndex = 0; tokenIndex < tokenTimeline.length; tokenIndex++) {
const entry = tokenTimeline[tokenIndex]
const previousEntry = tokenIndex > 0 ? tokenTimeline[tokenIndex - 1] : undefined
const text = entry.text
const previousEntryText = previousEntry?.text
if (groups.length === 0 ||
text === '' ||
startsWithSeparatorCharacter(text) ||
(previousEntryText != null && endsWithSeparatorCharacter(previousEntryText))) {
groups.push([entry])
} else {
groups[groups.length - 1].push(entry)
}
}
{
const splitGroups: Timeline[] = []
for (let groupIndex = 0; groupIndex < groups.length; groupIndex++) {
const group = groups[groupIndex]
const nextGroup = groups[groupIndex + 1]
if (
group.length > 1 &&
group[group.length - 1].text === '.' &&
(!nextGroup || [' ', '['].includes(nextGroup[0].text[0]))) {
splitGroups.push(group.slice(0, group.length - 1))
splitGroups.push(group.slice(group.length - 1))
} else {
splitGroups.push(group)
}
}
groups = splitGroups
}
for (const group of groups) {
let groupText = this.tokensToText(group.map(entry => entry.id!))
if (groupText === '') {
continue
}
const startTime = group[0].startTime
const endTime = group[group.length - 1].endTime
let confidence: number | undefined = undefined
if (group[0].confidence != null) {
confidence = meanOfVector(group.map(entry => entry.confidence!))
}
const newEntry: TimelineEntry = {
type: 'word',
text: groupText.trim(),
startTime,
endTime,
confidence,
timeline: group,
}
resultTimeline.push(newEntry)
}
return resultTimeline
}
async getTokenTimelineFromAlignmentPath(alignmentPath: AlignmentPath, tokens: number[], startTimeOffset: number, endTimeOffset: number, tokensConfidence?: number[], correctionAmount = 0.0) {
if (alignmentPath.length === 0) {
return []
}
const tokenTimeline: Timeline = []
for (let pathIndex = 0; pathIndex < alignmentPath.length; pathIndex++) {
if (pathIndex !== 0 && alignmentPath[pathIndex].source === alignmentPath[pathIndex - 1].source) {
continue
}
const tokenMappingEntry = alignmentPath[pathIndex]
const tokenIndex = tokenMappingEntry.source
const token = tokens[tokenIndex]
const tokenConfidence = tokensConfidence ? tokensConfidence[tokenIndex] : undefined
const tokenText = this.tokenToText(token, true)
let startTime = startTimeOffset + (tokenMappingEntry.dest * 0.02)
startTime = Math.max(startTime + correctionAmount, startTimeOffset)
if (tokenTimeline.length > 0) {
tokenTimeline[tokenTimeline.length - 1].endTime = startTime
}
tokenTimeline.push({
type: 'token',
text: tokenText,
id: token,
startTime,
endTime: -1,
confidence: tokenConfidence
})
}
if (tokenTimeline.length > 0) {
tokenTimeline[tokenTimeline.length - 1].endTime = endTimeOffset
}
return tokenTimeline
}
async findAlignmentPathFromQKs(qksTensors: OnnxLikeFloat32Tensor[], tokens: number[], segmentStartFrame: number, segmentEndFrame: number, headIndexes?: number[]) {
const segmentFrameCount = segmentEndFrame - segmentStartFrame
if (segmentFrameCount === 0 || tokens.length === 0 || qksTensors.length === 0) {
return []
}
const tokenCount = qksTensors.length
const layerCount = qksTensors[0].dims[0]
const headCount = qksTensors[0].dims[2]
const frameCount = qksTensors[0].dims[4]
if (!headIndexes) {
headIndexes = getIntegerRange(0, layerCount * headCount)
}
// Load attention head weights from tensors
const attentionHeads: Float32Array[][] = [] // structure: [heads, tokens, frames]
for (const headIndex of headIndexes) {
const attentionHead: Float32Array[] = [] // structure: [tokens, frames]
for (let tokenIndex = 0; tokenIndex < tokenCount; tokenIndex++) {
const bufferOffset = headIndex * frameCount
const startIndexInBuffer = bufferOffset + segmentStartFrame
cons