webppl
Version:
Probabilistic programming for the web
74 lines • 2.21 kB
JavaScript
'use strict';
var ad = require('../ad');
var _ = require('lodash');
var base = require('./base');
var types = require('../types');
var util = require('../util');
var numeric = require('../math/numeric');
var T = ad.tensor;
function sample(theta, thetaSum) {
if (thetaSum === undefined) {
thetaSum = numeric._sum(theta);
}
var x = util.random() * thetaSum;
var k = theta.length;
var probAccum = 0;
for (var i = 0; i < k; i++) {
probAccum += theta[i];
if (x < probAccum) {
return i;
}
}
return k - 1;
}
function scoreVector(val, probs, norm) {
return ad.scalar.log(ad.scalar.div(T.get(probs, val), norm));
}
function scoreArray(val, probs, norm) {
return ad.scalar.log(ad.scalar.div(probs[val], norm));
}
function inSupport(val, dim) {
return val === Math.floor(val) && 0 <= val && val < dim;
}
function toUnliftedArray(x) {
return _.isArray(x) ? x.map(ad.value) : ad.value(x).data;
}
var Discrete = base.makeDistributionType({
name: 'Discrete',
desc: 'Distribution over ``{0,1,...,ps.length-1}`` with P(i) proportional to ``ps[i]``',
params: [{
name: 'ps',
desc: 'probabilities (can be unnormalized)',
type: types.nonNegativeVectorOrRealArray
}],
wikipedia: 'Categorical_distribution',
mixins: [base.finiteSupport],
constructor: function () {
if (_.isArray(this.params.ps)) {
this.norm = numeric.sum(this.params.ps);
this.scoreFn = scoreArray;
this.dim = this.params.ps.length;
} else {
this.norm = T.sumreduce(this.params.ps);
this.scoreFn = scoreVector;
this.dim = ad.value(this.params.ps).length;
}
},
sample: function () {
return sample(toUnliftedArray(this.params.ps), ad.value(this.norm));
},
score: function (val) {
if (inSupport(val, this.dim)) {
return this.scoreFn(val, this.params.ps, this.norm);
} else {
return -Infinity;
}
},
support: function () {
return _.range(this.dim);
}
});
module.exports = {
Discrete: Discrete,
sample: sample
};