UNPKG

webppl

Version:

Probabilistic programming for the web

210 lines (182 loc) 7.49 kB
//////////////////////////////////////////////////////////////////// // Asynchronous Anytime SMC. // http://arxiv.org/abs/1407.2864 // bufferSize: queue size // numParticles: total number of particles to run 'use strict'; var _ = require('lodash'); var util = require('../util'); var numeric = require('../math/numeric'); var CountAggregator = require('../aggregation/CountAggregator'); module.exports = function(env) { function copyOneParticle(particle) { return { continuation: particle.continuation, weight: particle.weight, completed: particle.completed, factorIndex: particle.factorIndex, value: particle.value, numChildrenToSpawn: 1, multiplicity: particle.multiplicity, store: _.clone(particle.store) }; } function initParticle(s, cont) { return { continuation: cont, weight: 0, completed: false, factorIndex: undefined, value: undefined, numChildrenToSpawn: 0, multiplicity: 1, store: _.clone(s) }; } function AsyncPF(s, k, a, wpplFn, options) { this.numParticles = 0; // K_0 -- initialized here, set in run this.bufferSize = options.bufferSize == undefined ? options.particles : options.bufferSize; // \rho this.initNumParticles = Math.floor(this.bufferSize * (1 / 2)); // \rho_0 this.exitK = function(s) {return wpplFn(s, env.exit, a);}; this.store = s; this.buffer = []; for (var i = 0; i < this.initNumParticles; i++) { this.buffer.push(initParticle(this.store, this.exitK)); } this.obsWeights = {}; this.exitedParticles = 0; this.hist = new CountAggregator(); // Move old coroutine out of the way and install this as current handler. this.k = k; this.oldCoroutine = env.coroutine; env.coroutine = this; this.oldStore = _.clone(s); // will be reinstated at the end } AsyncPF.prototype.run = function(numP) { // allows for continuing pf this.numParticles = (numP == undefined) ? this.numParticles : this.numParticles + numP; // launch a new particle OR continue an existing one var p, launchP; var i = Math.floor((this.buffer.length + 1) * util.random()); if (i == this.buffer.length) { // generate new particle p = initParticle(this.store, this.exitK); } else { // launch particle in queue launchP = this.buffer[i]; if (launchP.numChildrenToSpawn > 1) { p = copyOneParticle(launchP); launchP.numChildrenToSpawn -= 1; } else { p = launchP; this.buffer = util.deleteIndex(this.buffer, i); } } this.activeParticle = p; return p.continuation(p.store); }; AsyncPF.prototype.sample = function(s, cc, a, dist) { return cc(s, dist.sample()); }; AsyncPF.prototype.factor = function(s, cc, a, score) { this.activeParticle.weight += score; this.activeParticle.continuation = cc; this.activeParticle.store = s; var fi = this.activeParticle.factorIndex; var newFI = fi == undefined ? 0 : fi + 1; this.activeParticle.factorIndex = newFI; this.branching(newFI); // compute branching and #children return this.run(); // return to run }; AsyncPF.prototype.branching = function(factorIndex) { // find weights at current observation var lk = this.obsWeights[factorIndex]; if (lk == undefined) { // 1st particle at observation var det = { wbar: this.activeParticle.weight, mnk: 1 }; this.obsWeights[factorIndex] = [det]; this.activeParticle.numChildrenToSpawn = 1; } else { // 2nd or greater particle at observation var currMultiplicity = this.activeParticle.multiplicity; var currWeight = this.activeParticle.weight; var denom = lk.length + currMultiplicity; // k - 1 + Ckn var prevWBar = lk[lk.length - 1].wbar; var wbar = -Math.log(denom) + numeric._logsumexp([Math.log(lk.length) + prevWBar, Math.log(currMultiplicity) + currWeight]); if (wbar > 0) throw new Error('Positive weight!!'); // sanity check var logRatio = currWeight - wbar; var numChildrenAndWeight = []; // compute number of children and their weights if (logRatio < 0) { numChildrenAndWeight = Math.log(util.random()) < logRatio ? [1, wbar] : [0, -Infinity]; } else { var totalChildren = 0; for (var v = 0; v < lk.length; v++) totalChildren += lk[v].mnk; // \sum M^k_n var minK = Math.min(this.numParticles, lk.length); // min(K_0, k-1) // if all previous particles have -Infinity *and* current weight is -Infinity // rnk = lim(x->0) x/x = 1 => [1, wbar] = [1, -Infinity] var rnk = isNaN(logRatio) ? 1 : Math.exp(logRatio); var clampedRnk = totalChildren <= minK ? Math.ceil(rnk) : Math.floor(rnk); numChildrenAndWeight = [clampedRnk, currWeight - Math.log(clampedRnk)]; } var det2 = { wbar: wbar, mnk: numChildrenAndWeight[0] }; this.obsWeights[factorIndex] = lk.concat([det2]); if (numChildrenAndWeight[0] > 0) { // there are children if (this.buffer.length < this.bufferSize) { // buffer can be added to this.activeParticle.numChildrenToSpawn = numChildrenAndWeight[0]; this.activeParticle.weight = numChildrenAndWeight[1]; } else { // buffer full, update multiplicty this.activeParticle.multiplicity *= numChildrenAndWeight[0]; this.activeParticle.numChildrenToSpawn = 1; this.activeParticle.weight = numChildrenAndWeight[1]; } this.buffer.push(this.activeParticle); // push into buffer } } }; AsyncPF.prototype.exit = function(s, retval) { this.activeParticle.value = retval; this.activeParticle.completed = true; // correct weight with multiplicity this.activeParticle.weight += Math.log(this.activeParticle.multiplicity); this.exitedParticles += 1; this.hist.add(retval); if (this.exitedParticles < this.numParticles) { return this.run(); } else { var dist = this.hist.toDist(); var lastFactorIndex = this.activeParticle.factorIndex; var olk = this.obsWeights[lastFactorIndex]; var Kn = _.reduce(olk, function(a, b) {return a + b.mnk;}, 0) dist.normalizationConstant = Math.log(Kn) - // K_n Math.log(this.numParticles) + // K_0 olk[olk.length - 1].wbar; // Wbar^k_n // allow for continuing pf var currCoroutine = this; dist.continue = function(s, k, a, numP) { currCoroutine.k = k; currCoroutine.oldCoroutine = env.coroutine; env.coroutine = currCoroutine; currCoroutine.oldStore = _.clone(s); // will be reinstated at the end return currCoroutine.run(numP); }; // Reinstate previous coroutine: env.coroutine = this.oldCoroutine; // Return from particle filter by calling original continuation: return this.k(this.oldStore, dist); } }; AsyncPF.prototype.incrementalize = env.defaultCoroutine.incrementalize; function asyncPF(s, cc, a, wpplFn, options) { options = options || {}; return new AsyncPF(s, cc, a, wpplFn, options).run(options.particles); } return { AsyncPF: asyncPF }; };