UNPKG

@magenta/music

Version:

Make music with machine learning, in the browser.

343 lines 14.2 kB
import * as tf from '@tensorflow/tfjs'; import { fetch } from '../core/compat/global'; import { logging } from '../core'; const DATA_TIME_QUANTIZE_RATE = 31.25; const DATA_MAX_DISCRETE_TIMES = 32; const RNN_NLAYERS = 2; const RNN_NUNITS = 128; const NUM_BUTTONS = 8; const NUM_PIANOKEYS = 88; function createZeroState() { const state = { c: [], h: [] }; for (let i = 0; i < RNN_NLAYERS; ++i) { state.c.push(tf.zeros([1, RNN_NUNITS], 'float32')); state.h.push(tf.zeros([1, RNN_NUNITS], 'float32')); } return state; } function disposeState(state) { for (let i = 0; i < RNN_NLAYERS; ++i) { state.c[i].dispose(); state.h[i].dispose(); } } function sampleLogits(logits, temperature, seed) { temperature = temperature !== undefined ? temperature : 1.; if (temperature < 0. || temperature > 1.) { throw new Error('Invalid temperature specified'); } let result; if (temperature === 0) { result = tf.argMax(logits, 0); } else { if (temperature < 1) { logits = tf.div(logits, tf.scalar(temperature, 'float32')); } const scores = tf.reshape(tf.softmax(logits, 0), [1, -1]); const sample = tf.multinomial(scores, 1, seed, true); result = tf.reshape(sample, []); } return result; } class PianoGenieBase { constructor(checkpointURL) { this.checkpointURL = checkpointURL; this.initialized = false; } isInitialized() { return this.initialized; } async initialize(staticVars) { if (this.initialized) { this.dispose(); } if (this.checkpointURL === undefined && staticVars === undefined) { throw new Error('Need to specify either URI or static variables'); } if (staticVars === undefined) { const vars = await fetch(`${this.checkpointURL}/weights_manifest.json`) .then((response) => response.json()) .then((manifest) => tf.io.loadWeights(manifest, this.checkpointURL)); this.modelVars = vars; } else { this.modelVars = staticVars; } this.decLSTMCells = []; this.decForgetBias = tf.scalar(1, 'float32'); for (let i = 0; i < RNN_NLAYERS; ++i) { const cellPrefix = `phero_model/decoder/rnn/rnn/multi_rnn_cell/cell_${i}/lstm_cell/`; this.decLSTMCells.push((data, c, h) => tf.basicLSTMCell(this.decForgetBias, this.modelVars[cellPrefix + 'kernel'], this.modelVars[cellPrefix + 'bias'], data, c, h)); } this.resetState(); this.initialized = true; this.next(0); this.resetState(); } getRnnInputFeats() { const feats = tf.tidy(() => { const buttonTensor = tf.tensor1d([this.button], 'float32'); const buttonScaled = tf.sub(tf.mul(2., tf.div(buttonTensor, NUM_BUTTONS - 1)), 1); return buttonScaled.as1D(); }); return feats; } next(button, temperature, seed) { const sampleFunc = (logits) => { return sampleLogits(logits, temperature, seed); }; return this.nextWithCustomSamplingFunction(button, sampleFunc); } nextFromKeyList(button, keyList, temperature, seed) { const sampleFunc = (logits) => { const keySubsetTensor = tf.tensor1d(keyList, 'int32'); logits = tf.gather(logits, keySubsetTensor); let result = sampleLogits(logits, temperature, seed); const result1d = tf.gather(keySubsetTensor, tf.reshape(result, [1])); result = tf.reshape(result1d, []); return result; }; return this.nextWithCustomSamplingFunction(button, sampleFunc); } nextFromKeyWhitelist(button, keyList, temperature, seed) { logging.log('nextFromKeyWhitelist() is deprecated, and will be removed in a future \ version. Please use nextFromKeyList() instead', 'PianoGenie', 5); return this.nextFromKeyList(button, keyList, temperature, seed); } nextWithCustomSamplingFunction(button, sampleFunc) { const lastState = this.lastState; this.button = button; const rnnInput = this.getRnnInputFeats(); const [state, output] = this.evaluateModelAndSample(rnnInput, lastState, sampleFunc); rnnInput.dispose(); disposeState(this.lastState); this.lastState = state; return output; } evaluateModelAndSample(rnnInput1d, initialState, sampleFunc) { if (!this.initialized) { throw new Error('Model is not initialized.'); } const [finalState, output] = tf.tidy(() => { let rnnInput = tf.matMul(tf.expandDims(rnnInput1d, 0), this.modelVars['phero_model/decoder/rnn_input/dense/kernel']); rnnInput = tf.add(rnnInput, this.modelVars['phero_model/decoder/rnn_input/dense/bias']); const [c, h] = tf.multiRNNCell(this.decLSTMCells, rnnInput, initialState.c, initialState.h); const finalState = { c, h }; let logits = tf.matMul(h[RNN_NLAYERS - 1], this.modelVars['phero_model/decoder/pitches/dense/kernel']); logits = tf.add(logits, this.modelVars['phero_model/decoder/pitches/dense/bias']); const logits1D = tf.reshape(logits, [NUM_PIANOKEYS]); const sample = sampleFunc(logits1D); const output = sample.dataSync()[0]; return [finalState, output]; }); return [finalState, output]; } resetState() { if (this.lastState !== undefined) { disposeState(this.lastState); } this.lastState = createZeroState(); } dispose() { if (!this.initialized) { return; } Object.keys(this.modelVars).forEach(name => this.modelVars[name].dispose()); this.decForgetBias.dispose(); disposeState(this.lastState); this.initialized = false; } } class PianoGenieAutoregressiveDeltaTime extends PianoGenieBase { getRnnInputFeats() { const feats = tf.tidy(() => { const featsArr = [super.getRnnInputFeats()]; const lastOutput = this.lastOutput; const lastTime = this.lastTime; const time = this.time; let deltaTime; if (this.deltaTimeOverride === undefined) { deltaTime = (time.getTime() - lastTime.getTime()) / 1000; } else { deltaTime = this.deltaTimeOverride; this.deltaTimeOverride = undefined; } const lastOutputTensor = tf.scalar(lastOutput, 'int32'); const lastOutputInc = tf.addStrict(lastOutputTensor, tf.scalar(1, 'int32')); const lastOutputOh = tf.cast(tf.oneHot(lastOutputInc, NUM_PIANOKEYS + 1), 'float32'); featsArr.push(lastOutputOh); const deltaTimeTensor = tf.scalar(deltaTime, 'float32'); const deltaTimeBin = tf.round(tf.mul(deltaTimeTensor, DATA_TIME_QUANTIZE_RATE)); const deltaTimeTrunc = tf.minimum(deltaTimeBin, DATA_MAX_DISCRETE_TIMES); const deltaTimeInt = tf.cast(tf.add(deltaTimeTrunc, 1e-4), 'int32'); const deltaTimeOh = tf.oneHot(deltaTimeInt, DATA_MAX_DISCRETE_TIMES + 1); const deltaTimeOhFloat = tf.cast(deltaTimeOh, 'float32'); featsArr.push(deltaTimeOhFloat); this.lastTime = time; return tf.concat1d(featsArr); }); return feats; } nextWithCustomSamplingFunction(button, sampleFunc) { this.time = new Date(); const output = super.nextWithCustomSamplingFunction(button, sampleFunc); this.lastOutput = output; this.lastTime = this.time; return output; } overrideLastOutput(lastOutput) { this.lastOutput = lastOutput; } overrideDeltaTime(deltaTime) { this.deltaTimeOverride = deltaTime; } resetState() { super.resetState(); this.lastOutput = -1; this.lastTime = new Date(); this.lastTime.setSeconds(this.lastTime.getSeconds() - 100000); this.time = new Date(); } } var PitchClass; (function (PitchClass) { PitchClass[PitchClass["None"] = 0] = "None"; PitchClass[PitchClass["C"] = 1] = "C"; PitchClass[PitchClass["Cs"] = 2] = "Cs"; PitchClass[PitchClass["D"] = 3] = "D"; PitchClass[PitchClass["Eb"] = 4] = "Eb"; PitchClass[PitchClass["E"] = 5] = "E"; PitchClass[PitchClass["F"] = 6] = "F"; PitchClass[PitchClass["Fs"] = 7] = "Fs"; PitchClass[PitchClass["G"] = 8] = "G"; PitchClass[PitchClass["Ab"] = 9] = "Ab"; PitchClass[PitchClass["A"] = 10] = "A"; PitchClass[PitchClass["Bb"] = 11] = "Bb"; PitchClass[PitchClass["B"] = 12] = "B"; })(PitchClass || (PitchClass = {})); var ChordFamily; (function (ChordFamily) { ChordFamily[ChordFamily["None"] = 0] = "None"; ChordFamily[ChordFamily["Maj"] = 1] = "Maj"; ChordFamily[ChordFamily["Min"] = 2] = "Min"; ChordFamily[ChordFamily["Aug"] = 3] = "Aug"; ChordFamily[ChordFamily["Dim"] = 4] = "Dim"; ChordFamily[ChordFamily["Seven"] = 5] = "Seven"; ChordFamily[ChordFamily["Maj7"] = 6] = "Maj7"; ChordFamily[ChordFamily["Min7"] = 7] = "Min7"; ChordFamily[ChordFamily["Min7b5"] = 8] = "Min7b5"; })(ChordFamily || (ChordFamily = {})); class PianoGenieAutoregressiveDeltaTimeChord extends PianoGenieAutoregressiveDeltaTime { getRnnInputFeats() { const feats = tf.tidy(() => { const feats1d = super.getRnnInputFeats(); const featsArr = [feats1d]; const chordRootTensor = tf.scalar(this.chordRoot, 'int32'); const chordRootTensorSubOne = tf.subStrict(chordRootTensor, tf.scalar(1, 'int32')); const chordRootTensorOh = tf.cast(tf.oneHot(chordRootTensorSubOne, 12), 'float32'); featsArr.push(chordRootTensorOh); const chordFamilyTensor = tf.scalar(this.chordFamily, 'int32'); const chordFamilyTensorSubOne = tf.subStrict(chordFamilyTensor, tf.scalar(1, 'int32')); const chordFamilyTensorOh = tf.cast(tf.oneHot(chordFamilyTensorSubOne, 8), 'float32'); featsArr.push(chordFamilyTensorOh); return tf.concat1d(featsArr); }); return feats; } setChordRoot(chordRoot) { this.chordRoot = chordRoot; } setChordFamily(chordFamily) { this.chordFamily = chordFamily; } resetState() { super.resetState(); this.chordRoot = PitchClass.None; this.chordFamily = ChordFamily.None; } } class PianoGenieAutoregressiveDeltaTimeKeysig extends PianoGenieAutoregressiveDeltaTime { getRnnInputFeats() { const feats = tf.tidy(() => { const feats1d = super.getRnnInputFeats(); const featsArr = [feats1d]; const keySigTensor = tf.scalar(this.keySignature, 'int32'); const keySigTensorSubOne = tf.subStrict(keySigTensor, tf.scalar(1, 'int32')); const keySigTensorOh = tf.cast(tf.oneHot(keySigTensorSubOne, 12), 'float32'); featsArr.push(keySigTensorOh); return tf.concat1d(featsArr); }); return feats; } setKeySignature(keySignature) { this.keySignature = keySignature; } resetState() { super.resetState(); this.keySignature = PitchClass.None; } } class PianoGenieAutoregressiveDeltaTimeKeysigChord extends PianoGenieAutoregressiveDeltaTimeKeysig { getRnnInputFeats() { const feats = tf.tidy(() => { const feats1d = super.getRnnInputFeats(); const featsArr = [feats1d]; const chordRootTensor = tf.scalar(this.chordRoot, 'int32'); const chordRootTensorSubOne = tf.subStrict(chordRootTensor, tf.scalar(1, 'int32')); const chordRootTensorOh = tf.cast(tf.oneHot(chordRootTensorSubOne, 12), 'float32'); featsArr.push(chordRootTensorOh); const chordFamilyTensor = tf.scalar(this.chordFamily, 'int32'); const chordFamilyTensorSubOne = tf.subStrict(chordFamilyTensor, tf.scalar(1, 'int32')); const chordFamilyTensorOh = tf.cast(tf.oneHot(chordFamilyTensorSubOne, 8), 'float32'); featsArr.push(chordFamilyTensorOh); return tf.concat1d(featsArr); }); return feats; } setChordRoot(chordRoot) { this.chordRoot = chordRoot; } setChordFamily(chordFamily) { this.chordFamily = chordFamily; } resetState() { super.resetState(); this.chordRoot = PitchClass.None; this.chordFamily = ChordFamily.None; } } class PianoGenieAutoregressiveDeltaTimeKeysigChordFamily extends PianoGenieAutoregressiveDeltaTimeKeysig { getRnnInputFeats() { const feats = tf.tidy(() => { const feats1d = super.getRnnInputFeats(); const featsArr = [feats1d]; const chordFamilyTensor = tf.scalar(this.chordFamily, 'int32'); const chordFamilyTensorSubOne = tf.subStrict(chordFamilyTensor, tf.scalar(1, 'int32')); const chordFamilyTensorOh = tf.cast(tf.oneHot(chordFamilyTensorSubOne, 8), 'float32'); featsArr.push(chordFamilyTensorOh); return tf.concat1d(featsArr); }); return feats; } setChordFamily(chordFamily) { this.chordFamily = chordFamily; } resetState() { super.resetState(); this.chordFamily = ChordFamily.None; } } class PianoGenie extends PianoGenieAutoregressiveDeltaTime { } class PianoGenieChord extends PianoGenieAutoregressiveDeltaTimeChord { } class PianoGenieKeysig extends PianoGenieAutoregressiveDeltaTimeKeysig { } class PianoGenieKeysigChord extends PianoGenieAutoregressiveDeltaTimeKeysigChord { } class PianoGenieKeysigChordFamily extends PianoGenieAutoregressiveDeltaTimeKeysigChordFamily { } export { PianoGenie, PianoGenieChord, PianoGenieKeysig, PianoGenieKeysigChord, PianoGenieKeysigChordFamily, PitchClass, ChordFamily }; //# sourceMappingURL=model.js.map