UNPKG

webppl

Version:

Probabilistic programming for the web

211 lines (172 loc) 6.75 kB
// In this file, we test our samplers by running them a bunch for various // sample values and comparing the resulting *sample* statistics against mathematically // derived *population* statistics. We also check that every sample is in the // support of the distribution, so that users aren't bit by underflow or overflow 'use strict'; var _ = require('lodash'); var seedrandom = require('seedrandom'); var util = require('../src/util'); var webppl = require('../src/main'); var helpers = require('./helpers/helpers'); var statistics = require('../src/math/statistics'); var repeat = function(n, f) { // used typedarray because node can run out of memory easily with lots of big arrays var a = new Float64Array(n); for (var i = 0; i < n; i++) { a[i] = f() } return a; } var ln = Math.log, pow = Math.pow, sqrt = Math.sqrt, abs = Math.abs; var mean = statistics.mean; var variance = statistics.variance; var sd = statistics.sd; var skew = statistics.skew; var kurtosis = statistics.kurtosis; var mode = statistics.kdeMode; var sampleStatisticFunctions = { mean: mean, variance: variance, skew: skew, kurtosis: kurtosis, mode: mode } var distMetadataList = [ require('./test-data/sampler/gamma'), require('./test-data/sampler/binomial'), require('./test-data/sampler/beta'), require('./test-data/sampler/gaussian'), require('./test-data/sampler/poisson'), require('./test-data/sampler/logNormal') ]; var generateSettingTest = function(seed, distMetadata, settings) { // settings includes: // - params to the dist // - inference params (e.g., number of samples) // - test params (e.g., relative tolerance) var params = settings.params; var n = settings.n; // only test the stats that aren't blacklisted var populationStatisticFunctions = _.pickBy(distMetadata.populationStatisticFunctions, function(v, k) { return !_.includes(settings.skip, k) }); var group = {}; var moment = distMetadata.moment; group['test'] = function(test) { var samples = repeat(n, function() { return distMetadata.sampler.apply(null, params); }); // first check support // use for loop because some nodes don't define map() // for Float64Array var allInSupport = true; var outsideSupport = []; for (var i = 0, ii = samples.length; i < ii; i++) { var inSupport = distMetadata.inSupport(params, samples[i]); allInSupport = allInSupport && inSupport; if (!inSupport) { outsideSupport.push(samples[i]) } } test.ok(allInSupport, 'support test failed ' + outsideSupport.slice(0, 10).join(', ')); // then check each populationStatisticFunction _.each(populationStatisticFunctions, function(statFn, statName) { var expectedResult = statFn(params); // compute an automatic tolerance for mean, variance, skew, kurtosis var autoTolerance; var variance = populationStatisticFunctions.variance(params) var sigma = sqrt(variance); var samplingDistVariance; if (statName == 'mean') { samplingDistVariance = variance / n; } else if (statName == 'variance') { // sample variance is asymptotically normally distributed // http://stats.stackexchange.com/a/105338/71884 samplingDistVariance = moment(params, 4) / n - pow(sigma, 4) * (n - 3) / (n * (n - 1)); } else if (statName == 'skew') { // HT https://en.wikipedia.org/wiki/Skewness#Sample_skewness // formula assumes normal distribution // thankfully, van der Vaart tells us that sample skew is asymptotically // normally distributed (page 29 of Asymptotic Statistics) samplingDistVariance = 6 * n * (n - 1) / ((n - 2) * (n + 1) * (n + 3)); } else if (statName == 'kurtosis') { // HT https://en.wikipedia.org/wiki/Kurtosis#Sample_kurtosis samplingDistVariance = 24 * n * (n - 1) * (n - 1) / ((n - 3) * (n - 2) * (n + 3) * (n + 5)) } // we want tests to fail with probability 1/10000 (succeed with probability 0.9999) // set the error tolerance to be 4 sd's; // 0.999367 of the probability mass of a normal distribution lies within // 4 standard deviations. // but the sampling distributions are only asymptotically normal // so let's give them some breathing room var autoToleranceMultiple = { mean: 8, variance: 8, skew: 100, kurtosis: 100 }; autoTolerance = autoToleranceMultiple[statName] * sqrt(samplingDistVariance); var sampleStatisticFunction; if (statName == 'mode') { sampleStatisticFunction = (distMetadata.type == 'integer') ? statistics.mode : statistics.kdeMode } else { sampleStatisticFunction = sampleStatisticFunctions[statName]; } var actualResult = sampleStatisticFunction(samples); var tolerance; if (settings.reltol && settings.reltol[statName]) { tolerance = abs(settings.reltol[statName] * expectedResult); } else if (settings.abstol && settings.abstol[statName]) { tolerance = settings.abstol[statName]; } else { tolerance = autoTolerance; } helpers.testWithinTolerance(test, actualResult, expectedResult, tolerance, statName, 'verbose'); }); test.done(); }; group.setUp = function(callback) { util.seedRNG(seed); callback(); }; group.tearDown = function(callback) { util.resetRNG(); callback(); }; return group; } var generateTestCases = function(seed) { var oldSuppressWarnings = !!global.suppressWarnings; var oldStackTraceLimit = Error.stackTraceLimit; exports.setUp = function(callback) { // suppress warnings (for, e.g., underflow) global.suppressWarnings = true; // less noise from stack trace Error.stackTraceLimit = 2; callback() } exports.tearDown = function(callback) { global.suppressWarnings = oldSuppressWarnings; Error.stackTraceLimit = oldStackTraceLimit; callback() } _.each(distMetadataList, function(distMetadata) { var group = {}; _.map(distMetadata.settings, function(settings) { group[settings.params.join(',')] = generateSettingTest(seed, distMetadata, settings) }); exports[distMetadata.name] = group; }); }; var seed = helpers.getRandomSeedFromEnv() || abs(seedrandom().int32()); console.log('Random seed: ' + seed); generateTestCases(seed);