multivariate-normal
Version:
Port of NumPy's random.multivariate_normal to Node.JS
101 lines (81 loc) • 3.21 kB
JavaScript
Object.defineProperty(exports, "__esModule", {
value: true
});
var _validation = require("./validation.js");
var _gaussian = require("gaussian");
var _gaussian2 = _interopRequireDefault(_gaussian);
var _numeric = require("numeric");
var _numeric2 = _interopRequireDefault(_numeric);
function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
// Low-level distribution constructor. NOT a public API.
//
// n: dimensionality of the distribtion
// mean: vector of mean, of length n.
// cov: covarance matrix, of size n-by-n.
// svd: { u, s, v } from decomposition of cov
//
// preconditions:
// - mean and cov have been validated
// - mean and cov are frozen
var standardNormalDist = (0, _gaussian2.default)(0, 1);
var standardNormalVector = function standardNormalVector(length) {
var ary = [];
for (var i = 0; i < length; i++) {
ary.push(standardNormalDist.ppf(Math.random()));
}
return ary;
};
var Distribution = function Distribution(n, mean, cov, _ref) {
var u = _ref.u,
s = _ref.s,
v = _ref.v;
return {
sample: function sample() {
// From numpy (paraphrased):
// x = standard_normal(n)
// x = np.dot(x, np.sqrt(s)[:, None] * v)
// x += mean
//
// https://github.com/numpy/numpy/blob/a835270d718d299535606d7104fd86d9b2aa68a6/numpy/random/mtrand/mtrand.pyx
// np.sqrt(s)[:, None] * v
//
// This is an elegant way in numpy to multiply each column of
// v by sqrt(s). Unfortunately, we don't have numpy, so we do this
// manually
var sqrtS = s.map(Math.sqrt);
var scaledV = v.map(function (row) {
return row.map(function (val, colIdx) {
return val * sqrtS[colIdx];
});
});
// We populate a row vector with a standard normal distribution
// (mean 0, variance 1), and then multiply it with scaledV
var standardNormal = standardNormalVector(n);
// compute the correlated dsitribution based on the covariance
// matrix
var variants = _numeric2.default.dot(standardNormal, _numeric2.default.transpose(scaledV));
// add the mean
return variants.map(function (variant, idx) {
return variant + mean[idx];
});
},
getMean: function getMean() {
return mean;
},
setMean: function setMean(unvalidatedMean) {
var newMean = (0, _validation.validateMean)(unvalidatedMean, n);
return Distribution(n, newMean, cov, { u: u, s: s, v: v });
},
getCov: function getCov() {
return cov;
},
setCov: function setCov(unvalidatedCov) {
var _validateCovAndGetSVD = (0, _validation.validateCovAndGetSVD)(unvalidatedCov, n),
newCov = _validateCovAndGetSVD.cov,
newSVD = _validateCovAndGetSVD.svd;
return Distribution(n, mean, newCov, newSVD);
}
};
};
exports.default = Distribution;
;