UNPKG

multivariate-normal

Version:

Port of NumPy's random.multivariate_normal to Node.JS

101 lines (81 loc) 3.21 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var _validation = require("./validation.js"); var _gaussian = require("gaussian"); var _gaussian2 = _interopRequireDefault(_gaussian); var _numeric = require("numeric"); var _numeric2 = _interopRequireDefault(_numeric); function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } // Low-level distribution constructor. NOT a public API. // // n: dimensionality of the distribtion // mean: vector of mean, of length n. // cov: covarance matrix, of size n-by-n. // svd: { u, s, v } from decomposition of cov // // preconditions: // - mean and cov have been validated // - mean and cov are frozen var standardNormalDist = (0, _gaussian2.default)(0, 1); var standardNormalVector = function standardNormalVector(length) { var ary = []; for (var i = 0; i < length; i++) { ary.push(standardNormalDist.ppf(Math.random())); } return ary; }; var Distribution = function Distribution(n, mean, cov, _ref) { var u = _ref.u, s = _ref.s, v = _ref.v; return { sample: function sample() { // From numpy (paraphrased): // x = standard_normal(n) // x = np.dot(x, np.sqrt(s)[:, None] * v) // x += mean // // https://github.com/numpy/numpy/blob/a835270d718d299535606d7104fd86d9b2aa68a6/numpy/random/mtrand/mtrand.pyx // np.sqrt(s)[:, None] * v // // This is an elegant way in numpy to multiply each column of // v by sqrt(s). Unfortunately, we don't have numpy, so we do this // manually var sqrtS = s.map(Math.sqrt); var scaledV = v.map(function (row) { return row.map(function (val, colIdx) { return val * sqrtS[colIdx]; }); }); // We populate a row vector with a standard normal distribution // (mean 0, variance 1), and then multiply it with scaledV var standardNormal = standardNormalVector(n); // compute the correlated dsitribution based on the covariance // matrix var variants = _numeric2.default.dot(standardNormal, _numeric2.default.transpose(scaledV)); // add the mean return variants.map(function (variant, idx) { return variant + mean[idx]; }); }, getMean: function getMean() { return mean; }, setMean: function setMean(unvalidatedMean) { var newMean = (0, _validation.validateMean)(unvalidatedMean, n); return Distribution(n, newMean, cov, { u: u, s: s, v: v }); }, getCov: function getCov() { return cov; }, setCov: function setCov(unvalidatedCov) { var _validateCovAndGetSVD = (0, _validation.validateCovAndGetSVD)(unvalidatedCov, n), newCov = _validateCovAndGetSVD.cov, newSVD = _validateCovAndGetSVD.svd; return Distribution(n, mean, newCov, newSVD); } }; }; exports.default = Distribution;