UNPKG

webppl

Version:

Probabilistic programming for the web

1,129 lines (992 loc) 39 kB
//////////////////////////////////////////////////////////////////// // Incrementalized (i.e. caching) Lightweight MH 'use strict'; var _ = require('lodash'); var assert = require('assert'); var util = require('../util'); var Hashtable = require('../hashtable').Hashtable var Query = require('../query').Query; var CountAggregator = require('../aggregation/CountAggregator'); var cb = require('./callbacks'); module.exports = function(env) { var drift = require('./driftKernel')(env); // ------------------------------------------------------------------ // Debugging output var DEBUG = 0; function debuglog(debuglevel) { if (DEBUG >= debuglevel) { var args = Array.prototype.slice.call(arguments, 1); console.log.apply(console, args); } } function tabbedlog(debuglevel, depth) { if (DEBUG >= debuglevel) { var args = Array.prototype.slice.call(arguments, 2); var pad = ''; for (var i = 0; i < depth; i++) pad += ' '; pad += '[' + depth + '] '; console.log.apply(console, [pad].concat(args)); } } // ------------------------------------------------------------------ // A sort of 'copy on write' system for cache nodes function touch(node) { if (!node.__snapshot) { node.coroutine.touch(node); node.__snapshot = { reachable: true }; } } function updateProperty(node, prop, val) { touch(node); if (!node.__snapshot[prop]) node.__snapshot[prop] = node[prop]; node[prop] = val; } function hasSnapshotForProperty(node, prop) { return node.__snapshot !== undefined && node.__snapshot[prop] !== undefined; } function restoreSnapshot(node) { for (var prop in node.__snapshot) { node[prop] = node.__snapshot[prop]; } node.__snapshot = undefined; } function discardSnapshot(node) { node.__snapshot = undefined; } // ------------------------------------------------------------------ // A cached distribution call function DistNode(coroutine, parent, s, k, a, dist, args) { this.coroutine = coroutine; this.store = _.clone(s); this.continuation = k; this.address = a; this.dist = dist; this.args = args; this.parent = parent; this.depth = parent.depth + 1; this.index = parent.nextChildIdx; this.reachable = true; this.needsUpdate = false; this.val = dist.sample(); this.score = 0; this.rescore(); // Add this to the master list of distribution nodes this.coroutine.addDist(this); } DistNode.prototype.print = function() { tabbedlog(0, this.depth, 'DistNode', this.dist, this.val, this.reachable ? '' : '!!UNREACHABLE!!'); }; DistNode.prototype.execute = function() { tabbedlog(4, this.depth, 'execute distribution'); if (this.needsUpdate) { tabbedlog(4, this.depth, 'yes, dist changed'); tabbedlog(5, this.depth, 'old dist:', this.__snapshot ? this.__snapshot.dist : undefined, 'new dist:', this.dist); this.needsUpdate = false; this.rescore(); } else { tabbedlog(4, this.depth, 'no, dist has not changed'); tabbedlog(5, this.depth, 'dist:', this.dist); } // Bail out early if we know proposal will be rejected if (this.score === -Infinity) { tabbedlog(4, this.depth, 'score became -Infinity; bailing out early'); return this.coroutine.exit(); } else { return this.kontinue(); } }; DistNode.prototype.registerInputChanges = function(s, k, dist, args) { updateProperty(this, 'store', _.clone(s)); updateProperty(this, 'continuation', k); updateProperty(this, 'index', this.parent.nextChildIdx); updateProperty(this, 'args', args); this.reachable = true; if (!distEqual(dist, this.dist)) { this.needsUpdate = true; updateProperty(this, 'dist', dist); } }; DistNode.prototype.kontinue = function() { this.parent.notifyChildExecuted(this); // Call continuation // Copies store, so that we maintain a pristine copy of this.store return this.continuation(_.clone(this.store), this.val); }; DistNode.prototype.killDescendantLeaves = function() { this.coroutine.removeDist(this); }; DistNode.prototype.propose = function() { var oldval = this.val; var sampleOptions = this.args[0]; return drift.getProposalDist( this.store, this.address, this.dist, sampleOptions, oldval, function(s, fwdPropDist) { var newval = fwdPropDist.sample(); tabbedlog(4, this.depth, 'proposing change to distribution.', 'oldval:', oldval, 'newval:', newval); // If the value didn't change, then just bail out (we know the // the proposal will be accepted) if (oldval === newval) { tabbedlog(4, this.depth, "proposal didn't change value; bailing out early"); tabbedlog(5, this.depth, 'value:', this.val); return this.coroutine.exit(); } else { updateProperty(this, 'store', _.clone(this.store)); updateProperty(this, 'val', newval); this.rescore(); if (this.score === -Infinity && _.has(sampleOptions, 'driftKernel')) { drift.proposalWarning(this.dist); } return drift.getProposalDist( s, this.address, this.dist, sampleOptions, newval, function(s, rvsPropDist) { this.coroutine.rvsPropLP = rvsPropDist.score(oldval); this.coroutine.fwdPropLP = fwdPropDist.score(newval); tabbedlog(1, this.depth, 'initial rvsPropLP:', this.coroutine.rvsPropLP, 'initial fwdPropLP:', this.coroutine.fwdPropLP); this.needsUpdate = false; if (this.coroutine.doFullRerun) { // Mark every node above this one as needing update, then re-run // the program from the start for (var node = this.parent; node !== null; node = node.parent) node.needsUpdate = true; return this.coroutine.runFromStart(); } else { this.parent.notifyChildChanged(this); // Restore node stack up to this point this.coroutine.restoreStackUpTo(this.parent); return this.execute(); } }.bind(this)); } }.bind(this)); }; DistNode.prototype.rescore = function() { var oldscore = this.score; updateProperty(this, 'score', this.dist.score(this.val)); this.coroutine.score += this.score - oldscore; }; // This is used to decide whether to re-score a distribution. Only a // shallow check for parameter equality is performed, as a deep // check is unlikely to be any faster than re-scoring. function distEqual(dist1, dist2) { return dist1.constructor === dist2.constructor && distParamsEqual(dist1.params, dist2.params); } function distParamsEqual(p1, p2) { if (p1 === p2) { return true; } //assert.strictEqual(_.size(p1), _.size(p2)); for (var k in p1) { if (p1.hasOwnProperty(k)) { //assert.ok(p2.hasOwnProperty(k)); var v1 = p1[k], v2 = p2[k]; if (typeof v1 === 'number') { if (v1 !== v2) { return false; } } else if (_.isArray(v1)) { if (!_.isEqual(v1, v2)) { return false; } } else { return false; } } } return true; } // ------------------------------------------------------------------ // A cached factor call function FactorNode(coroutine, parent, s, k, a, unused, args) { this.coroutine = coroutine; this.store = s; this.continuation = k; this.address = a; this.parent = parent; this.depth = parent.depth + 1; this.index = parent.nextChildIdx; this.reachable = true; this.rescore(0, args[0]); tabbedlog(3, this.depth, 'new factor'); } FactorNode.prototype.print = function() { tabbedlog(0, this.depth, 'FactorNode', this.reachable ? '' : '!!UNREACHABLE!!'); }; FactorNode.prototype.execute = function() { // Bail out early if we know proposal will be rejected if (this.score === -Infinity) { tabbedlog(4, this.depth, 'score became -Infinity; bailing out early'); return this.coroutine.exit(); } else { return this.kontinue(); } }; FactorNode.prototype.registerInputChanges = function(s, k, unused, args) { updateProperty(this, 'store', s); updateProperty(this, 'continuation', k); updateProperty(this, 'index', this.parent.nextChildIdx); this.reachable = true; if (this.score !== args[0]) this.rescore(this.score, args[0]); }; FactorNode.prototype.kontinue = function() { this.parent.notifyChildExecuted(this); return this.continuation(this.store); }; FactorNode.prototype.killDescendantLeaves = function() { tabbedlog(3, this.depth, 'kill factor', this.address); this.coroutine.score -= this.score; } FactorNode.prototype.rescore = function(oldscore, score) { updateProperty(this, 'score', score); this.coroutine.score += score - oldscore; }; // ------------------------------------------------------------------ // Comparing two stores for shallow equality function storesEqual(s1, s2) { var prop; for (prop in s1) { if (s1[prop] !== s2[prop]) return false; } for (prop in s2) { if (s1[prop] !== s2[prop]) return false; } return true; } // Checks whether two functions are equivalent var fnEquivCache = {}; function fnsEqual(f1, f2) { // If the two functions are literally the same closure, then of course // they are equivalent. if (f1 === f2) return true; // Otherwise, they're equivalent if they come from the same source location // and the values of the variables they close over are the same. // First, we check if the functions actually have this metadata. External header // functions do not, so we must return false, to ensure correct behavior. if (f1.__lexid === undefined || f2.__lexid === undefined) return false; // We cache this check, because situations often arise where we're checking // the same pair of functions over and over again. if (f1.__lexid === f2.__lexid) { var key = JSON.stringify([f1.__uniqueid, f2.__uniqueid]); var val = fnEquivCache[key]; if (val === undefined) { val = true; for (var i = 0; i < f1.__freeVarVals.length; i++) { var v1 = f1.__freeVarVals[i]; var v2 = f2.__freeVarVals[i]; // If they're both functions, recursively apply this check var eq = (_.isFunction(v1) && _.isFunction(v2)) ? fnsEqual(v1, v2) : v1 === v2; if (!eq) { val = false; break; } } fnEquivCache[key] = val; } return val; } return false; } // A cached, general WebPPL function call function FunctionNode(coroutine, parent, s, k, a, fn, args) { this.coroutine = coroutine; this.continuation = k; this.address = a; this.func = fn; this.parent = parent; this.depth = parent ? parent.depth + 1 : 0; this.index = parent ? parent.nextChildIdx : undefined; this.children = []; this.nextChildIdx = 0; this.reachable = true; this.needsUpdate = true; this.inStore = _.clone(s); this.args = args; this.initialized = false; this.retval = undefined; this.outStore = null; } FunctionNode.prototype.print = function() { tabbedlog(0, this.depth, 'FunctionNode', this.args, this.retval, this.reachable ? '' : '!!UNREACHABLE!!'); for (var i = 0; i < this.children.length; i++) this.children[i].print(); }; FunctionNode.prototype.execute = function() { tabbedlog(4, this.depth, 'execute function'); if (this.needsUpdate) { if (this.initialized) this.coroutine.cacheAdapter.registerMiss(this); tabbedlog(4, this.depth, 'yes, function args changed; executing'); tabbedlog(5, this.depth, 'old args:', this.__snapshot ? this.__snapshot.args : undefined, 'new args:', this.args); this.needsUpdate = false; // Keep track of program stack this.coroutine.nodeStack.push(this); // Reset nextChildIdx this.nextChildIdx = 0; // Mark all children as unreachable; execution will then determine which // ones are actually reachable var nchildren = this.children.length; for (var i = 0; i < nchildren; i++) { touch(this.children[i]); this.children[i].reachable = false; } tabbedlog(4, this.depth, 'Children marked unreachable on execute:', nchildren); // Preserve reference to coroutine object so // continuation can refer to it. var coroutine = this.coroutine; // Also preserve references to this node's continuation and address. // We need these to bypass the cache 'on-exit' processing stuff and just // invoke the original continuation if and when nodes with this address // stop being cached (by the cache adapter). var continuation = this.continuation; var address = this.address; // Record the fact that we entered this function. this.entered = true; return this.func.apply(global, [ this.inStore, function(s, retval) { // If we've stopped caching nodes with this address, just immediately continue if (!coroutine.cacheAdapter.shouldCache(address)) return continuation(s, retval); // Recover a reference to 'this' var that = coroutine.nodeStack.pop(); tabbedlog(4, that.depth, 'continue from function'); that.initialized = true; // Clear out any children that have become unreachable var newchildren = []; var nchildren = that.children.length; var ii = 0; for (var i = 0; i < nchildren; i++) { var child = that.children[i]; if (!child.reachable) child.killDescendantLeaves(); else { updateProperty(child, 'index', ii++); newchildren.push(child); } } updateProperty(that, 'children', newchildren); // If the return value and output store haven't changed, then we can bail early. // We can only do this if this call is returning from a change somewhere below it // (i.e. that.entered == false). Otherwise, we need to keep running. if (!that.entered && that.retval === retval && storesEqual(that.outStore, s)) { tabbedlog(4, that.depth, 'bailing b/c function return val not changed'); tabbedlog(5, that.depth, 'return val:', retval); coroutine.cacheAdapter.registerHit(that); return coroutine.exit(); } if (!that.entered && that.parent !== null) that.parent.notifyChildChanged(that); if (!that.entered) coroutine.cacheAdapter.registerMiss(that); that.entered = false; tabbedlog(4, that.depth, 'function return val has changed, or cannot bail'); tabbedlog(5, that.depth, 'old ret val:', that.retval, 'new ret val:', retval); // Update output values updateProperty(that, 'retval', retval); updateProperty(that, 'outStore', _.clone(s)); // Continue execution return that.kontinue(); }, this.address ].concat(this.args)); } else { this.coroutine.cacheAdapter.registerHit(this); tabbedlog(4, this.depth, 'no, function args have not changed; continuing'); tabbedlog(5, this.depth, 'args:', this.args); return this.kontinue(); } }; FunctionNode.prototype.registerInputChanges = function(s, k, fn, args) { updateProperty(this, 'continuation', k); if (this.parent) updateProperty(this, 'index', this.parent.nextChildIdx); this.reachable = true; // Check fn for changes if (!fnsEqual(fn, this.func)) { this.needsUpdate = true; updateProperty(this, 'func', fn); } // Check args for changes if (this.args.length !== args.length) { this.needsUpdate = true; updateProperty(this, 'args', args); } else { var i = args.length; for (var i = 0; i < args.length; i++) { if (args[i] !== this.args[i]) { this.needsUpdate = true; updateProperty(this, 'args', args); break; } } } // Check store for changes if (!storesEqual(this.inStore, s)) { this.needsUpdate = true; updateProperty(this, 'inStore', _.clone(s)); } }; FunctionNode.prototype.killDescendantLeaves = function() { tabbedlog(3, this.depth, 'kill function (and all descendant leaves)', this.address); var stack = [this]; while (stack.length > 0) { var node = stack.pop(); if (node.score !== undefined) node.killDescendantLeaves(); else { var n = node.children.length; while (n--) stack.push(node.children[n]); } } }; FunctionNode.prototype.kontinue = function() { if (this.parent !== null) this.parent.notifyChildExecuted(this); // Call continuation // Copies outStore, so we maintain a pristine record of it. return this.continuation(_.clone(this.outStore), this.retval); }; FunctionNode.prototype.notifyChildExecuted = function(child) { this.nextChildIdx = child.index + 1; }; FunctionNode.prototype.notifyChildChanged = function(child) { // Children later in the execution order may become unreachable due // to this change, so we mark them all as unreachable and see which // ones we hit. var nchildren = this.children.length; var totalmarked = 0; for (var i = child.index + 1; i < nchildren; i++) { touch(this.children[i]); this.children[i].reachable = false; totalmarked++; } tabbedlog(4, this.depth, 'Children marked unreachable on child change:', totalmarked); }; // Called by the cache adapter when it determines that this node isn't giving // good enough cache efficiency and should be removed from the cache. FunctionNode.prototype.removeFromCache = function() { // First, verify that this node actually *is* in the cache (i.e. hasn't // already been removed via rejection), otherwise the subsequent ops // ops we perform will screw things up. if (this.parent.children[this.index] === this) { // Correct the various metadata on the children of this node. var n = this.children.length; for (var i = 0; i < n; i++) { var child = this.children[i]; child.parent = this.parent; child.index = this.index + i; // Recursively adjust depth of all descendants var stack = [child]; while (stack.length > 0) { var node = stack.pop(); node.depth--; if (node.children) { var nn = node.children.length; while (nn--) stack.push(node.children[nn]); } } } // Correct the indices of the subsequent siblings of this node, to // account for the children that are about to be moved up. for (var i = this.index + 1; i < this.parent.children.length; i++) { this.parent.children[i].index = i + n - 1; } // Remove this and move up this.children into this.parent.children this.children.unshift(this.index, 1); Array.prototype.splice.apply(this.parent.children, this.children); } } // ------------------------------------------------------------------ // Abstraction representing a master list of distributions // (lets us abstract over whether we're using an array or a hash table) function ArrayDistMasterList() { this.distNodes = []; } ArrayDistMasterList.prototype.size = function() { return this.distNodes.length; } ArrayDistMasterList.prototype.oldSize = function() { return this.oldDistNodes === undefined ? undefined : this.oldDistNodes.length; } ArrayDistMasterList.prototype.addDist = function(node) { this.distNodes.push(node); }; ArrayDistMasterList.prototype.removeDist = function(node) { // Set it up to be removed as a post-process touch(node); node.reachable = false; }; ArrayDistMasterList.prototype.preProposal = function() { this.oldDistNodes = this.distNodes.slice(); }; ArrayDistMasterList.prototype.postProposal = function() { this.distNodes = _.filter(this.distNodes, function(node) { return node.reachable; }); }; ArrayDistMasterList.prototype.getRandom = function() { var idx = Math.floor(util.random() * this.distNodes.length); return this.distNodes[idx]; }; ArrayDistMasterList.prototype.restoreOnReject = function() { this.distNodes = this.oldDistNodes; }; function HashtableDistMasterList() { this.distNodeMap = new Hashtable(); this.distsAdded = []; this.distsRemoved = []; this.numDists = 0; } HashtableDistMasterList.prototype.size = function() { return this.numDists; } HashtableDistMasterList.prototype.oldSize = function() { return this.oldNumDists; } HashtableDistMasterList.prototype.addDist = function(node) { this.distNodeMap.put(node.address, node); this.distsAdded.push(node); this.numDists++; // this.checkConsistency("addDist"); }; HashtableDistMasterList.prototype.removeDist = function(node) { this.distNodeMap.remove(node.address); this.distsRemoved.push(node); this.numDists--; // this.checkConsistency("removeDist"); }; HashtableDistMasterList.prototype.preProposal = function() { this.oldNumDists = this.numDists; this.distsAdded = []; this.distsRemoved = []; }; HashtableDistMasterList.prototype.postProposal = function() {}; HashtableDistMasterList.prototype.getRandom = function() { return this.distNodeMap.getRandom(); } HashtableDistMasterList.prototype.restoreOnReject = function() { // this.checkConsistency("restoreOnReject"); this.numDists = this.oldNumDists; var n = this.distsAdded.length; while (n--) { var node = this.distsAdded[n]; this.distNodeMap.remove(node.address); } n = this.distsRemoved.length; while (n--) { var node = this.distsRemoved[n]; this.distNodeMap.put(node.address, node); } }; // For debugging HashtableDistMasterList.prototype.checkConsistency = function(tag) { for (var i = 0; i < this.distsAdded.length; i++) { var addr = this.distsAdded[i].address; if (!this.distNodeMap.get(addr)) throw new Error('WTF - hash table doesn\'t contain node ' + addr + ' that we added (' + tag + ')'); } for (var i = 0; i < this.distsRemoved.length; i++) { var addr = this.distsRemoved[i].address; if (this.distNodeMap.get(addr)) throw new Error('WTF - hash table contains node ' + addr + ' that we removed (' + tag + ')'); } }; // ------------------------------------------------------------------ // Tracks statistics on how the cache is performing, so we can make // decisions about when to stop caching certain functions function CacheAdapter(minHitRate, fuseLength, iterFuseLength) { this.minHitRate = minHitRate; this.fuseLength = fuseLength; this.iterFuseLength = iterFuseLength; this.addrToId = {}; this.stats = {}; this.idsToRemove = {}; this.hasIdsToRemove = false; } CacheAdapter.prototype.id = function(addr) { var id = this.addrToId[addr]; if (id === undefined) { var arr = addr.split('_'); id = arr[arr.length - 1]; this.addrToId[addr] = id; } return id; } CacheAdapter.prototype.getStats = function(addr) { var id = this.id(addr); var stats = this.stats[id]; if (stats === undefined) { stats = {shouldCache: true, hits: 0, total: 0}; this.stats[id] = stats; } return stats; } CacheAdapter.prototype.shouldCache = function(addr) { return this.getStats(addr).shouldCache; }; CacheAdapter.prototype.registerHit = function(node) { var stats = this.getStats(node.address); stats.hits++; stats.total++; if (node.parent !== null && // Can't remove the cache root (node.coroutine.totalIterations - node.coroutine.iterations) > this.iterFuseLength && stats.total >= this.fuseLength && stats.hits / stats.total < this.minHitRate) { this.idsToRemove[this.id(node.address)] = true; this.hasIdsToRemove = true; } }; CacheAdapter.prototype.registerMiss = function(node) { var stats = this.getStats(node.address); stats.total++; if (node.parent !== null && // Can't remove the cache root (node.coroutine.totalIterations - node.coroutine.iterations) > this.iterFuseLength && stats.total >= this.fuseLength && stats.hits / stats.total < this.minHitRate) { this.idsToRemove[this.id(node.address)] = true; this.hasIdsToRemove = true; } } CacheAdapter.prototype.adapt = function(cacheRoot) { if (this.hasIdsToRemove) { for (var id in this.idsToRemove) { var s = this.stats[id]; debuglog(5, 'Cache adapter removing nodes w/ id', id, 'hit rate:', s.hits / s.total); this.stats[id].shouldCache = false; } // Traverse cache and remove all nodes with this id var stack = [cacheRoot]; while (stack.length > 0) { var node = stack.pop(); if (node.children !== undefined) { var n = node.children.length; while (n--) stack.push(node.children[n]); } if (this.idsToRemove[this.id(node.address)]) { node.removeFromCache(); } } this.idsToRemove = {}; this.hasIdsToRemove = false; if (DEBUG >= 6) { debuglog(6, '=== Post-adaptation cache status ==='); cacheRoot.print(); } } }; CacheAdapter.prototype.report = function() { for (var id in this.stats) { console.log(id, ':', this.stats[id]); } }; // ------------------------------------------------------------------ function IncrementalMH(s, k, a, wpplFn, opts) { opts = util.mergeDefaults(opts, { samples: 1, dontAdapt: false, debuglevel: 0, verbose: false, doFullRerun: false, onlyMAP: false, cacheMinHitRate: 0.00000001, cacheFuseLength: 50, lag: 0, cacheIterFuseLength: 10, burn: 0, verboseLag: 1, callbacks: [] }, 'IncrementalMH'); // Doing a full re-run doesn't really jive with the heuristic we use for adaptive // caching, so disable adaptation in this case. if (opts.doFullRerun) opts.dontAdapt = true; DEBUG = opts.debuglevel; this.verbose = opts.verbose; this.k = k; this.oldStore = s; this.iterations = opts.samples * (opts.lag + 1) + opts.burn; this.wpplFn = wpplFn; this.s = s; this.a = a; this.aggregator = new CountAggregator(opts.onlyMAP); this.totalIterations = this.iterations; this.acceptedProps = 0; this.lag = opts.lag; this.burn = opts.burn; this.verboseLag = opts.verboseLag; this.callbacks = cb.prepare(opts.callbacks); this.doFullRerun = opts.doFullRerun; this.doAdapt = !opts.dontAdapt; this.cacheAdapter = new CacheAdapter(opts.cacheMinHitRate, opts.cacheFuseLength, opts.cacheIterFuseLength); // Move old coroutine out of the way and install this as the current // handler. this.oldCoroutine = env.coroutine; env.coroutine = this; } IncrementalMH.prototype.run = function() { this.cacheRoot = null; this.distMasterList = new HashtableDistMasterList(); // this.distMasterList = new ArrayDistMasterList(); this.touchedNodes = []; this.score = 0; this.fwdPropLP = 0; this.rvsPropLP = 0; this.query = new Query(); env.query.clear(); debuglog(1, '-------------------------------------'); debuglog(1, 'RUN FROM START'); return this.runFromStart(); }; IncrementalMH.prototype.runFromStart = function() { this.nodeStack = []; // Cache the top-level function, so that we always have a valid // cache root. return this.incrementalize(this.s, env.exit, this.a, this.wpplFn, []); }; IncrementalMH.prototype.factor = function(s, k, a, score) { return this.cachelookup(FactorNode, s, k, a, null, [score]).execute(); }; IncrementalMH.prototype.sample = function(s, k, a, dist, options, name) { var n = this.cachelookup(DistNode, s, k, a, dist, [options]); n.name = name; return n.execute(); }; // A node should call this on itself if it makes some change to itself. IncrementalMH.prototype.touch = function(node) { this.touchedNodes.push(node); }; function acceptProb(currScore, oldScore, currN, oldN, rvsPropLP, fwdPropLP) { if (oldScore === undefined) { return 1; } // init if (currScore === -Infinity) return 0; // auto-reject debuglog(1, 'currScore:', currScore, 'oldScore', oldScore); debuglog(1, 'rvsPropLP:', rvsPropLP, 'fwdPropLP:', fwdPropLP); var fw = -Math.log(oldN) + fwdPropLP; var bw = -Math.log(currN) + rvsPropLP; var p = Math.exp(currScore - oldScore + bw - fw); assert.ok(!isNaN(p)); var acceptance = Math.min(1, p); return acceptance; } // Returns true if we've successfully rejection initialized. IncrementalMH.prototype.isInitialized = function() { return this.iterations < this.totalIterations; } IncrementalMH.prototype.exit = function() { if (this.iterations > 0) { // Initialization: Keep rejection sampling until we get a trace with // non-zero probability if (!this.isInitialized() && this.score === -Infinity) { return this.run(); } else { var iternum = this.totalIterations - this.iterations; debuglog(1, 'iteration', iternum); if (this.verbose && (this.iterations % this.verboseLag === 0)) { if (iternum > this.burn) { console.log('IncrementalMH iteration ' + (iternum - this.burn) + ' / ' + (this.totalIterations - this.burn)); } else { console.log('IncrementalMH burnin ' + iternum + ' / ' + this.burn); } } // Continue proposing as normal this.iterations--; this.distMasterList.postProposal(); debuglog(2, 'Num vars:', this.distMasterList.size()); debuglog(2, 'Touched nodes:', this.touchedNodes.length); // Accept/reject the current proposal var acceptance = acceptProb(this.score, this.oldScore, this.distMasterList.size(), this.distMasterList.oldSize(), this.rvsPropLP, this.fwdPropLP); debuglog(1, 'num vars:', this.distMasterList.size(), 'old num vars:', this.distMasterList.oldSize()); debuglog(1, 'acceptance prob:', acceptance); if (util.random() >= acceptance) { debuglog(1, 'REJECT'); this.score = this.oldScore; var n = this.touchedNodes.length; while (n--) restoreSnapshot(this.touchedNodes[n]); this.distMasterList.restoreOnReject(); } else { debuglog(1, 'ACCEPT'); var n = this.touchedNodes.length; while (n--) discardSnapshot(this.touchedNodes[n]); this.acceptedProps++; this.query.addAll(env.query); } env.query.clear(); var val = this.cacheRoot.retval; debuglog(1, 'return val:', val); // Record this sample, if lag allows for it and not in burnin period if ((iternum % (this.lag + 1) === 0) && (iternum >= this.burn)) { // Replace val with accumulated query, if need be. if (val === env.query) val = this.query.getTable(); // add val to hist: this.aggregator.add(val, this.score); this.callbacks.sample({value: val, score: this.score}); } if (DEBUG >= 6) { debuglog(6, '=== Cache status ==='); this.cacheRoot.print(); } // this.checkReachabilityConsistency(); if (this.doAdapt) this.cacheAdapter.adapt(this.cacheRoot); if (this.distMasterList.numDists > 0) { // Prepare to make a new proposal this.oldScore = this.score; this.distMasterList.preProposal(); this.touchedNodes = []; this.fwdPropLP = 0; this.rvsPropLP = 0; // Select distribution to change. var propnode = this.distMasterList.getRandom(); // Propose change and resume execution debuglog(1, '----------------------------------------------------------------------'); debuglog(1, 'PROPOSAL', 'type:', propnode.dist, 'address:', propnode.address); return propnode.propose(); } else { return this.runFromStart(); } } } else { // Reinstate previous coroutine: var k = this.k; env.coroutine = this.oldCoroutine; if (DEBUG >= 5) { console.log('Acceptance ratio: ' + this.acceptedProps / this.totalIterations); this.cacheAdapter.report(); } this.callbacks.finish(); // Return by calling original continuation: return k(this.oldStore, this.aggregator.toDist()); } }; IncrementalMH.prototype.incrementalize = function(s, k, a, fn, args) { if (this.cacheAdapter.shouldCache(a)) return this.cachelookup(FunctionNode, s, k, a, fn, args).execute(); else return env.defaultCoroutine.incrementalize(s, k, a, fn, args); }; // Returns a cache node IncrementalMH.prototype.cachelookup = function(NodeType, s, k, a, fn, args) { var cacheNode; // If the cache is empty, then initialize it. if (this.cacheRoot === null) { cacheNode = new NodeType(this, null, s, k, a, fn, args); this.cacheRoot = cacheNode; // If the node stack is empty, then we must be looking up the root on a // re-run from start } else if (this.nodeStack.length === 0) { if (a !== this.cacheRoot.address) throw new Error('Wrong address for cache root lookup'); cacheNode = this.cacheRoot; cacheNode.registerInputChanges(s, k, fn, args); // Otherwise, do the general thing. } else { var currNode = this.nodeStack[this.nodeStack.length - 1]; tabbedlog(3, currNode.depth, 'lookup', NodeType.name, a); // Look for cache node among the children of currNode cacheNode = this.findNode(currNode, a); if (cacheNode) { // Lookup successful; check for changes to store/args and move on. tabbedlog(3, currNode.depth, 'found'); cacheNode.registerInputChanges(s, k, fn, args); } else { // Lookup failed; create new node and insert it into currNode.children if (DEBUG) { var addrs = _.map(_.filter(currNode.children, function(node) { return node instanceof NodeType; }), function(node) { return node.address; }); tabbedlog(3, currNode.depth, '*not* found'); tabbedlog(4, currNode.depth, 'options were', addrs); } cacheNode = new NodeType(this, currNode, s, k, a, fn, args); var insertidx = currNode.nextChildIdx; // Copy the children array if we don't already have a snapshot for it // Kind of annoying that this somewhat breaks the abstraction of snapshots, but // I think it's worth it. if (!hasSnapshotForProperty(currNode, 'children')) updateProperty(currNode, 'children', currNode.children.slice()); currNode.children.splice(insertidx, 0, cacheNode); } } return cacheNode; }; IncrementalMH.prototype.findNode = function(parentNode, address) { // If we haven't initialized yet (i.e. we're running the program for // the first time), then don't even bother looking. if (!this.isInitialized()) return undefined; // Need to snapshot the children array, since we perform swaps on it to keep nodes // in execution order. var nodes = parentNode.children; var nexti = parentNode.nextChildIdx; for (var i = nexti; i < nodes.length; i++) { // for (var i = 0; i < nodes.length; i++) { if (nodes[i].address === address) { // Keep nodes ordered according to execution order: if // i !== nexti, then swap those two. if (i !== nexti) { // if (i < nexti) throw new Error('WTF - cache node found *before* first possible location'); if (!hasSnapshotForProperty(parentNode, 'children')) { nodes = nodes.slice(); updateProperty(parentNode, 'children', nodes); } var tmp = nodes[i]; nodes[i] = nodes[nexti]; nodes[nexti] = tmp; } return nodes[nexti]; } } }; IncrementalMH.prototype.addDist = function(node) { tabbedlog(3, node.depth, 'new distribution'); this.distMasterList.addDist(node); this.fwdPropLP += node.score; }; IncrementalMH.prototype.removeDist = function(node) { tabbedlog(3, node.depth, 'kill distribution', node.address); this.distMasterList.removeDist(node); this.rvsPropLP += node.score; this.score -= node.score; }; // Restore this.nodeStack up to the specified node IncrementalMH.prototype.restoreStackUpTo = function(node) { this.nodeStack = []; while (node !== null) { this.nodeStack.push(node); node = node.parent; } this.nodeStack.reverse(); }; // For debugging purposes, mostly IncrementalMH.prototype.addressInCache = function(addr) { var stack = [this.cacheRoot]; while (stack.length > 0) { var node = stack.pop(); if (node.address === addr) return true; if (node.children !== undefined) for (var i = 0; i < node.children.length; i++) stack.push(node.children[i]); } return false; }; // Also for debugging IncrementalMH.prototype.checkReachabilityConsistency = function() { var stack = [this.cacheRoot]; while (stack.length > 0) { var node = stack.pop(); if (!node.reachable) throw new Error('WTF - found unreachable node in cache.'); if (node.children !== undefined) for (var i = 0; i < node.children.length; i++) stack.push(node.children[i]); } return false; }; // ------------------------------------------------------------------ function imh(s, cc, a, wpplFn, opts) { opts = opts || {}; return new IncrementalMH(s, cc, a, wpplFn, opts).run(); } return { IncrementalMH: imh }; };