webppl
Version:
Probabilistic programming for the web
250 lines (207 loc) • 8.44 kB
JavaScript
'use strict';
var _ = require('lodash');
var assert = require('assert');
var util = require('../util');
var numeric = require('../math/numeric');
var ad = require('../ad');
module.exports = function(env) {
var drift = require('./driftKernel')(env);
function makeMHKernel(options) {
options = util.mergeDefaults(options, {
adRequired: false,
permissive: false,
discreteOnly: false
}, 'MH kernel');
return function(cont, oldTrace, runOpts) {
return new MHKernel(cont, oldTrace, options, runOpts).run();
};
}
function MHKernel(cont, oldTrace, options, runOpts) {
this.discreteOnly = options.discreteOnly;
this.adRequired = options.adRequired;
if (!options.permissive) {
assert.notStrictEqual(oldTrace.score, -Infinity);
}
runOpts = util.mergeDefaults(runOpts, {
proposalBoundary: 0,
exitFactor: 0,
factorCoeff: 1,
allowHardFactors: true
});
this.proposalBoundary = runOpts.proposalBoundary;
this.exitFactor = runOpts.exitFactor;
this.factorCoeff = runOpts.factorCoeff;
assert.ok(0 <= this.factorCoeff && this.factorCoeff <= 1);
this.allowHardFactors = runOpts.allowHardFactors;
this.cont = cont;
this.oldTrace = oldTrace;
this.a = oldTrace.baseAddress; // Support relative addressing.
this.reused = {};
this.oldCoroutine = env.coroutine;
env.coroutine = this;
}
MHKernel.prototype.run = function() {
this.regenFrom = this.sampleRegenChoice(this.oldTrace);
if (this.regenFrom < 0) {
// Immediately return from coroutine if there are no random
// choices to propose to.
return this.continue(this.oldTrace);
}
env.query.clear();
this.trace = this.oldTrace.upto(this.regenFrom);
var regen = this.oldTrace.choiceAtIndex(this.regenFrom);
return this.resample(_.clone(regen.store), regen.k, regen.address, regen.dist, regen.options);
};
MHKernel.prototype.factor = function(s, k, a, score) {
// Optimization: Bail early if we know acceptProb will be zero.
if (ad.value(score) === -Infinity) {
if (!this.allowHardFactors) {
throw new Error('Hard factor statements are not allowed.');
}
return this.finish(this.oldTrace, false);
}
this.trace.numFactors += 1;
this.trace.score = ad.scalar.add(this.trace.score, score);
if (this.trace.numFactors === this.exitFactor) {
this.trace.saveContinuation(s, k);
return this.exit(s, undefined, true);
}
return k(s);
};
MHKernel.prototype.sample = function(s, k, a, dist, options) {
var prevChoice = this.oldTrace.findChoice(a);
var val;
if (prevChoice) {
val = prevChoice.val; // Will be a tape if continuous.
this.reused[a] = true;
} else {
var _val = dist.sample();
val = this.adRequired && dist.isContinuous ? ad.lift(_val) : _val;
}
return this.addChoiceToTrace(s, k, a, dist, options, val);
};
// Generation of a new proposal begins here, by re-sampling a value
// for the random choice selected as the regen point.
MHKernel.prototype.resample = function(s, k, a, dist, options) {
var prevChoice = this.oldTrace.findChoice(a);
assert(prevChoice);
return drift.getProposalDist(s, a, dist, options, prevChoice.val, function(s, fwdProposalDist) {
var _val = fwdProposalDist.sample();
var val = this.adRequired && fwdProposalDist.isContinuous ? ad.lift(_val) : _val;
// Optimization: Bail early if same value is re-sampled.
if (!fwdProposalDist.isContinuous && prevChoice.val === val) {
return this.finish(this.oldTrace, true);
}
return drift.getProposalDist(s, a, dist, options, val, function(s, revProposalDist) {
// Store references to the proposal distributions. Getting our
// hands on them again later (from the non-CPS acceptance
// probability code) would be tricky.
this.fwdProposalDist = fwdProposalDist;
this.revProposalDist = revProposalDist;
return this.addChoiceToTrace(s, k, a, dist, options, val, true);
}.bind(this));
}.bind(this));
};
MHKernel.prototype.addChoiceToTrace = function(s, k, a, dist, options, val, atResample) {
this.trace.addChoice(dist, val, a, s, k, options);
if (ad.value(this.trace.score) === -Infinity) {
if (atResample && _.has(options, 'driftKernel')) {
drift.proposalWarning(dist);
}
return this.finish(this.oldTrace, false);
}
return k(s, val);
};
MHKernel.prototype.exit = function(s, val, earlyExit) {
if (!earlyExit) {
this.trace.complete(val);
} else {
assert(this.trace.store);
assert(this.trace.k);
assert(!this.trace.isComplete());
}
var prob = this.acceptProb(this.trace, this.oldTrace);
var accept = util.random() < prob;
return this.finish(accept ? this.trace : this.oldTrace, accept);
};
MHKernel.prototype.finish = function(trace, accepted) {
assert(_.isBoolean(accepted));
if (accepted && trace.value === env.query) {
trace.value = _.assign({}, this.oldTrace.value, env.query.getTable());
}
if (this.oldTrace.info) {
var oldInfo = this.oldTrace.info;
trace.info = {
accepted: oldInfo.accepted + accepted,
total: oldInfo.total + 1
};
}
return this.continue(trace);
};
MHKernel.prototype.continue = function(trace) {
env.coroutine = this.oldCoroutine;
return this.cont(trace);
};
MHKernel.prototype.incrementalize = env.defaultCoroutine.incrementalize;
MHKernel.prototype.proposableDiscreteDistIndices = function(trace) {
return _.range(this.proposalBoundary, trace.length).filter(function(i) {
return !trace.choices[i].dist.isContinuous;
});
};
MHKernel.prototype.numRegenChoices = function(trace) {
if (this.discreteOnly) {
return this.proposableDiscreteDistIndices(trace).length;
} else {
return trace.length - this.proposalBoundary;
}
};
MHKernel.prototype.sampleRegenChoice = function(trace) {
return this.discreteOnly ?
this.sampleRegenChoiceDiscrete(trace) :
this.sampleRegenChoiceAny(trace);
};
MHKernel.prototype.sampleRegenChoiceDiscrete = function(trace) {
var indices = this.proposableDiscreteDistIndices(trace);
return indices.length > 0 ? indices[Math.floor(util.random() * indices.length)] : -1;
};
MHKernel.prototype.sampleRegenChoiceAny = function(trace) {
var numChoices = trace.length - this.proposalBoundary;
return numChoices > 0 ? this.proposalBoundary + Math.floor(util.random() * numChoices) : -1;
};
MHKernel.prototype.acceptProb = function(trace, oldTrace) {
// assert.notStrictEqual(trace, undefined);
// assert.notStrictEqual(oldTrace, undefined);
// assert.notStrictEqual(this.fwdProposalDist, undefined);
// assert.notStrictEqual(this.revProposalDist, undefined);
// assert(_.isNumber(ad.value(trace.score)));
// assert(_.isNumber(ad.value(oldTrace.score)));
// assert(_.isNumber(this.regenFrom));
// assert(_.isNumber(this.proposalBoundary));
var fw = this.transitionProb(oldTrace, trace, this.fwdProposalDist);
var bw = this.transitionProb(trace, oldTrace, this.revProposalDist);
var newTraceScore, oldTraceScore;
if (this.factorCoeff == 1) {
// Optimise for the common case.
newTraceScore = ad.value(trace.score);
oldTraceScore = ad.value(oldTrace.score);
} else {
newTraceScore = ad.value(trace.scoreAllChoices()) + this.factorCoeff * ad.value(trace.scoreAllFactors());
oldTraceScore = ad.value(oldTrace.scoreAllChoices()) + this.factorCoeff * ad.value(oldTrace.scoreAllFactors());
}
var p = Math.exp(newTraceScore - oldTraceScore + bw - fw);
assert(!isNaN(p));
return Math.min(1, p);
};
MHKernel.prototype.transitionProb = function(fromTrace, toTrace, proposalDist) {
var regenChoice = toTrace.choiceAtIndex(this.regenFrom);
var score = ad.value(proposalDist.score(regenChoice.val));
// Rest of the trace.
score += numeric._sum(toTrace.choices.slice(this.regenFrom + 1).map(function(choice) {
return this.reused.hasOwnProperty(choice.address) ? 0 : ad.value(choice.dist.score(choice.val));
}, this));
score -= Math.log(this.numRegenChoices(fromTrace));
assert(!isNaN(score));
return score;
};
return makeMHKernel;
};