UNPKG

webppl

Version:

Probabilistic programming for the web

270 lines (223 loc) 7.96 kB
// Neal, Radford M. "MCMC using Hamiltonian dynamics." Handbook of // Markov Chain Monte Carlo 2 (2011). // http://arxiv.org/abs/1206.1901 'use strict'; var _ = require('lodash'); var assert = require('assert'); var util = require('../util'); var gaussian = require('../dists/gaussian'); var tensorGaussian = require('../dists/tensorGaussian'); var Trace = require('../trace'); var ad = require('../ad'); var Tensor = require('../tensor'); var generic = require('../math/genericArithmetic'); var addEq = generic.addEq; var add = generic.add; var mul = generic.mul; var sum = generic.sum; module.exports = function(env) { function makeHMCKernel(options) { options = util.mergeDefaults(options, { steps: 5, stepSize: 0.1 }, 'HMC kernel'); assert.ok(options.steps > 0); var kernel = function(cont, oldTrace, runOpts) { return new HMCKernel(cont, oldTrace, options, runOpts).run(); }; kernel.adRequired = true; return kernel; } function HMCKernel(cont, oldTrace, options, runOpts) { this.steps = options.steps; this.stepSize = options.stepSize; runOpts = util.mergeDefaults(runOpts, { exitFactor: 0 }); this.exitFactor = runOpts.exitFactor; this.cont = cont; this.oldTrace = oldTrace; this.a = oldTrace.baseAddress; // Support relative addressing. this.oldCoroutine = env.coroutine; env.coroutine = this; } HMCKernel.prototype.sample = function(s, k, a, dist, options) { var prevChoice = this.prevTrace.findChoice(a); if (!prevChoice) { throw new Error('HMC does not support structural continuous variables.'); } var val; if (dist.isContinuous) { if (dist.noHMC) { throw new Error('HMC does not yet support the ' + dist.meta.name + ' distribution'); } var prevVal = ad.value(prevChoice.val); var _val = add(prevVal, mul(this.momentum[a], this.stepSize)); // Handle constraints. // We only have constraints on scalar valued distributions at // present. The following is implemented using scalar math ops // for readability. if (dist.support) { var support = dist.support(); var lower = ad.value(support.lower); var upper = ad.value(support.upper); while (_val < lower || _val > upper) { if (_val < lower) { _val = lower + (lower - _val); this.momentum[a] *= -1; } if (_val > upper) { _val = upper - (_val - upper); this.momentum[a] *= -1; } } } val = ad.lift(_val); } else { val = prevChoice.val; } this.trace.addChoice(dist, val, a, s, k, options); return k(s, val); }; HMCKernel.prototype.factor = function(s, k, a, score) { this.trace.numFactors += 1; this.trace.score = ad.scalar.add(this.trace.score, score); if (this.trace.numFactors === this.exitFactor) { this.trace.saveContinuation(s, k); return this.exit(s, undefined, true); } return k(s); }; HMCKernel.prototype.run = function() { // Immediately return from coroutine if there are no continuous // random choices to propose to. var numChoices = this.oldTrace.choices.reduce(function(acc, c) { return acc + (c.dist.isContinuous ? 1 : 0); }, 0); if (numChoices === 0) { return this.continue(this.oldTrace); } // Zero derivatives left over from previous HMC iterations, or // from the rejuvenation of a particle which shares parts of the // ad graph which this trace. if (ad.isLifted(this.oldTrace.score)) { this.oldTrace.score.zeroDerivatives(); } // Initialize momentum. this.momentum = sampleMomentum(this.oldTrace); // Compute current value of H. var oldH = computeH(this.oldTrace, this.momentum); this.momentumStep(this.oldTrace, 0.5); // Half-step. (Modifies momentum in-place.) // Main HMC loop. // The leapfrog method. (See page 8 of "MCMC using Hamiltonian // dynamics".) return util.cpsIterate( this.steps - 1, this.oldTrace, this.leapFrogStep.bind(this), function(trace) { // Final position step: return this.positionStep(function(finalTrace) { // Final momentum half-step. this.momentumStep(finalTrace, 0.5); var newH = computeH(finalTrace, this.momentum); // Accept/reject. var p = Math.min(1, Math.exp(newH - oldH)); var accept = util.random() < p; return this.finish(accept ? finalTrace : this.oldTrace, accept); }.bind(this), trace); }.bind(this)); }; function stdGaussianSampleLike(x) { // Sample a value with same type (scalar/tensor) and dimension as // x from a standard Gaussian. return x instanceof Tensor ? tensorGaussian.sample(0, 1, x.dims) : gaussian.sample(0, 1); } function sampleMomentum(trace) { var momentum = {}; _.each(trace.choices, function(choice) { if (choice.dist.isContinuous) { momentum[choice.address] = stdGaussianSampleLike(ad.value(choice.val)); } }); return momentum; } HMCKernel.prototype.leapFrogStep = function(cont, trace) { return this.positionStep(function(newTrace) { this.momentumStep(newTrace, 1); return cont(newTrace); }.bind(this), trace); }; HMCKernel.prototype.positionStep = function(cont, trace) { // Run the program creating a new trace with updated (continuous) // variables. this.prevTrace = trace; this.trace = this.prevTrace.fresh(); // Once the WebPPL program has finished we need to call cont to // continue inference. Since the program will call env.exit once // finished, we save cont here in order to resume inference as // desired. Note that we can't pass a continuation other than // env.exit to the program. This is because the continuation is // stored as part of the trace, and when invoked by a different // MCMC kernel execution would jump back here. this.positionStepCont = cont; env.query.clear(); return this.trace.continue(); }; HMCKernel.prototype.exit = function(k, val, earlyExit) { if (!earlyExit) { this.trace.complete(val); } else { assert(this.trace.store); assert(this.trace.k); assert(!this.trace.isComplete()); } var cont = this.positionStepCont; this.positionStepCont = undefined; return cont(this.trace); }; HMCKernel.prototype.momentumStep = function(trace, scaleFactor) { if (ad.isLifted(trace.score)) { // Compute gradient of score w.r.t. the continuous variables. trace.score.backprop(); var stepSize = this.stepSize * scaleFactor; _.each(trace.choices, function(choice) { if (choice.dist.isContinuous) { this.momentum[choice.address] = addEq( this.momentum[choice.address], mul(ad.derivative(choice.val), stepSize)); } }.bind(this)); } }; function computeH(trace, momentum) { var score = ad.value(trace.score); var kinetic = 0.5 * _.reduce(momentum, function(memo, p) { return memo + sum(mul(p, p)); }, 0); var x = score - kinetic; return x; } HMCKernel.prototype.finish = function(trace, accepted) { assert(_.isBoolean(accepted)); if (accepted && trace.value === env.query) { trace.value = env.query.getTable(); } if (this.oldTrace.info) { var oldInfo = this.oldTrace.info; trace.info = { accepted: oldInfo.accepted + accepted, total: oldInfo.total + 1 }; } return this.continue(trace); }; HMCKernel.prototype.continue = function(trace) { env.coroutine = this.oldCoroutine; return this.cont(trace); }; HMCKernel.prototype.incrementalize = env.defaultCoroutine.incrementalize; return makeHMCKernel; };