UNPKG

@magenta/music

Version:

Make music with machine learning, in the browser.

272 lines 11.5 kB
import * as tf from '@tensorflow/tfjs'; import { loadAudioFromFile, loadAudioFromUrl } from '../core/audio_utils'; import * as logging from '../core/logging'; import { fetch, performance } from '../core/compat/global'; import { preprocessAudio } from './audio_utils'; import { MEL_SPEC_BINS, MIDI_PITCHES } from './constants'; import { batchInput, pianorollToNoteSequence, unbatchOutput } from './transcription_utils'; class OnsetsAndFrames { constructor(checkpointURL, chunkLength = 250) { this.checkpointURL = checkpointURL; this.chunkLength = chunkLength; } async initialize() { this.dispose(); const startTime = performance.now(); const vars = await fetch(`${this.checkpointURL}/weights_manifest.json`) .then((response) => response.json()) .then((manifest) => tf.io.loadWeights(manifest, this.checkpointURL)); this.build(vars); Object.keys(vars).map(name => vars[name].dispose()); this.initialized = true; logging.logWithDuration('Initialized model', startTime, 'O&F'); } dispose() { if (!this.initialized) { return; } this.onsetsCnn.dispose(); this.onsetsRnn.dispose(); this.activationCnn.dispose(); this.frameRnn.dispose(); this.velocityCnn.dispose(); this.initialized = false; } isInitialized() { return this.initialized; } async transcribeFromMelSpec(melSpec, parallelBatches = 4) { if (!this.isInitialized()) { this.initialize(); } const startTime = performance.now(); logging.log('Computing onsets, frames, and velocities...', 'O&F', 20); const [frameProbs, onsetProbs, velocities] = tf.tidy(() => { const batches = batchInput(melSpec, this.chunkLength); return this.processBatches(batches, this.chunkLength, melSpec.length, parallelBatches); }); logging.log('Converting to NoteSequence...', 'O&F', 20); const ns = pianorollToNoteSequence(frameProbs, onsetProbs, velocities); ns.then(() => { frameProbs.dispose(); onsetProbs.dispose(); velocities.dispose(); logging.logWithDuration('Transcribed from mel spec', startTime, 'O&F'); }); return ns; } async transcribeFromAudioBuffer(audioBuffer, batchSize = 4) { const startTime = performance.now(); const melSpec = preprocessAudio(audioBuffer); melSpec.then(() => logging.logWithDuration('Converted audio to mel spec', startTime, 'O&F', 20)); return melSpec.then((spec) => this.transcribeFromMelSpec(spec.map(a => Array.from(a), batchSize))); } async transcribeFromAudioFile(blob) { const audio = await loadAudioFromFile(blob); return this.transcribeFromAudioBuffer(audio); } async transcribeFromAudioURL(url) { const audio = await loadAudioFromUrl(url); return this.transcribeFromAudioBuffer(audio); } processBatches(batches, chunkLength, fullLength, batchSize) { const runCnns = ((batch) => [this.onsetsCnn.predict(batch, batchSize), this.activationCnn.predict(batch, batchSize), this.velocityCnn.predict(batch, batchSize)]); let onsetsCnnOut, activationProbs, scaledVelocities; if (batches.shape[0] === 1) { [onsetsCnnOut, activationProbs, scaledVelocities] = runCnns(batches.expandDims(-1)); } else { const batchesOutput = runCnns(batches.expandDims(-1)); const allOutputs = []; for (let i = 0; i < 3; ++i) { allOutputs.push(unbatchOutput(batchesOutput[i], chunkLength, fullLength)); } [onsetsCnnOut, activationProbs, scaledVelocities] = allOutputs; } const onsetProbs = this.onsetsRnn.predict(onsetsCnnOut, this.chunkLength); onsetsCnnOut.dispose(); const frameProbInputs = tf.concat3d([onsetProbs, activationProbs], -1); activationProbs.dispose(); const frameProbs = this.frameRnn.predict(frameProbInputs, this.chunkLength); const velocities = tf.clipByValue(scaledVelocities, 0., 1.) .mul(tf.scalar(80.)) .add(tf.scalar(10.)) .toInt(); scaledVelocities.dispose(); return [frameProbs.squeeze(), onsetProbs.squeeze(), velocities.squeeze()]; } build(vars) { tf.tidy(() => { this.onsetsCnn = new AcousticCnn(); this.onsetsCnn.setWeights(vars, 'onsets'); this.onsetsRnn = new Lstm([null, this.onsetsCnn.outputShape[2]]); this.onsetsRnn.setWeights(vars, 'onsets', 'onset_probs'); this.activationCnn = new AcousticCnn('sigmoid'); this.activationCnn.setWeights(vars, 'frame', 'activation_probs'); this.frameRnn = new Lstm([null, MIDI_PITCHES * 2]); this.frameRnn.setWeights(vars, 'frame', 'frame_probs'); this.velocityCnn = new AcousticCnn('linear'); this.velocityCnn.setWeights(vars, 'velocity', 'onset_velocities'); }); } } class AcousticCnn { constructor(finalDenseActivation) { this.nn = tf.sequential(); const convConfig = { filters: 48, kernelSize: [3, 3], activation: 'linear', useBias: false, padding: 'same', dilationRate: [1, 1], inputShape: [null, MEL_SPEC_BINS, 1], trainable: false }; const batchNormConfig = { scale: false, trainable: false }; this.nn.add(tf.layers.conv2d(convConfig)); this.nn.add(tf.layers.batchNormalization(batchNormConfig)); this.nn.add(tf.layers.activation({ activation: 'relu' })); convConfig.inputShape = null; this.nn.add(tf.layers.conv2d(convConfig)); this.nn.add(tf.layers.batchNormalization(batchNormConfig)); this.nn.add(tf.layers.activation({ activation: 'relu' })); this.nn.add(tf.layers.maxPooling2d({ poolSize: [1, 2], strides: [1, 2] })); convConfig.filters = 96; this.nn.add(tf.layers.conv2d(convConfig)); this.nn.add(tf.layers.batchNormalization(batchNormConfig)); this.nn.add(tf.layers.activation({ activation: 'relu' })); this.nn.add(tf.layers.maxPooling2d({ poolSize: [1, 2], strides: [1, 2] })); const dims = this.nn.outputShape; this.nn.add(tf.layers.reshape({ targetShape: [dims[1], dims[2] * dims[3]] })); this.nn.add(tf.layers.dense({ units: 768, activation: 'relu', trainable: false })); if (finalDenseActivation) { this.nn.add(tf.layers.dense({ units: MIDI_PITCHES, activation: finalDenseActivation, trainable: false })); } this.outputShape = this.nn.outputShape; } dispose() { this.nn.dispose(); } predict(inputs, batchSize) { return this.nn.predict(inputs, { batchSize }); } setWeights(vars, scope, denseName) { function getVar(name) { const v = vars[name]; if (v === undefined) { throw Error(`Variable not found: ${name}`); } return v; } let weights = [ getVar(`${scope}/conv0/weights`), getVar(`${scope}/conv0/BatchNorm/beta`), getVar(`${scope}/conv0/BatchNorm/moving_mean`), getVar(`${scope}/conv0/BatchNorm/moving_variance`), getVar(`${scope}/conv1/weights`), getVar(`${scope}/conv1/BatchNorm/beta`), getVar(`${scope}/conv1/BatchNorm/moving_mean`), getVar(`${scope}/conv1/BatchNorm/moving_variance`), getVar(`${scope}/conv2/weights`), getVar(`${scope}/conv2/BatchNorm/beta`), getVar(`${scope}/conv2/BatchNorm/moving_mean`), getVar(`${scope}/conv2/BatchNorm/moving_variance`), getVar(`${scope}/fc_end/weights`), getVar(`${scope}/fc_end/biases`), ]; if (denseName) { weights = weights.concat([ getVar(`${scope}/${denseName}/weights`), getVar(`${scope}/${denseName}/biases`) ]); } this.nn.setWeights(weights); } } class Lstm { constructor(inputShape, units = 384) { this.dense = tf.sequential(); this.units = units; function getLstm() { const lstm = tf.layers.lstm({ inputShape, units, returnSequences: true, recurrentActivation: 'sigmoid', returnState: true, kernelInitializer: 'zeros', recurrentInitializer: 'zeros', biasInitializer: 'zeros', trainable: false }); const inputs = [ tf.input({ shape: inputShape }), tf.input({ shape: [units] }), tf.input({ shape: [units] }) ]; const outputs = lstm.apply(inputs); return tf.model({ inputs, outputs }); } this.lstm = getLstm(); this.dense.add(tf.layers.dense({ inputShape: [null, units], units: MIDI_PITCHES, activation: 'sigmoid', trainable: false })); } dispose() { this.lstm.dispose(); this.dense.dispose(); } setWeights(vars, scope, denseName) { function getVar(name) { const v = vars[name]; if (v === undefined) { throw Error(`Variable not found: ${name}`); } return v; } const reorderGates = ((weights, forgetBias = 0) => { const [i, c, f, o] = tf.split(weights, 4, -1); return tf.concat([i, f.add(tf.scalar(forgetBias)), c, o], -1); }); const splitAndReorderKernel = ((kernel) => tf.split(reorderGates(kernel), [kernel.shape[0] - this.units, this.units])); const LSTM_PREFIX = 'cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell'; const setLstmWeights = (lstm) => lstm.setWeights(splitAndReorderKernel(getVar(`${scope}/${LSTM_PREFIX}/kernel`)) .concat(reorderGates(getVar(`${scope}/${LSTM_PREFIX}/bias`), 1.0))); setLstmWeights(this.lstm); this.dense.setWeights([ getVar(`${scope}/${denseName}/weights`), getVar(`${scope}/${denseName}/biases`) ]); } predict(inputs, chunkSize) { return tf.tidy(() => this.predictImpl(inputs, chunkSize)); } predictImpl(inputs, chunkSize) { const fullLength = inputs.shape[1]; const numChunks = Math.ceil(fullLength / chunkSize); let state = [tf.zeros([1, this.units]), tf.zeros([1, this.units])]; const outputChunks = []; for (let i = 0; i < numChunks; ++i) { const chunk = inputs.slice([0, i * chunkSize], [-1, (i < numChunks - 1) ? chunkSize : -1]); const result = this.lstm.predict([ chunk, state[0], state[1] ]); outputChunks.push(this.dense.predict(result[0])); state = result.slice(1); } return outputChunks.length === 1 ? outputChunks[0] : tf.concat3d(outputChunks, 1); } } export { OnsetsAndFrames }; //# sourceMappingURL=model.js.map