UNPKG

@magenta/music

Version:

Make music with machine learning, in the browser.

358 lines 14.5 kB
import * as tf from '@tensorflow/tfjs-core'; import { fetch, performance } from '../core/compat/global'; import * as logging from '../core/logging'; import * as sequences from '../core/sequences'; import { IS_IOS, NUM_PITCHES, pianorollToSequence, sequenceToPianoroll } from './coconet_utils'; const DEFAULT_SPEC = { useSoftmaxLoss: true, batchNormVarianceEpsilon: 1.0e-07, numInstruments: 4, numFilters: 128, numLayers: 33, numRegularConvLayers: 0, dilation: [ [1, 1], [2, 2], [4, 4], [8, 8], [16, 16], [16, 32], [1, 1], [2, 2], [4, 4], [8, 8], [16, 16], [16, 32], [1, 1], [2, 2], [4, 4], [8, 8], [16, 16], [16, 32], [1, 1], [2, 2], [4, 4], [8, 8], [16, 16], [16, 32], [1, 1], [2, 2], [4, 4], [8, 8], [16, 16], [16, 32] ], layers: null, interleaveSplitEveryNLayers: 16, numPointwiseSplits: 4, }; class ConvNet { constructor(spec, vars) { this.residualPeriod = 2; this.outputForResidual = null; this.residualCounter = -1; this.rawVars = null; this.spec = spec; this.rawVars = vars; } dispose() { if (this.rawVars !== null) { tf.dispose(this.rawVars); } if (this.outputForResidual) { this.outputForResidual.dispose(); } } predictFromPianoroll(pianoroll, masks) { return tf.tidy(() => { let featuremaps = this.getConvnetInput(pianoroll, masks); const n = this.spec.layers.length; for (let i = 0; i < n; i++) { this.residualCounter += 1; this.residualSave(featuremaps); let numPointwiseSplits = null; if (this.spec.interleaveSplitEveryNLayers && i > 0 && i < n - 2 && i % (this.spec.interleaveSplitEveryNLayers + 1) === 0) { numPointwiseSplits = this.spec.numPointwiseSplits; } featuremaps = this.applyConvolution(featuremaps, this.spec.layers[i], i, i >= this.spec.numRegularConvLayers, numPointwiseSplits); featuremaps = this.applyResidual(featuremaps, i === 0, i === n - 1, i); featuremaps = this.applyActivation(featuremaps, this.spec.layers[i], i); featuremaps = this.applyPooling(featuremaps, this.spec.layers[i], i); } return this.computePredictions(featuremaps); }); } computePredictions(logits) { if (this.spec.useSoftmaxLoss) { return logits.transpose([0, 1, 3, 2]).softmax().transpose([0, 1, 3, 2]); } return logits.sigmoid(); } residualReset() { this.outputForResidual = null; this.residualCounter = 0; } residualSave(x) { if (this.residualCounter % this.residualPeriod === 1) { this.outputForResidual = x; } } applyResidual(x, isFirst, isLast, i) { if (this.outputForResidual == null) { return x; } if (this.outputForResidual .shape[this.outputForResidual.shape.length - 1] !== x.shape[x.shape.length - 1]) { this.residualReset(); return x; } if (this.residualCounter % this.residualPeriod === 0) { if (!isFirst && !isLast) { x = x.add(this.outputForResidual); } } return x; } getVar(name, layerNum) { const varname = `model/conv${layerNum}/${name}`; return this.rawVars[varname]; } getSepConvVar(name, layerNum) { const varname = `model/conv${layerNum}/SeparableConv2d/${name}`; return this.rawVars[varname]; } getPointwiseSplitVar(name, layerNum, splitNum) { const varname = `model/conv${layerNum}/split_${layerNum}_${splitNum}/${name}`; return this.rawVars[varname]; } applyConvolution(x, layer, i, depthwise, numPointwiseSplits) { if (layer.filters == null) { return x; } const filterShape = layer.filters; const stride = layer.convStride || 1; const padding = layer.convPad ? layer.convPad.toLowerCase() : 'same'; let conv = null; if (depthwise) { const dWeights = this.getSepConvVar('depthwise_weights', i); if (!numPointwiseSplits) { const pWeights = this.getSepConvVar('pointwise_weights', i); const biases = this.getSepConvVar('biases', i); const sepConv = tf.separableConv2d(x, dWeights, pWeights, [stride, stride], padding, layer.dilation, 'NHWC'); conv = sepConv.add(biases); } else { conv = tf.depthwiseConv2d(x, dWeights, [stride, stride], padding, 'NHWC', layer.dilation); const splits = tf.split(conv, numPointwiseSplits, conv.rank - 1); const pointwiseSplits = []; for (let splitIdx = 0; splitIdx < numPointwiseSplits; splitIdx++) { const outputShape = filterShape[3] / numPointwiseSplits; const weights = this.getPointwiseSplitVar('kernel', i, splitIdx); const biases = this.getPointwiseSplitVar('bias', i, splitIdx); const dot = tf.matMul(splits[splitIdx].reshape([-1, outputShape]), weights, false, false); const bias = tf.add(dot, biases); pointwiseSplits.push(bias.reshape([ splits[splitIdx].shape[0], splits[splitIdx].shape[1], splits[splitIdx].shape[2], outputShape ])); } conv = tf.concat(pointwiseSplits, conv.rank - 1); } } else { const weights = this.getVar('weights', i); const stride = layer.convStride || 1; const padding = layer.convPad ? layer.convPad.toLowerCase() : 'same'; conv = tf.conv2d(x, weights, [stride, stride], padding, 'NHWC', [1, 1]); } return this.applyBatchnorm(conv, i); } applyBatchnorm(x, i) { const gammas = this.getVar('gamma', i); const betas = this.getVar('beta', i); const mean = this.getVar('popmean', i); const variance = this.getVar('popvariance', i); if (IS_IOS) { const v = variance.arraySync()[0][0][0]; const stdevs = tf.tensor(v.map((x) => Math.sqrt(x + this.spec.batchNormVarianceEpsilon))); return x.sub(mean).mul(gammas.div(stdevs)).add(betas); } return tf.batchNorm(x, tf.squeeze(mean), tf.squeeze(variance), tf.squeeze(betas), tf.squeeze(gammas), this.spec.batchNormVarianceEpsilon); } applyActivation(x, layer, i) { if (layer.activation === 'identity') { return x; } return x.relu(); } applyPooling(x, layer, i) { if (layer.pooling == null) { return x; } const pooling = layer.pooling; const padding = layer.poolPad ? layer.poolPad.toLowerCase() : 'same'; return tf.maxPool(x, [pooling[0], pooling[1]], [pooling[0], pooling[1]], padding); } getConvnetInput(pianoroll, masks) { pianoroll = tf.scalar(1, 'float32').sub(masks).mul(pianoroll); masks = tf.scalar(1, 'float32').sub(masks); return pianoroll.concat(masks, 3); } } class Coconet { constructor(checkpointURL) { this.spec = null; this.initialized = false; this.checkpointURL = checkpointURL; this.spec = DEFAULT_SPEC; } async initialize() { this.dispose(); const startTime = performance.now(); this.instantiateFromSpec(); const vars = await fetch(`${this.checkpointURL}/weights_manifest.json`) .then((response) => response.json()) .then((manifest) => tf.io.loadWeights(manifest, this.checkpointURL)); this.convnet = new ConvNet(this.spec, vars); this.initialized = true; logging.logWithDuration('Initialized model', startTime, 'Coconet'); } dispose() { if (this.convnet) { this.convnet.dispose(); } this.initialized = false; } isInitialized() { return this.initialized; } instantiateFromSpec() { const nonFinalLayerFilterOuterSizes = 3; const finalTwoLayersFilterOuterSizes = 2; this.spec.layers = []; this.spec.layers.push({ filters: [ nonFinalLayerFilterOuterSizes, nonFinalLayerFilterOuterSizes, this.spec.numInstruments * 2, this.spec.numFilters ] }); for (let i = 0; i < this.spec.numLayers - 3; i++) { this.spec.layers.push({ filters: [ nonFinalLayerFilterOuterSizes, nonFinalLayerFilterOuterSizes, this.spec.numFilters, this.spec.numFilters ], dilation: this.spec.dilation ? this.spec.dilation[i] : null }); } this.spec.layers.push({ filters: [ finalTwoLayersFilterOuterSizes, finalTwoLayersFilterOuterSizes, this.spec.numFilters, this.spec.numFilters ] }); this.spec.layers.push({ filters: [ finalTwoLayersFilterOuterSizes, finalTwoLayersFilterOuterSizes, this.spec.numFilters, this.spec.numInstruments ], activation: 'identity', }); } async infill(sequence, config) { sequences.assertIsRelativeQuantizedSequence(sequence); if (sequence.notes.length === 0) { throw new Error(`NoteSequence ${sequence.id} does not have any notes to infill.`); } const numSteps = sequence.totalQuantizedSteps || sequence.notes[sequence.notes.length - 1].quantizedEndStep; const pianoroll = sequenceToPianoroll(sequence, numSteps); let temperature = 0.99; let numIterations = 96; let outerMasks; if (config) { numIterations = config.numIterations || numIterations; temperature = config.temperature || temperature; outerMasks = this.getCompletionMaskFromInput(config.infillMask, pianoroll); } else { outerMasks = this.getCompletionMask(pianoroll); } const samples = await this.run(pianoroll, numIterations, temperature, outerMasks); const outputSequence = pianorollToSequence(samples, numSteps); pianoroll.dispose(); samples.dispose(); outerMasks.dispose(); return outputSequence; } async run(pianorolls, numSteps, temperature, outerMasks) { return this.gibbs(pianorolls, numSteps, temperature, outerMasks); } getCompletionMaskFromInput(masks, pianorolls) { if (!masks) { return this.getCompletionMask(pianorolls); } else { const buffer = tf.buffer([pianorolls.shape[1], 4]); for (let i = 0; i < masks.length; i++) { buffer.set(1, masks[i].step, masks[i].voice); } return tf.tidy(() => { return buffer.toTensor() .expandDims(1) .tile([1, NUM_PITCHES, 1]) .expandDims(0); }); } } getCompletionMask(pianorolls) { return tf.tidy(() => { const isEmpty = pianorolls.sum(2, true).equal(tf.scalar(0, 'float32')); return tf.cast(isEmpty, 'float32').add(tf.zerosLike(pianorolls)); }); } async gibbs(pianorolls, numSteps, temperature, outerMasks) { const numStepsTensor = tf.scalar(numSteps, 'float32'); let pianoroll = pianorolls.clone(); for (let s = 0; s < numSteps; s++) { const pm = this.yaoSchedule(s, numStepsTensor); const innerMasks = this.bernoulliMask(pianoroll.shape, pm, outerMasks); await tf.nextFrame(); const predictions = tf.tidy(() => { return this.convnet.predictFromPianoroll(pianoroll, innerMasks); }); await tf.nextFrame(); pianoroll = tf.tidy(() => { const samples = this.samplePredictions(predictions, temperature); const updatedPianorolls = tf.where(tf.cast(innerMasks, 'bool'), samples, pianoroll); pianoroll.dispose(); predictions.dispose(); innerMasks.dispose(); pm.dispose(); return updatedPianorolls; }); await tf.nextFrame(); } numStepsTensor.dispose(); return pianoroll; } yaoSchedule(i, n) { return tf.tidy(() => { const pmin = tf.scalar(0.1, 'float32'); const pmax = tf.scalar(0.9, 'float32'); const alpha = tf.scalar(0.7, 'float32'); const wat = pmax.sub(pmin).mul(tf.scalar(i, 'float32')).div(n); const secondArg = pmax.sub(wat).div(alpha); return pmin.reshape([1]).concat(secondArg.reshape([1])).max(); }); } bernoulliMask(shape, pm, outerMasks) { return tf.tidy(() => { const [bb, tt, pp, ii] = shape; const probs = tf.tile(tf.randomUniform([bb, tt, 1, ii], 0, 1, 'float32'), [1, 1, pp, 1]); const masks = probs.less(pm); return tf.cast(masks, 'float32').mul(outerMasks); }); } samplePredictions(predictions, temperature) { return tf.tidy(() => { predictions = tf.pow(predictions, tf.scalar(1 / temperature, 'float32')); const cmf = tf.cumsum(predictions, 2, false, false); const totalMasses = cmf.slice([0, 0, cmf.shape[2] - 1, 0], [cmf.shape[0], cmf.shape[1], 1, cmf.shape[3]]); const u = tf.randomUniform(totalMasses.shape, 0, 1, 'float32'); const i = u.mul(totalMasses).less(cmf).argMax(2); return tf.oneHot(i.flatten(), predictions.shape[2], 1, 0) .reshape([ predictions.shape[0], predictions.shape[1], predictions.shape[3], predictions.shape[2] ]) .transpose([0, 1, 3, 2]); }); } } export { Coconet }; //# sourceMappingURL=model.js.map