UNPKG

@huggingface/transformers

Version:

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!

160 lines (138 loc) • 7.13 kB
import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js'; import { Tensor } from '../../utils/tensor.js'; import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js'; export class ClapFeatureExtractor extends FeatureExtractor { constructor(config) { super(config); this.mel_filters = mel_filter_bank( this.config.nb_frequency_bins, // num_frequency_bins this.config.feature_size, // num_mel_filters this.config.frequency_min, // min_frequency this.config.frequency_max, // max_frequency this.config.sampling_rate, // sampling_rate null, // norm "htk", // mel_scale ); this.mel_filters_slaney = mel_filter_bank( this.config.nb_frequency_bins, // num_frequency_bins this.config.feature_size, // num_mel_filters this.config.frequency_min, // min_frequency this.config.frequency_max, // max_frequency this.config.sampling_rate, // sampling_rate "slaney", // norm "slaney", // mel_scale ); this.window = window_function(this.config.fft_window_size, 'hann') } /** * Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments. * * Four different path are possible: * - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram * will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram * are then stacked together. They will later be used for `feature_fusion`. * - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is * padded based on `padding`. * - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded * based on `padding`, and is repeated `4` times. * - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel * spectrogram will be computed on a random crop of the waveform. * * @param {Float32Array|Float64Array} waveform The input waveform. * @param {number} max_length The maximum length of the waveform. * @param {string} truncation The truncation strategy to use. * @param {string} padding The padding strategy to use. * @returns {Promise<Tensor>} An object containing the mel spectrogram data as a Float32Array, its dimensions as an array of numbers, and a boolean indicating whether the waveform was longer than the max length. * @private */ async _get_input_mel(waveform, max_length, truncation, padding) { /** @type {Tensor} */ let input_mel; let longer = false; const diff = waveform.length - max_length; if (diff > 0) { if (truncation === 'rand_trunc') { longer = true; const idx = Math.floor(Math.random() * (diff + 1)); waveform = waveform.subarray(idx, idx + max_length); input_mel = await this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples); } else { // TODO implement fusion strategy throw new Error(`Truncation strategy "${truncation}" not implemented`) } } else { if (diff < 0) { let padded = new Float64Array(max_length); // already padded with zeros padded.set(waveform); if (padding === 'repeat') { for (let i = waveform.length; i < max_length; i += waveform.length) { padded.set(waveform.subarray(0, Math.min(waveform.length, max_length - i)), i); } } else if (padding === 'repeatpad') { for (let i = waveform.length; i < -diff; i += waveform.length) { padded.set(waveform, i); } } waveform = padded; } if (truncation === 'fusion') { throw new Error(`Truncation strategy "${truncation}" not implemented`) } input_mel = await this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples); } return input_mel.unsqueeze_(0); } /** * Compute the log-mel spectrogram of the provided `waveform` using the Hann window. * In CLAP, two different filter banks are used depending on the truncation pattern: * - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from * calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation` * is set to `"fusion"`. * - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used * `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original * implementation when the truncation mode is not `"fusion"`. * * @param {Float32Array|Float64Array} waveform The audio waveform to process. * @param {number[][]} mel_filters The mel filters to use. * @param {number} [max_length=null] The maximum number of frames to return. * @returns {Promise<Tensor>} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers. */ async _extract_fbank_features(waveform, mel_filters, max_length = null) { // NOTE: We don't pad/truncate since that is passed in as `max_num_frames` return spectrogram( waveform, this.window, // window this.config.fft_window_size, // frame_length this.config.hop_length, // hop_length { power: 2.0, mel_filters, log_mel: 'dB', // Custom max_num_frames: max_length, do_pad: false, transpose: true, } ) } /** * Asynchronously extracts features from a given audio using the provided configuration. * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. */ async _call(audio, { max_length = null, } = {}) { validate_audio_inputs(audio, 'ClapFeatureExtractor'); // convert to mel spectrogram, truncate and pad if needed. const padded_inputs = await this._get_input_mel( audio, max_length ?? this.config.nb_max_samples, this.config.truncation, this.config.padding, ); return { input_features: padded_inputs.unsqueeze_(0), } } }