museaikit
Version:
A powerful music-focused AI toolkit
636 lines • 27.2 kB
JavaScript
import * as tf from '@tensorflow/tfjs';
import { NoteSequence } from '../protobuf/index';
import * as constants from './constants';
import { DEFAULT_DRUM_PITCH_CLASSES } from './constants';
import * as logging from './logging';
import { Melody, MelodyRhythm, MelodyShape } from './melodies';
import * as performance from './performance';
import * as sequences from './sequences';
export { DEFAULT_DRUM_PITCH_CLASSES };
export function converterFromSpec(spec) {
switch (spec.type) {
case 'MelodyConverter':
return new MelodyConverter(spec.args);
case 'MelodyRhythmConverter':
return new MelodyRhythmConverter(spec.args);
case 'MelodyShapeConverter':
return new MelodyShapeConverter(spec.args);
case 'DrumsConverter':
return new DrumsConverter(spec.args);
case 'DrumRollConverter':
return new DrumRollConverter(spec.args);
case 'TrioConverter':
return new TrioConverter(spec.args);
case 'TrioRhythmConverter':
return new TrioRhythmConverter(spec.args);
case 'DrumsOneHotConverter':
return new DrumsOneHotConverter(spec.args);
case 'MultitrackConverter':
return new MultitrackConverter(spec.args);
case 'GrooveConverter':
return new GrooveConverter(spec.args);
default:
throw new Error(`Unknown DataConverter type: ${spec}`);
}
}
export class DataConverter {
numSteps;
numSegments;
NUM_SPLITS = 0;
SEGMENTED_BY_TRACK = false;
constructor(args) {
this.numSteps = args.numSteps;
this.numSegments = args.numSegments;
}
tensorSteps(tensor) {
return tf.scalar(tensor.shape[0], 'int32');
}
}
export class DrumsConverter extends DataConverter {
pitchClasses;
pitchToClass;
depth;
endTensor;
constructor(args) {
super(args);
this.pitchClasses = args.pitchClasses || DEFAULT_DRUM_PITCH_CLASSES;
this.pitchToClass = new Map();
for (let c = 0; c < this.pitchClasses.length; ++c) {
this.pitchClasses[c].forEach((p) => {
this.pitchToClass.set(p, c);
});
}
this.depth = this.pitchClasses.length + 1;
}
toTensor(noteSequence) {
sequences.assertIsQuantizedSequence(noteSequence);
const numSteps = this.numSteps || noteSequence.totalQuantizedSteps;
const drumRoll = tf.buffer([numSteps, this.pitchClasses.length + 1], 'int32');
for (let i = 0; i < numSteps; ++i) {
drumRoll.set(1, i, -1);
}
noteSequence.notes.forEach((note) => {
drumRoll.set(1, note.quantizedStartStep, this.pitchToClass.get(note.pitch));
drumRoll.set(0, note.quantizedStartStep, -1);
});
return drumRoll.toTensor();
}
async toNoteSequence(oh, stepsPerQuarter, qpm) {
const noteSequence = sequences.createQuantizedNoteSequence(stepsPerQuarter, qpm);
const labelsTensor = oh.argMax(1);
const labels = await labelsTensor.data();
labelsTensor.dispose();
for (let s = 0; s < labels.length; ++s) {
for (let p = 0; p < this.pitchClasses.length; p++) {
if (labels[s] >> p & 1) {
noteSequence.notes.push(NoteSequence.Note.create({
pitch: this.pitchClasses[p][0],
quantizedStartStep: s,
quantizedEndStep: s + 1,
isDrum: true
}));
}
}
}
noteSequence.totalQuantizedSteps = labels.length;
return noteSequence;
}
}
export class DrumRollConverter extends DrumsConverter {
async toNoteSequence(roll, stepsPerQuarter, qpm) {
const noteSequence = sequences.createQuantizedNoteSequence(stepsPerQuarter, qpm);
const flatRoll = await roll.data();
for (let s = 0; s < roll.shape[0]; ++s) {
const pitches = flatRoll.slice(s * this.pitchClasses.length, (s + 1) * this.pitchClasses.length);
for (let p = 0; p < pitches.length; ++p) {
if (pitches[p]) {
noteSequence.notes.push(NoteSequence.Note.create({
pitch: this.pitchClasses[p][0],
quantizedStartStep: s,
quantizedEndStep: s + 1,
isDrum: true
}));
}
}
}
noteSequence.totalQuantizedSteps = roll.shape[0];
return noteSequence;
}
}
export class DrumsOneHotConverter extends DrumsConverter {
constructor(args) {
super(args);
this.depth = Math.pow(2, this.pitchClasses.length);
}
toTensor(noteSequence) {
sequences.assertIsRelativeQuantizedSequence(noteSequence);
const numSteps = this.numSteps || noteSequence.totalQuantizedSteps;
const labels = Array(numSteps).fill(0);
for (const { pitch, quantizedStartStep } of noteSequence.notes) {
labels[quantizedStartStep] += Math.pow(2, this.pitchToClass.get(pitch));
}
return tf.tidy(() => tf.oneHot(tf.tensor1d(labels, 'int32'), this.depth));
}
}
export class MelodyConverter extends DataConverter {
minPitch;
maxPitch;
ignorePolyphony;
depth;
endTensor;
NOTE_OFF = 1;
FIRST_PITCH = 2;
constructor(args) {
super(args);
this.minPitch = args.minPitch;
this.maxPitch = args.maxPitch;
this.ignorePolyphony = args.ignorePolyphony;
this.depth = args.maxPitch - args.minPitch + 1 + this.FIRST_PITCH;
}
toTensor(noteSequence) {
const melody = Melody.fromNoteSequence(noteSequence, this.minPitch, this.maxPitch, this.ignorePolyphony, this.numSteps);
return tf.tidy(() => tf.oneHot(tf.tensor(melody.events, [melody.events.length], 'int32'), this.depth));
}
async toNoteSequence(oh, stepsPerQuarter, qpm) {
const labelsTensor = oh.argMax(1);
const labels = await labelsTensor.data();
labelsTensor.dispose();
const melody = new Melody(labels, this.minPitch, this.maxPitch);
return melody.toNoteSequence(stepsPerQuarter, qpm);
}
}
class MelodyControlConverter extends DataConverter {
minPitch;
maxPitch;
ignorePolyphony;
melodyControl;
depth;
endTensor;
constructor(args, melodyControl) {
super(args);
this.minPitch = args.minPitch;
this.maxPitch = args.maxPitch;
this.ignorePolyphony = args.ignorePolyphony;
this.melodyControl = melodyControl;
this.depth = melodyControl.depth;
}
toTensor(noteSequence) {
const melody = Melody.fromNoteSequence(noteSequence, this.minPitch, this.maxPitch, this.ignorePolyphony, this.numSteps);
return this.melodyControl.extract(melody);
}
}
export class MelodyRhythmConverter extends MelodyControlConverter {
constructor(args) {
super(args, new MelodyRhythm());
}
async toNoteSequence(tensor, stepsPerQuarter, qpm) {
const noteSequence = sequences.createQuantizedNoteSequence(stepsPerQuarter, qpm);
const rhythm = await tensor.data();
for (let s = 0; s < rhythm.length; ++s) {
if (rhythm[s]) {
noteSequence.notes.push(NoteSequence.Note.create({
pitch: DEFAULT_DRUM_PITCH_CLASSES[1][0],
quantizedStartStep: s,
quantizedEndStep: s + 1,
isDrum: true
}));
}
}
noteSequence.totalQuantizedSteps = rhythm.length;
return noteSequence;
}
}
export class MelodyShapeConverter extends MelodyControlConverter {
constructor(args) {
super(args, new MelodyShape());
}
async toNoteSequence(oh, stepsPerQuarter, qpm) {
const noteSequence = sequences.createQuantizedNoteSequence(stepsPerQuarter, qpm);
const shapeTensor = oh.argMax(1);
const shape = await shapeTensor.data();
shapeTensor.dispose();
let pitch = Math.round((this.maxPitch + this.minPitch) / 2);
for (let s = 0; s < shape.length; ++s) {
switch (shape[s]) {
case 0:
pitch -= 1;
if (pitch < this.minPitch) {
pitch = this.minPitch;
logging.log('Pitch range exceeded when creating NoteSequence from shape.', 'MelodyShapeConverter');
}
break;
case 2:
pitch += 1;
if (pitch > this.maxPitch) {
pitch = this.maxPitch;
logging.log('Pitch range exceeded when creating NoteSequence from shape.', 'MelodyShapeConverter');
}
break;
default:
break;
}
noteSequence.notes.push(NoteSequence.Note.create({ pitch, quantizedStartStep: s, quantizedEndStep: s + 1 }));
}
noteSequence.totalQuantizedSteps = shape.length;
return noteSequence;
}
}
export class TrioConverter extends DataConverter {
melConverter;
bassConverter;
drumsConverter;
depth;
endTensor;
NUM_SPLITS = 3;
MEL_PROG_RANGE = [0, 31];
BASS_PROG_RANGE = [32, 39];
constructor(args) {
super(args);
args.melArgs.numSteps = args.numSteps;
args.bassArgs.numSteps = args.numSteps;
args.drumsArgs.numSteps = args.numSteps;
this.melConverter = new MelodyConverter(args.melArgs);
this.bassConverter = new MelodyConverter(args.bassArgs);
this.drumsConverter = new DrumsOneHotConverter(args.drumsArgs);
this.depth =
(this.melConverter.depth + this.bassConverter.depth +
this.drumsConverter.depth);
}
toTensor(noteSequence) {
sequences.assertIsQuantizedSequence(noteSequence);
const melSeq = sequences.clone(noteSequence);
const bassSeq = sequences.clone(noteSequence);
const drumsSeq = sequences.clone(noteSequence);
melSeq.notes = noteSequence.notes.filter(n => (!n.isDrum && n.program >= this.MEL_PROG_RANGE[0] &&
n.program <= this.MEL_PROG_RANGE[1]));
bassSeq.notes = noteSequence.notes.filter(n => (!n.isDrum && n.program >= this.BASS_PROG_RANGE[0] &&
n.program <= this.BASS_PROG_RANGE[1]));
drumsSeq.notes = noteSequence.notes.filter(n => n.isDrum);
return tf.tidy(() => tf.concat([
this.melConverter.toTensor(melSeq),
this.bassConverter.toTensor(bassSeq),
this.drumsConverter.toTensor(drumsSeq)
], -1));
}
async toNoteSequence(th, stepsPerQuarter, qpm) {
const ohs = tf.split(th, [
this.melConverter.depth, this.bassConverter.depth,
this.drumsConverter.depth
], -1);
const ns = await this.melConverter.toNoteSequence(ohs[0], stepsPerQuarter, qpm);
ns.notes.forEach(n => {
n.instrument = 0;
n.program = 0;
});
const bassNs = await this.bassConverter.toNoteSequence(ohs[1], stepsPerQuarter, qpm);
ns.notes.push(...bassNs.notes.map(n => {
n.instrument = 1;
n.program = this.BASS_PROG_RANGE[0];
return n;
}));
const drumsNs = await this.drumsConverter.toNoteSequence(ohs[2], stepsPerQuarter, qpm);
ns.notes.push(...drumsNs.notes.map(n => {
n.instrument = 2;
return n;
}));
ohs.forEach(oh => oh.dispose());
return ns;
}
}
export class TrioRhythmConverter extends DataConverter {
trioConverter;
depth;
endTensor;
NUM_SPLITS = 3;
constructor(args) {
super(args);
this.trioConverter = new TrioConverter(args);
this.depth = 3;
}
toTensor(noteSequence) {
return tf.tidy(() => {
const trioTensor = this.trioConverter.toTensor(noteSequence);
const instrumentTensors = tf.split(trioTensor, [
this.trioConverter.melConverter.depth,
this.trioConverter.bassConverter.depth,
this.trioConverter.drumsConverter.depth
], 1);
const melodyEvents = tf.argMax(instrumentTensors[0], 1);
const bassEvents = tf.argMax(instrumentTensors[1], 1);
const drumsEvents = tf.argMax(instrumentTensors[2], 1);
const melodyRhythm = tf.greater(melodyEvents, 1);
const bassRhythm = tf.greater(bassEvents, 1);
const drumsRhythm = tf.greater(drumsEvents, 0);
return tf.stack([melodyRhythm, bassRhythm, drumsRhythm], 1);
});
}
async toNoteSequence(tensor, stepsPerQuarter, qpm) {
const rhythmTensors = tf.split(tensor, 3, 1);
const rhythms = await Promise.all(rhythmTensors.map(t => t.data()));
const noteSequence = sequences.createQuantizedNoteSequence(stepsPerQuarter, qpm);
for (let s = 0; s < this.numSteps; ++s) {
if (rhythms[0][s]) {
noteSequence.notes.push(NoteSequence.Note.create({
pitch: 72,
quantizedStartStep: s,
quantizedEndStep: s + 1,
instrument: 0,
program: 0,
}));
}
if (rhythms[1][s]) {
noteSequence.notes.push(NoteSequence.Note.create({
pitch: 36,
quantizedStartStep: s,
quantizedEndStep: s + 1,
instrument: 1,
program: 32,
}));
}
if (rhythms[2][s]) {
noteSequence.notes.push(NoteSequence.Note.create({
pitch: DEFAULT_DRUM_PITCH_CLASSES[1][0],
quantizedStartStep: s,
quantizedEndStep: s + 1,
instrument: 2,
isDrum: true
}));
}
}
noteSequence.totalQuantizedSteps = this.numSteps;
return noteSequence;
}
}
export class MultitrackConverter extends DataConverter {
SEGMENTED_BY_TRACK = true;
stepsPerQuarter;
totalSteps;
numVelocityBins;
minPitch;
maxPitch;
numPitches;
performanceEventDepth;
numPrograms;
endToken;
depth;
endTensor;
constructor(args) {
super(args);
this.stepsPerQuarter = args.stepsPerQuarter;
this.totalSteps = args.totalSteps;
this.numVelocityBins = args.numVelocityBins;
this.minPitch = args.minPitch ? args.minPitch : constants.MIN_MIDI_PITCH;
this.maxPitch = args.maxPitch ? args.maxPitch : constants.MAX_MIDI_PITCH;
this.numPitches = this.maxPitch - this.minPitch + 1;
this.performanceEventDepth =
2 * this.numPitches + this.totalSteps + this.numVelocityBins;
this.numPrograms =
constants.MAX_MIDI_PROGRAM - constants.MIN_MIDI_PROGRAM + 2;
this.endToken = this.performanceEventDepth + this.numPrograms;
this.depth = this.endToken + 1;
this.endTensor = tf.tidy(() => tf.oneHot(tf.tensor1d([this.endToken], 'int32'), this.depth)
.as1D());
}
trackToTensor(track) {
const maxEventsPerTrack = this.numSteps / this.numSegments;
let tokens = undefined;
if (track) {
while (track.events.length > maxEventsPerTrack - 2) {
track.events.pop();
}
tokens = tf.buffer([track.events.length + 2], 'int32');
tokens.set(this.performanceEventDepth +
(track.isDrum ? this.numPrograms - 1 : track.program), 0);
track.events.forEach((event, index) => {
switch (event.type) {
case 'note-on':
tokens.set(event.pitch - this.minPitch, index + 1);
break;
case 'note-off':
tokens.set(this.numPitches + event.pitch - this.minPitch, index + 1);
break;
case 'time-shift':
tokens.set(2 * this.numPitches + event.steps - 1, index + 1);
break;
case 'velocity-change':
tokens.set(2 * this.numPitches + this.totalSteps + event.velocityBin - 1, index + 1);
break;
default:
throw new Error(`Unrecognized performance event: ${event}`);
}
});
tokens.set(this.endToken, track.events.length + 1);
}
else {
tokens = tf.buffer([1], 'int32', new Int32Array([this.endToken]));
}
return tf.tidy(() => {
const oh = tf.oneHot(tokens.toTensor(), this.depth);
return oh.pad([[0, maxEventsPerTrack - oh.shape[0]], [0, 0]]);
});
}
toTensor(noteSequence) {
sequences.assertIsRelativeQuantizedSequence(noteSequence);
if (noteSequence.quantizationInfo.stepsPerQuarter !==
this.stepsPerQuarter) {
throw new Error(`Steps per quarter note mismatch: ${noteSequence.quantizationInfo.stepsPerQuarter} != ${this.stepsPerQuarter}`);
}
const seq = sequences.clone(noteSequence);
seq.notes = noteSequence.notes.filter(note => note.pitch >= this.minPitch && note.pitch <= this.maxPitch);
const instruments = new Set(seq.notes.map(note => note.instrument));
const tracks = Array.from(instruments)
.map(instrument => performance.Performance.fromNoteSequence(seq, this.totalSteps, this.numVelocityBins, instrument));
const sortedTracks = tracks.sort((a, b) => b.isDrum ? -1 : (a.isDrum ? 1 : a.program - b.program));
while (sortedTracks.length > this.numSegments) {
sortedTracks.pop();
}
sortedTracks.forEach((track) => track.setNumSteps(this.totalSteps));
while (sortedTracks.length < this.numSegments) {
sortedTracks.push(undefined);
}
return tf.tidy(() => tf.concat(sortedTracks.map((track) => this.trackToTensor(track)), 0));
}
tokensToTrack(tokens) {
const idx = tokens.indexOf(this.endToken);
const endIndex = idx >= 0 ? idx : tokens.length;
const trackTokens = tokens.slice(0, endIndex);
const eventTokens = trackTokens.filter((token) => token < this.performanceEventDepth);
const programTokens = trackTokens.filter((token) => token >= this.performanceEventDepth);
const [program, isDrum] = programTokens.length ?
(programTokens[0] - this.performanceEventDepth < this.numPrograms - 1 ?
[programTokens[0] - this.performanceEventDepth, false] :
[0, true]) :
[0, false];
const events = Array.from(eventTokens).map((token) => {
if (token < this.numPitches) {
return { type: 'note-on', pitch: this.minPitch + token };
}
else if (token < 2 * this.numPitches) {
return {
type: 'note-off',
pitch: this.minPitch + token - this.numPitches
};
}
else if (token < 2 * this.numPitches + this.totalSteps) {
return {
type: 'time-shift',
steps: token - 2 * this.numPitches + 1
};
}
else if (token <
2 * this.numPitches + this.totalSteps + this.numVelocityBins) {
return {
type: 'velocity-change',
velocityBin: token - 2 * this.numPitches - this.totalSteps + 1
};
}
else {
throw new Error(`Invalid performance event token: ${token}`);
}
});
return new performance.Performance(events, this.totalSteps, this.numVelocityBins, program, isDrum);
}
async toNoteSequence(oh, stepsPerQuarter = this.stepsPerQuarter, qpm) {
const noteSequence = sequences.createQuantizedNoteSequence(stepsPerQuarter, qpm);
noteSequence.totalQuantizedSteps = this.totalSteps;
const tensors = tf.tidy(() => tf.split(oh.argMax(1), this.numSegments));
const tracks = await Promise.all(tensors.map(async (tensor) => {
const tokens = await tensor.data();
const track = this.tokensToTrack(tokens);
tensor.dispose();
return track;
}));
tracks.forEach((track, instrument) => {
track.setNumSteps(this.totalSteps);
noteSequence.notes.push(...track.toNoteSequence(instrument).notes);
});
return noteSequence;
}
}
export class GrooveConverter extends DataConverter {
stepsPerQuarter;
humanize;
tapify;
pitchClasses;
pitchToClass;
depth;
endTensor;
splitInstruments;
TAPIFY_CHANNEL = 3;
constructor(args) {
super(args);
this.stepsPerQuarter =
args.stepsPerQuarter || constants.DEFAULT_STEPS_PER_QUARTER;
this.pitchClasses = args.pitchClasses || DEFAULT_DRUM_PITCH_CLASSES;
this.pitchToClass = new Map();
for (let c = 0; c < this.pitchClasses.length; ++c) {
this.pitchClasses[c].forEach((p) => {
this.pitchToClass.set(p, c);
});
}
this.humanize = args.humanize || false;
this.tapify = args.tapify || false;
this.splitInstruments = args.splitInstruments || false;
this.depth = 3;
}
toTensor(ns) {
const qns = sequences.isRelativeQuantizedSequence(ns) ?
ns :
sequences.quantizeNoteSequence(ns, this.stepsPerQuarter);
const numSteps = this.numSteps;
const qpm = (qns.tempos && qns.tempos.length) ?
qns.tempos[0].qpm :
constants.DEFAULT_QUARTERS_PER_MINUTE;
const stepLength = (60. / qpm) / this.stepsPerQuarter;
const stepNotes = [];
for (let i = 0; i < numSteps; ++i) {
stepNotes.push(new Map());
}
qns.notes.forEach(n => {
if (!(this.tapify || this.pitchToClass.has(n.pitch))) {
return;
}
const s = n.quantizedStartStep;
if (s >= stepNotes.length) {
throw Error(`Model does not support sequences with more than ${numSteps} steps (${numSteps * stepLength} seconds at qpm ${qpm}).`);
}
const d = this.tapify ? this.TAPIFY_CHANNEL : this.pitchToClass.get(n.pitch);
if (!stepNotes[s].has(d) || stepNotes[s].get(d).velocity < n.velocity) {
stepNotes[s].set(d, n);
}
});
const numDrums = this.pitchClasses.length;
const hitVectors = tf.buffer([numSteps, numDrums]);
const velocityVectors = tf.buffer([numSteps, numDrums]);
const offsetVectors = tf.buffer([numSteps, numDrums]);
function getOffset(n) {
if (n.startTime === undefined) {
return 0;
}
const tOnset = n.startTime;
const qOnset = n.quantizedStartStep * stepLength;
return 2 * (qOnset - tOnset) / stepLength;
}
for (let s = 0; s < numSteps; ++s) {
for (let d = 0; d < numDrums; ++d) {
const note = stepNotes[s].get(d);
hitVectors.set(note ? 1 : 0, s, d);
if (!this.humanize && !this.tapify) {
velocityVectors.set(note ? note.velocity / constants.MAX_MIDI_VELOCITY : 0, s, d);
}
if (!this.humanize) {
offsetVectors.set(note ? getOffset(note) : 0, s, d);
}
}
}
return tf.tidy(() => {
const hits = hitVectors.toTensor();
const velocities = velocityVectors.toTensor();
const offsets = offsetVectors.toTensor();
const outLength = this.splitInstruments ? numSteps * numDrums : numSteps;
return tf.concat([
hits.as2D(outLength, -1), velocities.as2D(outLength, -1),
offsets.as2D(outLength, -1)
], 1);
});
}
async toNoteSequence(t, stepsPerQuarter, qpm = constants.DEFAULT_QUARTERS_PER_MINUTE) {
if (stepsPerQuarter && stepsPerQuarter !== this.stepsPerQuarter) {
throw Error('`stepsPerQuarter` is set by the model.');
}
stepsPerQuarter = this.stepsPerQuarter;
const numSteps = this.splitInstruments ?
t.shape[0] / this.pitchClasses.length :
t.shape[0];
const stepLength = (60. / qpm) / this.stepsPerQuarter;
const ns = NoteSequence.create({ totalTime: numSteps * stepLength, tempos: [{ qpm }] });
const results = await t.data();
function clip(v, min, max) {
return Math.min(Math.max(v, min), max);
}
const numDrums = this.pitchClasses.length;
for (let s = 0; s < numSteps; ++s) {
const stepResults = results.slice(s * numDrums * this.depth, (s + 1) * numDrums * this.depth);
for (let d = 0; d < numDrums; ++d) {
const hitOutput = stepResults[this.splitInstruments ? d * this.depth : d];
const velI = this.splitInstruments ? (d * this.depth + 1) : (numDrums + d);
const velOutput = stepResults[velI];
const offsetI = this.splitInstruments ? (d * this.depth + 2) : (2 * numDrums + d);
const offsetOutput = stepResults[offsetI];
if (hitOutput > 0.5) {
const velocity = clip(Math.round(velOutput * constants.MAX_MIDI_VELOCITY), constants.MIN_MIDI_VELOCITY, constants.MAX_MIDI_VELOCITY);
const offset = clip(offsetOutput / 2, -0.5, 0.5);
ns.notes.push(NoteSequence.Note.create({
pitch: this.pitchClasses[d][0],
startTime: (s - offset) * stepLength,
endTime: (s - offset + 1) * stepLength,
velocity,
isDrum: true
}));
}
}
}
return ns;
}
}
//# sourceMappingURL=data.js.map