webppl
Version:
Probabilistic programming for the web
166 lines (131 loc) • 5.24 kB
JavaScript
'use strict';
var _ = require('lodash');
var util = require('../../util');
var Trace = require('../../trace');
var guide = require('../../guide');
module.exports = function(env) {
// This coroutine generates samples from something like "the
// posterior predictive distribution over local random choices".
// This amounts to sampling global choices (those outside of
// mapData) from the guide, and local choices (those inside mapData)
// from the target.
// The trace data structure is only used as a dictionary in which
// sampled choices are stored for later look up. In particular
// `trace.score` is not maintained by this coroutine. All
// scores/gradients are computed by a separate coroutine.
function dreamSample(s, k, a, wpplFn) {
this.wpplFn = wpplFn;
this.s = s;
this.k = k;
this.a = a;
// A 'record' stores random choices (in the trace) and the
// fantasized data.
var trace = new Trace(this.wpplFn, s, env.exit, a);
this.record = {trace: trace, data: []};
this.guideRequired = true;
this.isParamBase = true;
this.insideMapData = false;
this.oldCoroutine = env.coroutine;
env.coroutine = this;
}
dreamSample.prototype = {
run: function() {
return this.wpplFn(_.clone(this.s), function(s, val) {
env.coroutine = this.oldCoroutine;
return this.k(this.s, this.record);
}.bind(this), this.a);
},
sample: function(s, k, a, dist, options) {
var sampleFn = this.insideMapData ? this.sampleLocal : this.sampleGlobal;
return sampleFn.call(this, s, a, dist, options, function(s, val) {
this.record.trace.addChoice(dist, val, a, s, k, options);
return k(s, val);
}.bind(this));
},
sampleLocal: function(s, a, targetDist, options, k) {
return k(s, targetDist.sample());
},
sampleGlobal: function(s, a, dist, options, k) {
options = options || {};
return guide.getDist(
options.guide, options.noAutoGuide, dist, env, s, a,
function(s, guideDist) {
if (!guideDist) {
throw new Error('dream: No guide distribution specified.');
}
return k(s, guideDist.sample());
});
},
factor: function(s, k, a) {
throw new Error('dream: factor not supported, use observe instead.');
},
observe: function(s, k, a, dist, options) {
if (!this.insideMapData) {
throw new Error('dream: observe can only be used within mapData with this estimator.');
}
var val = dist.sample();
this.observations.push(val);
return k(s, val);
},
mapDataEnter: function() {
this.observations = [];
},
mapDataLeave: function() {
if (_.isEmpty(this.observations)) {
throw new Error('dream: expected at least one observation to be made.');
}
// If there was only a single observation, unwrap it from the
// array.
var datum = this.observations.length === 1 ? this.observations[0] : this.observations;
this.record.data.push(datum);
},
mapDataFetch: function(data, opts, a) {
if (this.insideMapData) {
throw new Error('dream: nested mapData is not supported by this estimator.');
}
this.insideMapData = true;
this.guideRequired = false;
var batchSize = _.has(opts, 'dreamBatchSize') ? opts.dreamBatchSize : 1;
if (!(util.isInteger(batchSize) && batchSize >= 0)) {
throw new Error('dream: dreamBatchSize should be a non negative integer.');
}
// The current implementation yields elements of the actual data
// set to the observation function while fantasizing. This is to
// support models that use the structure of an observation as
// part of the generative process.
// (Time series models that generate a sequence by mapping over
// a sequence of observations for example.)
// This may be unnecessary in many cases -- optimization may
// work fine if we instead yielded e.g. `undefined`. But these
// models that rely on the structure of an observation would
// error out if we didn't do this.
if (_.isEmpty(data)) {
throw new Error('dream: data should be non empty.');
}
var ix = _.times(batchSize, function() {
return Math.floor(util.random() * data.length);
});
var batch = _.at(data, ix);
// We extend the address used to enter mapData so that addresses
// used while fantasizing don't overlap with those used when
// mapping over the real data.
// We don't return the original indices since it's not important
// that these are included in the address used to call the
// observation function.
return {data: batch, ix: null, address: a + '_dream'};
},
mapDataFinal: function() {
this.insideMapData = false;
this.guideRequired = true;
},
incrementalize: env.defaultCoroutine.incrementalize,
constructor: dreamSample
};
return {
dreamSample: function() {
var coroutine = Object.create(dreamSample.prototype);
dreamSample.apply(coroutine, arguments);
return coroutine.run();
}
};
};