echogarden
Version:
An easy-to-use speech toolset. Includes tools for synthesis, recognition, alignment, speech translation, language detection, source separation and more.
335 lines • 13.2 kB
JavaScript
import { Logger } from '../utilities/Logger.js';
import { loadPackage } from '../utilities/PackageManager.js';
import { alignDTWWindowed } from './DTWSequenceAlignmentWindowed.js';
import { cosineDistance } from '../math/VectorMath.js';
import { includesPunctuation, isWord, parseText } from '../nlp/Segmentation.js';
import { extractEntries } from '../utilities/Timeline.js';
export async function alignTimelineToTextSemantically(timeline, text, textLangCode) {
const logger = new Logger();
logger.start(`Prepare text for semantic alignment`);
const timelineSentenceEntries = extractEntries(timeline, entry => entry.type === 'sentence');
const timelineWordEntryGroups = [];
const timelineWordGroups = [];
for (const sentenceEntry of timelineSentenceEntries) {
const wordEntryGroup = sentenceEntry.timeline
.filter(wordEntry => isWord(wordEntry.text));
timelineWordEntryGroups.push(wordEntryGroup);
timelineWordGroups.push(wordEntryGroup.map(wordEntry => wordEntry.text));
}
const timelineWordEntriesFiltered = timelineWordEntryGroups.flat();
const segmentedText = await parseText(text, textLangCode);
const textWordGroups = [];
for (const sentenceEntry of segmentedText.sentences) {
const wordGroup = sentenceEntry.words.nonPunctuationWords;
textWordGroups.push(wordGroup);
}
const textWords = textWordGroups.flat();
logger.end();
const wordMappingEntries = await alignWordsToWordsSemantically(timelineWordGroups, textWordGroups);
logger.start(`Build timeline for translation`);
const mappingGroups = new Map();
for (const wordMappingEntry of wordMappingEntries) {
const wordIndex1 = wordMappingEntry.wordIndex1;
const wordIndex2 = wordMappingEntry.wordIndex2;
let group = mappingGroups.get(wordIndex1);
if (!group) {
group = [];
mappingGroups.set(wordIndex1, group);
}
if (!group.includes(wordIndex2)) {
group.push(wordIndex2);
}
}
const timeSlicesLookup = new Map();
for (const [wordIndex1, mappedWordIndexes] of mappingGroups) {
if (mappedWordIndexes.length === 0) {
continue;
}
const startTime = timelineWordEntriesFiltered[wordIndex1].startTime;
const endTime = timelineWordEntriesFiltered[wordIndex1].endTime;
const splitCount = mappedWordIndexes.length;
const sliceDuration = (endTime - startTime) / splitCount;
let timeOffset = 0;
for (let i = 0; i < splitCount; i++) {
const timeSlice = {
startTime: startTime + timeOffset,
endTime: startTime + timeOffset + sliceDuration
};
const wordIndex2 = mappedWordIndexes[i];
let timeSlicesForTargetWord = timeSlicesLookup.get(wordIndex2);
if (!timeSlicesForTargetWord) {
timeSlicesForTargetWord = [];
timeSlicesLookup.set(wordIndex2, timeSlicesForTargetWord);
}
timeSlicesForTargetWord.push(timeSlice);
timeOffset += sliceDuration;
}
}
const resultTimeline = [];
for (const [key, value] of timeSlicesLookup) {
resultTimeline.push({
type: 'word',
text: textWords[key],
startTime: value[0].startTime,
endTime: value[value.length - 1].endTime
});
}
logger.end();
return resultTimeline;
}
export async function alignWordsToWordsSemantically(wordsGroups1, wordsGroups2, windowTokenCount = 20000) {
const logger = new Logger();
// Load embedding model
const modelPath = await loadPackage(`xenova-multilingual-e5-small-fp16`);
const embeddingModel = new E5TextEmbedding(modelPath);
logger.start(`Initialize E5 embedding model`);
await embeddingModel.initializeIfNeeded();
async function extractEmbeddingsFromWordGroups(wordGroups) {
const logger = new Logger();
const maxTokensPerFragment = 512;
const { Tensor } = await import('@echogarden/transformers-nodejs-lite');
const words = [];
const embeddings = [];
const tokenToWordIndexMapping = [];
for (const wordGroup of wordGroups) {
const { joinedText: joinedTextForGroup, offsets: offsetsForGroup } = joinAndGetOffsets(wordGroup);
logger.start(`Tokenize text`);
const inputsForGroup = await embeddingModel.tokenizeToModelInputs(joinedTextForGroup);
logger.start(`Infer embeddings for text`);
const allTokenIds = inputsForGroup['input_ids'].data;
const allAttentionMask = inputsForGroup['attention_mask'].data;
let embeddingsForGroup = [];
for (let tokenStart = 0; tokenStart < allTokenIds.length; tokenStart += maxTokensPerFragment) {
const tokenEnd = Math.min(tokenStart + maxTokensPerFragment, allTokenIds.length);
const fragmentTokenCount = tokenEnd - tokenStart;
const fragmentInputIdsTensor = new Tensor('int64', allTokenIds.slice(tokenStart, tokenEnd), [1, fragmentTokenCount]);
const fragmentAttentionMaskTensor = new Tensor('int64', allAttentionMask.slice(tokenStart, tokenEnd), [1, fragmentTokenCount]);
const inputsForFragment = { input_ids: fragmentInputIdsTensor, attention_mask: fragmentAttentionMaskTensor };
const embeddingsForFragment = await embeddingModel.inferTokenEmbeddings(inputsForFragment);
embeddingsForGroup.push(...embeddingsForFragment);
}
logger.start(`Compute token to word mapping for text`);
const filteredEmbeddingsForGroup = embeddingsForGroup.filter((embedding) => embedding.text !== '▁' && embedding.text !== '<s>' && embedding.text !== '</s>');
const tokenToWordIndexMappingForGroup = mapTokenEmbeddingsToWordIndexes(filteredEmbeddingsForGroup, joinedTextForGroup, offsetsForGroup);
const tokenToWordIndexMappingForGroupWithOffset = tokenToWordIndexMappingForGroup.map(value => words.length + value);
embeddings.push(...filteredEmbeddingsForGroup);
tokenToWordIndexMapping.push(...tokenToWordIndexMappingForGroupWithOffset);
words.push(...wordGroup);
}
return { words, embeddings, tokenToWordIndexMapping };
}
logger.start(`Extract embeddings from source 1`);
const { words: words1, embeddings: embeddings1, tokenToWordIndexMapping: tokenToWordIndexMapping1 } = await extractEmbeddingsFromWordGroups(wordsGroups1);
logger.start(`Extract embeddings from source 2`);
const { words: words2, embeddings: embeddings2, tokenToWordIndexMapping: tokenToWordIndexMapping2 } = await extractEmbeddingsFromWordGroups(wordsGroups2);
// Align
function costFunction(a, b) {
const aIsPunctuation = includesPunctuation(a.text);
const bIsPunctuation = includesPunctuation(b.text);
if (aIsPunctuation === bIsPunctuation) {
return cosineDistance(a.embeddingVector, b.embeddingVector);
}
else {
return 1.0;
}
}
logger.start(`Align token embedding vectors using DTW`);
const { path } = alignDTWWindowed(embeddings1, embeddings2, costFunction, windowTokenCount);
// Use alignment path to words to words
logger.start(`Map tokens to words`);
const wordMapping = [];
for (let i = 0; i < path.length; i++) {
const pathEntry = path[i];
const sourceTokenIndex = pathEntry.source;
const destTokenIndex = pathEntry.dest;
const mappedWordIndex1 = tokenToWordIndexMapping1[sourceTokenIndex];
const mappedWordIndex2 = tokenToWordIndexMapping2[destTokenIndex];
wordMapping.push({
wordIndex1: mappedWordIndex1,
word1: words1[mappedWordIndex1],
wordIndex2: mappedWordIndex2,
word2: words2[mappedWordIndex2],
});
}
logger.end();
return wordMapping;
}
function mapTokenEmbeddingsToWordIndexes(embeddings, text, textWordOffsets) {
const tokenToWordIndex = [];
let currentTextOffset = 0;
for (let i = 0; i < embeddings.length; i++) {
const embedding = embeddings[i];
let tokenText = embedding.text;
if (tokenText === '<s>' || tokenText === '</s>') {
tokenToWordIndex.push(-1);
continue;
}
if (tokenText.startsWith('▁')) {
tokenText = tokenText.substring(1);
}
const matchPosition = text.indexOf(tokenText, currentTextOffset);
if (matchPosition === -1) {
throw new Error(`Token '${tokenText}' not found in text`);
}
currentTextOffset = matchPosition + tokenText.length;
let tokenMatchingWordIndex = textWordOffsets.findIndex((index) => index > matchPosition);
if (tokenMatchingWordIndex === -1) {
throw new Error(`Token '${tokenText}' not found in text`);
}
else {
tokenMatchingWordIndex = Math.max(tokenMatchingWordIndex - 1, 0);
}
tokenToWordIndex.push(tokenMatchingWordIndex);
}
return tokenToWordIndex;
}
function joinAndGetOffsets(words) {
let joinedText = '';
const offsets = [];
let offset = 0;
for (const word of words) {
const extendedWord = `${word} `;
joinedText += extendedWord;
offsets.push(offset);
offset += extendedWord.length;
}
offsets.push(joinedText.length);
return { joinedText, offsets };
}
export class E5TextEmbedding {
modelPath;
tokenizer;
model;
constructor(modelPath) {
this.modelPath = modelPath;
}
async tokenizeToModelInputs(text) {
await this.initializeIfNeeded();
const inputs = await this.tokenizer(text);
return inputs;
}
async inferTokenEmbeddings(inputs) {
await this.initializeIfNeeded();
const tokensText = this.tokenizer.model.convert_ids_to_tokens(Array.from(inputs.input_ids.data));
const result = await this.model(inputs);
const lastHiddenState = result.last_hidden_state;
const tokenCount = lastHiddenState.dims[1];
const embeddingSize = lastHiddenState.dims[2];
const tokenEmbeddings = [];
for (let i = 0; i < tokenCount; i++) {
const tokenEmbeddingVector = lastHiddenState.data.slice(i * embeddingSize, (i + 1) * embeddingSize);
const tokenId = Number(inputs.input_ids.data[i]);
const tokenText = tokensText[i];
tokenEmbeddings.push({
id: tokenId,
text: tokenText,
embeddingVector: tokenEmbeddingVector
});
}
return tokenEmbeddings;
}
async initializeIfNeeded() {
if (this.tokenizer && this.model) {
return;
}
const { AutoTokenizer, AutoModel } = await import('@echogarden/transformers-nodejs-lite');
this.tokenizer = await AutoTokenizer.from_pretrained(this.modelPath);
this.model = await AutoModel.from_pretrained(this.modelPath);
}
}
export const e5SupportedLanguages = [
'af', // Afrikaans
'am', // Amharic
'ar', // Arabic
'as', // Assamese
'az', // Azerbaijani
'be', // Belarusian
'bg', // Bulgarian
'bn', // Bengali
'br', // Breton
'bs', // Bosnian
'ca', // Catalan
'cs', // Czech
'cy', // Welsh
'da', // Danish
'de', // German
'el', // Greek
'en', // English
'eo', // Esperanto
'es', // Spanish
'et', // Estonian
'eu', // Basque
'fa', // Persian
'fi', // Finnish
'fr', // French
'fy', // Western Frisian
'ga', // Irish
'gd', // Scottish Gaelic
'gl', // Galician
'gu', // Gujarati
'ha', // Hausa
'he', // Hebrew
'hi', // Hindi
'hr', // Croatian
'hu', // Hungarian
'hy', // Armenian
'id', // Indonesian
'is', // Icelandic
'it', // Italian
'ja', // Japanese
'jv', // Javanese
'ka', // Georgian
'kk', // Kazakh
'km', // Khmer
'kn', // Kannada
'ko', // Korean
'ku', // Kurdish
'ky', // Kyrgyz
'la', // Latin
'lo', // Lao
'lt', // Lithuanian
'lv', // Latvian
'mg', // Malagasy
'mk', // Macedonian
'ml', // Malayalam
'mn', // Mongolian
'mr', // Marathi
'ms', // Malay
'my', // Burmese
'ne', // Nepali
'nl', // Dutch
'no', // Norwegian
'om', // Oromo
'or', // Oriya
'pa', // Panjabi
'pl', // Polish
'ps', // Pashto
'pt', // Portuguese
'ro', // Romanian
'ru', // Russian
'sa', // Sanskrit
'sd', // Sindhi
'si', // Sinhala
'sk', // Slovak
'sl', // Slovenian
'so', // Somali
'sq', // Albanian
'sr', // Serbian
'su', // Sundanese
'sv', // Swedish
'sw', // Swahili
'ta', // Tamil
'te', // Telugu
'th', // Thai
'tl', // Tagalog
'tr', // Turkish
'ug', // Uyghur
'uk', // Ukrainian
'ur', // Urdu
'uz', // Uzbek
'vi', // Vietnamese
'xh', // Xhosa
'yi', // Yiddish
'zh', // Chinese
];
//# sourceMappingURL=SemanticTextAlignment.js.map