webppl
Version:
Probabilistic programming for the web
80 lines • 3 kB
JavaScript
'use strict';
var ad = require('../../ad');
var assert = require('assert');
var _ = require('lodash');
var guide = require('../../guide');
module.exports = function (env) {
function dreamGradients(wpplFn, record, s, a, cont) {
this.wpplFn = wpplFn;
this.record = record;
this.s = s;
this.a = a;
this.cont = cont;
this.guideRequired = true;
this.isParamBase = true;
this.insideMapData = false;
this.oldCoroutine = env.coroutine;
env.coroutine = this;
}
dreamGradients.prototype = {
run: function () {
return this.estimateGradient(function (grad, objVal) {
env.coroutine = this.oldCoroutine;
return this.cont(grad, objVal);
}.bind(this));
},
estimateGradient: function (cont) {
this.paramsSeen = {};
this.logq = 0;
return this.wpplFn(_.clone(this.s), function (s, val) {
var objective = ad.scalar.neg(this.logq);
if (ad.isLifted(objective)) {
objective.backprop();
}
var grads = _.mapValues(this.paramsSeen, ad.derivative);
return cont(grads, ad.scalar.neg(ad.value(this.logq)));
}.bind(this), this.a);
},
sample: function (s, k, a, dist, options) {
options = options || {};
var choice = this.record.trace.findChoice(a);
assert.ok(ad.scalar.pneq(choice, undefined), 'dream: No entry for this choice in the trace.');
var val = choice.val;
if (this.insideMapData) {
return guide.getDist(options.guide, options.noAutoGuide, dist, env, s, a, function (s, guideDist) {
if (!guideDist) {
throw new Error('dream: No guide distribution specified.');
}
this.logq = ad.scalar.add(this.logq, guideDist.score(val));
return k(s, val);
}.bind(this));
} else {
return k(s, val);
}
},
factor: function (s, k, a, score) {
return k(s);
},
mapDataFetch: function (data, opts, a) {
if (this.insideMapData) {
throw new Error('dream: nested mapData is not supported by this estimator.');
}
this.insideMapData = true;
return {
data: this.record.data,
ix: null,
address: ad.scalar.add(a, '_dream')
};
},
mapDataFinal: function () {
this.insideMapData = false;
},
incrementalize: env.defaultCoroutine.incrementalize,
constructor: dreamGradients
};
return function () {
var coroutine = Object.create(dreamGradients.prototype);
dreamGradients.apply(coroutine, arguments);
return coroutine.run();
};
};