webppl
Version:
Probabilistic programming for the web
211 lines (172 loc) • 6.75 kB
JavaScript
// In this file, we test our samplers by running them a bunch for various
// sample values and comparing the resulting *sample* statistics against mathematically
// derived *population* statistics. We also check that every sample is in the
// support of the distribution, so that users aren't bit by underflow or overflow
;
var _ = require('lodash');
var seedrandom = require('seedrandom');
var util = require('../src/util');
var webppl = require('../src/main');
var helpers = require('./helpers/helpers');
var statistics = require('../src/math/statistics');
var repeat = function(n, f) {
// used typedarray because node can run out of memory easily with lots of big arrays
var a = new Float64Array(n);
for (var i = 0; i < n; i++) {
a[i] = f()
}
return a;
}
var ln = Math.log,
pow = Math.pow,
sqrt = Math.sqrt,
abs = Math.abs;
var mean = statistics.mean;
var variance = statistics.variance;
var sd = statistics.sd;
var skew = statistics.skew;
var kurtosis = statistics.kurtosis;
var mode = statistics.kdeMode;
var sampleStatisticFunctions = {
mean: mean,
variance: variance,
skew: skew,
kurtosis: kurtosis,
mode: mode
}
var distMetadataList = [
require('./test-data/sampler/gamma'),
require('./test-data/sampler/binomial'),
require('./test-data/sampler/beta'),
require('./test-data/sampler/gaussian'),
require('./test-data/sampler/poisson'),
require('./test-data/sampler/logNormal')
];
var generateSettingTest = function(seed, distMetadata, settings) {
// settings includes:
// - params to the dist
// - inference params (e.g., number of samples)
// - test params (e.g., relative tolerance)
var params = settings.params;
var n = settings.n;
// only test the stats that aren't blacklisted
var populationStatisticFunctions = _.pickBy(distMetadata.populationStatisticFunctions,
function(v, k) {
return !_.includes(settings.skip, k)
});
var group = {};
var moment = distMetadata.moment;
group['test'] = function(test) {
var samples = repeat(n, function() {
return distMetadata.sampler.apply(null, params);
});
// first check support
// use for loop because some nodes don't define map()
// for Float64Array
var allInSupport = true;
var outsideSupport = [];
for (var i = 0, ii = samples.length; i < ii; i++) {
var inSupport = distMetadata.inSupport(params, samples[i]);
allInSupport = allInSupport && inSupport;
if (!inSupport) {
outsideSupport.push(samples[i])
}
}
test.ok(allInSupport, 'support test failed ' + outsideSupport.slice(0, 10).join(', '));
// then check each populationStatisticFunction
_.each(populationStatisticFunctions, function(statFn, statName) {
var expectedResult = statFn(params);
// compute an automatic tolerance for mean, variance, skew, kurtosis
var autoTolerance;
var variance = populationStatisticFunctions.variance(params)
var sigma = sqrt(variance);
var samplingDistVariance;
if (statName == 'mean') {
samplingDistVariance = variance / n;
} else if (statName == 'variance') {
// sample variance is asymptotically normally distributed
// http://stats.stackexchange.com/a/105338/71884
samplingDistVariance = moment(params, 4) / n - pow(sigma, 4) * (n - 3) / (n * (n - 1));
} else if (statName == 'skew') {
// HT https://en.wikipedia.org/wiki/Skewness#Sample_skewness
// formula assumes normal distribution
// thankfully, van der Vaart tells us that sample skew is asymptotically
// normally distributed (page 29 of Asymptotic Statistics)
samplingDistVariance = 6 * n * (n - 1) / ((n - 2) * (n + 1) * (n + 3));
} else if (statName == 'kurtosis') {
// HT https://en.wikipedia.org/wiki/Kurtosis#Sample_kurtosis
samplingDistVariance = 24 * n * (n - 1) * (n - 1) / ((n - 3) * (n - 2) * (n + 3) * (n + 5))
}
// we want tests to fail with probability 1/10000 (succeed with probability 0.9999)
// set the error tolerance to be 4 sd's;
// 0.999367 of the probability mass of a normal distribution lies within
// 4 standard deviations.
// but the sampling distributions are only asymptotically normal
// so let's give them some breathing room
var autoToleranceMultiple = {
mean: 8,
variance: 8,
skew: 100,
kurtosis: 100
};
autoTolerance = autoToleranceMultiple[statName] * sqrt(samplingDistVariance);
var sampleStatisticFunction;
if (statName == 'mode') {
sampleStatisticFunction = (distMetadata.type == 'integer') ? statistics.mode : statistics.kdeMode
} else {
sampleStatisticFunction = sampleStatisticFunctions[statName];
}
var actualResult = sampleStatisticFunction(samples);
var tolerance;
if (settings.reltol && settings.reltol[statName]) {
tolerance = abs(settings.reltol[statName] * expectedResult);
} else if (settings.abstol && settings.abstol[statName]) {
tolerance = settings.abstol[statName];
} else {
tolerance = autoTolerance;
}
helpers.testWithinTolerance(test,
actualResult,
expectedResult,
tolerance,
statName,
'verbose');
});
test.done();
};
group.setUp = function(callback) {
util.seedRNG(seed);
callback();
};
group.tearDown = function(callback) {
util.resetRNG();
callback();
};
return group;
}
var generateTestCases = function(seed) {
var oldSuppressWarnings = !!global.suppressWarnings;
var oldStackTraceLimit = Error.stackTraceLimit;
exports.setUp = function(callback) {
// suppress warnings (for, e.g., underflow)
global.suppressWarnings = true;
// less noise from stack trace
Error.stackTraceLimit = 2;
callback()
}
exports.tearDown = function(callback) {
global.suppressWarnings = oldSuppressWarnings;
Error.stackTraceLimit = oldStackTraceLimit;
callback()
}
_.each(distMetadataList, function(distMetadata) {
var group = {};
_.map(distMetadata.settings, function(settings) {
group[settings.params.join(',')] = generateSettingTest(seed, distMetadata, settings)
});
exports[distMetadata.name] = group;
});
};
var seed = helpers.getRandomSeedFromEnv() || abs(seedrandom().int32());
console.log('Random seed: ' + seed);
generateTestCases(seed);