UNPKG

webppl

Version:

Probabilistic programming for the web

120 lines (106 loc) 3.17 kB
'use strict'; var _ = require('lodash'); var estraverse = require('estraverse'); var build = require('ast-types').builders; var types = require('ast-types').types; var isPrimitive = require('../syntax').isPrimitive; var Syntax = estraverse.Syntax; var cacheExempt = [ 'flip', 'categorical', 'uniformDraw', 'bernoulli', 'beta', 'binomial', 'cauchy', 'delta', 'diagCovGaussian', 'dirichlet', 'discrete', 'exponential', 'gamma', 'gaussian', 'ispNormal', 'laplace', 'logisticNormal', 'logitNormal', 'multinomial', 'multivariateBernoulli', 'multivariateGaussian', 'poisson', 'randomInteger', 'tensorGaussian', 'tensorLaplace', 'uniform', 'factor', 'sample', 'observe', 'sampleWithFactor' ]; var cacheExemptTable = {}; _.each(cacheExempt, function(funcName) { cacheExemptTable[funcName] = true; }); cacheExempt = cacheExemptTable; function shouldCache(callee) { // Don't cache 'primitive' functions. It actually could be beneficial to cache // these in some cases, but correctly binding 'this' will require some // systemic changes that I don't want to deal with right now. if (isPrimitive(callee)) return false; // Don't cache sampling helpers or other coroutine functions that // deal with distributions. // Why do this? If the cache adaptation decides to remove one of these functions, // then that function will have the same address as the distribution it's dealing with, // so the adapter will also try to remove the distribution. // Basically, a core assumption of IncrementalMH is that all cache nodes have unique // addresses. if (callee.type === Syntax.Identifier && cacheExempt[callee.name]) return false; // Otherwise, go ahead return true; } function exit(node) { if (node.type === Syntax.CallExpression && shouldCache(node.callee)) { return build.callExpression( build.identifier('incrementalize'), [node.callee, build.arrayExpression(node.arguments)] ); } } function cachingMain(node) { return estraverse.replace(node, { leave: exit }); } function isImhIdentifier(node) { return node.type === 'Identifier' && node.name === 'IncrementalMH'; } function isImhInferMethodOption(node) { return node.type === 'Property' && ((node.key.type === 'Identifier' && node.key.name === 'method') || (node.key.type === 'Literal' && node.key.value === 'method')) && (node.value.type === 'Literal' && node.value.value === 'incrementalMH'); } function transformRequired(programAST) { var flag = false; estraverse.traverse(programAST, { enter: function(node) { if (isImhIdentifier(node) || isImhInferMethodOption(node)) { flag = true; this.break(); } } }); return flag; } function hasNoCachingDirective(ast) { return ast.body.length > 0 && ast.body[0].type === Syntax.ExpressionStatement && ast.body[0].expression.type === Syntax.Literal && ast.body[0].expression.value === 'no caching'; } module.exports = { transform: cachingMain, transformRequired: transformRequired, hasNoCachingDirective: hasNoCachingDirective };