UNPKG

gaussian-mixture

Version:

An implementation of a Gaussian Mixture class in one dimension, that allows to fit models with an Expectation Maximization algorithm.

359 lines (296 loc) 10.2 kB
'use strict'; var test = require('tap').test; var data = require('./fixtures/data.json'); // 200 samples from refGmm var GMM = require('../index'); var Histogram = require('../index').Histogram; test('Initialization of a new GMM object.', function (t) { t.plan(2); function f(k) { return function () { var gmm = new GMM(3, k); gmm.sample(0); }; } t.throws(f([1, 2])); t.doesNotThrow(f([1, 2, 3])); }); test('Random sampling.', function (t) { t.plan(1); var gmm = new GMM(3); t.equal(5, gmm.sample(5).length); }); test('Gaussians of a mixture model.', function (t) { t.plan(6); var gmm = new GMM(3, [1 / 3, 1 / 3, 1 / 3], [0, 10, 20], [1, 2, 0.5]); var gaussians = gmm._gaussians(); t.equal(gaussians[0].mean, 0); t.equal(gaussians[1].mean, 10); t.equal(gaussians[2].mean, 20); t.equal(gaussians[0].variance, 1); t.equal(gaussians[1].variance, 2); t.equal(gaussians[2].variance, 0.5); }); test('Computing membership for one datapoint', function (t) { t.plan(2); var gmm = new GMM(3, undefined, [0, 10, 20]); t.equal(gmm.membership(5)[0], gmm.membership(5)[1]); t.equal(gmm.membership(0)[0] > 0.99, true); }); test('Shape of the membership matrix', function (t) { t.plan(2); var gmm = new GMM(5); var memberships = gmm.memberships([1, 2, 3, 4, 5, 6]); t.equal(memberships.length, 6); t.equal(memberships[0].length, 5); }); test('Convergence of model update', function (t) { t.plan(9); var refGmm = new GMM(3, [0.2, 0.5, 0.3], [0, 10, 30], [1, 2, 4]); var testGmm = new GMM(3, undefined, [-1, 13, 25], [1, 1, 1]); for (var i = 0; i < 200; i++) { testGmm._updateModel(data); } for (var j = 0; j < 3; j++) { t.equal(Math.abs(testGmm.weights[j] - refGmm.weights[j]) < 0.1, true); t.equal(Math.abs(testGmm.means[j] - refGmm.means[j]) < 1, true); t.equal(Math.abs(testGmm.vars[j] - refGmm.vars[j]) < 1, true); } }); test('log likelihood function', function (t) { t.plan(15); var gmm = new GMM(3, undefined, [-1, 15, 37], [1, 1, 1]); var l = -Infinity; var temp; for (var i = 0; i < 15; i++) { gmm._updateModel(data); temp = gmm.logLikelihood(data); t.equal(temp - l >= -1e-5, true); l = temp; } }); test('EM full optimization', function (t) { t.plan(11); var gmm = new GMM(3, undefined, [-1, 13, 25], [1, 1, 1]); var gmm2 = new GMM(3, undefined, [-1, 13, 25], [1, 1, 1]); for (var i = 0; i < 200; i++) { gmm._updateModel(data); } var counter = gmm2.optimize(data); t.equal(counter, 3); for (var j = 0; j < 3; j++) { t.equal(Math.abs(gmm.weights[j] - gmm2.weights[j]) < 1e-7, true); t.equal(Math.abs(gmm.means[j] - gmm2.means[j]) < 1e-7, true); t.equal(Math.abs(gmm.vars[j] - gmm2.vars[j]) < 1e-7, true); } var gmm3 = new GMM(3, [0.4, 0.3, 0.4], [-10, -5, -1], [1, 1, 1], {initialize: true}); gmm3.optimize([1, 2, 3, 1, 2, 3, 1, 1, 2, 2, 3, 3]); t.same(gmm3.means, [1, 2, 3]); }); test('Variance prior', function (t) { t.plan(3); var options = { variancePrior: 3, variancePriorRelevance: 0.01 }; var options2 = { variancePrior: 3, variancePriorRelevance: 1 }; var options3 = { variancePrior: 3, variancePriorRelevance: 1000000 }; var gmm = new GMM(3, undefined, [-1, 13, 25], [1, 1, 1], options); var gmm2 = new GMM(3, undefined, [-1, 13, 25], [1, 1, 1], options2); var gmm3 = new GMM(3, undefined, [-1, 13, 25], [1, 1, 1], options3); gmm.optimize(data); gmm2.optimize(data); gmm3.optimize(data); var cropFloat = function (a) { return Number(a.toFixed(1)); }; t.same(gmm.vars.map(cropFloat), [1.1, 2.5, 4.6]); t.same(gmm2.vars.map(cropFloat), [2.7, 2.8, 3.4]); t.same(gmm3.vars.map(cropFloat), [3, 3, 3]); }); test('Separation prior', function (t) { t.plan(3); var options = { separationPrior: 3, separationPriorRelevance: 0.01 }; var options2 = { separationPrior: 3, separationPriorRelevance: 1 }; var options3 = { separationPrior: 3, separationPriorRelevance: 1000000 }; var gmm = new GMM(3, undefined, [-1, 13, 25], [1, 1, 1], options); var gmm2 = new GMM(3, undefined, [-1, 13, 25], [1, 1, 1], options2); var gmm3 = new GMM(3, undefined, [-1, 13, 25], [1, 1, 1], options3); gmm.optimize(data); gmm2.optimize(data); gmm3.optimize(data); var cropFloat = function (a) { return Number(a.toFixed(1)); }; t.same(gmm.means.map(cropFloat), [0.7, 10.2, 29.8]); t.same(gmm2.means.map(cropFloat), [11.6, 13.5, 17.8]); t.same(gmm3.means.map(cropFloat), [11.4, 14.4, 17.4]); }); test('Model', function (t) { t.plan(2); var gmm = new GMM(3, [0.4, 0.2, 0.4], [-1, 13, 25], [1, 2, 1]); var model = { nComponents: 3, weights: [0.4, 0.2, 0.4], means: [-1, 13, 25], vars: [1, 2, 1] }; t.same(gmm.model(), model); t.same(GMM.fromModel(model), gmm); }); test('Barycenter method', function (t) { t.plan(5); var cropFloat = function (a) { return Number(a.toFixed(5)); }; t.equal(GMM._barycenter([1, 2], [0.5, 0.5]), 1.5); t.same(cropFloat(GMM._barycenter([1, 2], [0.1, 0.9])), 1.9); t.equal(GMM._barycenter([1, 2, 3], [0.3, 0.4, 0.3]), 2); t.equal(GMM._barycenter([1, 2, 3], [3, 4, 3]), 2); t.equal(GMM._barycenter([1, 2, 3], [0, 0, 0.01]), 3); }); test('Km++ Initialization', function (t) { var gmm = new GMM(3, [0.4, 0.2, 0.4], [-1, 13, 25], [1, 2, 1]); var means = gmm._initialize([1, 3, 3, 3, 2, 2, 1, 1, 3, 2, 2, 1, 3, 3, 3, 2, 1]); t.same(means, [1, 2, 3]); t.same(gmm._initialize([1, 1, 1, 1]), [1, 1, 1]); t.same(gmm._initialize([1, 1, 1, 2, 17]), [1, 2, 17]); t.throws(function () { gmm._initialize([1]); }, new Error('Data must have more points than the number of components in the model.')); t.end(); }); test('Km++ Initialization - Histogram', function (t) { var gmm = new GMM(3, [0.4, 0.2, 0.4], [-1, 13, 25], [1, 2, 1]); var h = Histogram.fromData([1, 3, 3, 3, 2, 2, 1, 1, 3, 2, 2, 1, 3, 3, 3, 2, 1]); var means = gmm._initializeHistogram(h); t.same(means, [1, 2, 3]); t.same(gmm._initializeHistogram(Histogram.fromData([1, 1, 1, 1])), [1, 1, 1]); t.same(gmm._initializeHistogram(Histogram.fromData([1, 1, 1, 2, 17])), [1, 2, 17]); t.throws(function () { gmm._initializeHistogram(Histogram.fromData([1])); }, new Error('Data must have more points than the number of components in the model.')); h = new Histogram({ counts: {'A': 10000, 'B': 0.001, 'C': 10, 'D': 10}, bins: {'A': [0, 1], 'B': [1, 2], 'C': [3, 4], 'D': [4, 5]} }); t.same(gmm._initializeHistogram(h), [0.5, 3.5, 4.5]); t.end(); }); test('memberships - histogram', function (t) { var h = Histogram.fromData([1, 2, 5, 5.4, 5.5, 6, 7, 7]); var gmm = GMM.fromModel({ means: [1, 5, 7], vars: [2, 2, 2], weights: [0.3, 0.5, 0.2], nComponents: 3 }); var memberships = gmm._membershipsHistogram(h); var expected = { 1: [0.9818947940807183, 0.01798403047511045, 0.00012117544417123207], 2: [0.8788782427321509, 0.11894323591065209, 0.0021785213571970234], 5: [0.013212886953789417, 0.7213991842739687, 0.265387928772242], 6: [0.0012378419366357771, 0.49938107903168216, 0.49938107903168216], 7: [0.00009021165708731931, 0.268917159718714, 0.7309926286241988] }; t.equal(Object.keys(memberships).length, 5); var equal = true; Object.keys(memberships).forEach(function (k) { if (!equal) return; var m = memberships[k]; var e = expected[k]; if (m.length !== e.length) { equal = false; return; } for (var i = 0; i < m.length; i++) { if (Math.abs(m[i] - e[i]) > 0.00001) { equal = false; return; } } }); t.ok(equal); t.end(); }); test('log likelihood - histogram', function (t) { var h = Histogram.fromData([1, 2, 5, 5, 5, 6, 7, 7]); var gmm = GMM.fromModel({ means: [1, 5, 7], vars: [2, 2, 2], weights: [0.3, 0.5, 0.2], nComponents: 3 }); t.equal(gmm.logLikelihood(h), gmm.logLikelihood([1, 2, 5, 5, 5, 6, 7, 7])); t.end(); }); test('optimize - histogram', function (t) { var h = Histogram.fromData([1, 2, 5, 5, 5, 6, 7, 7]); var gmm = GMM.fromModel({ means: [1, 5, 7], vars: [2, 2, 2], weights: [0.3, 0.5, 0.2], nComponents: 3 }); var gmm2 = GMM.fromModel({ means: [1, 5, 7], vars: [2, 2, 2], weights: [0.3, 0.5, 0.2], nComponents: 3 }); gmm._optimizeHistogram(h); gmm2._optimize([1, 2, 5, 5, 5, 6, 7, 7]); var round = x => Number(x.toFixed(5)); t.same(gmm.model().means.map(round), gmm2.model().means.map(round)); t.same(gmm.model().vars.map(round), gmm2.model().vars.map(round)); t.same(gmm.model().weights.map(round), gmm2.model().weights.map(round)); var options = { variancePrior: 3, variancePriorRelevance: 0.5, separationPrior: 3, separationPriorRelevance: 1, initialize: true }; gmm.options = options; gmm2.options = options; gmm.optimize(h); gmm2._optimize([1, 2, 5, 5, 5, 6, 7, 7]); t.same(gmm.model().means.map(round), gmm2.model().means.map(round)); t.same(gmm.model().vars.map(round), gmm2.model().vars.map(round)); t.same(gmm.model().weights.map(round), gmm2.model().weights.map(round)); t.end(); }); test('histogram total', function (t) { var d = [1, 2, 3, 4, 5, 5, 6, 6, 6]; var h = Histogram.fromData(d); t.equal(Histogram._total(h), 9); t.end(); }); test('histogram classify', function (t) { t.equal(Histogram._classify(3.4), '3'); t.equal(Histogram._classify(3.4, {'A': [1, 2], 'B': [3, 3.4], 'C': [3.4, 5], 'D': [5, 6]}), 'C'); t.same(Histogram._classify(7, {'A': [1, 2], 'B': [3, 3.4], 'C': [3.4, 5], 'D': [5, 6]}), null); t.end(); }); test('histogram value', function (t) { var h = new Histogram({ bins: {'A': [1, 2], 'B': [3, 3.4], 'C': [3.4, 5], 'D': [5, 6]}, counts: {'A': 5, 'B': 3} }); t.equal(h.value('A'), 1.5); t.equal(h.value('B'), 3.2); t.throws(() => h.value('E')); t.end(); }); test('histogram flatten', function (t) { var h = new Histogram({ bins: {'A': [1, 2], 'B': [3, 3.4], 'C': [3.4, 5], 'D': [5, 6]}, counts: {'A': 3, 'B': 2} }); t.same(h.flatten(), [1.5, 1.5, 1.5, 3.2, 3.2]); t.end(); });