UNPKG

webppl

Version:

Probabilistic programming for the web

68 lines 2.29 kB
'use strict'; var ad = require('../ad'); var _ = require('lodash'); var base = require('./base'); var types = require('../types'); var util = require('../util'); var numeric = require('../math/numeric'); var Marginal = require('./marginal').Marginal; var T = ad.tensor; var Categorical = base.makeDistributionType({ name: 'Categorical', desc: 'Distribution over elements of ``vs`` with ``P(vs[i])`` proportional to ``ps[i]``. ' + '``ps`` may be omitted, in which case a uniform distribution over ``vs`` is returned.', params: [ { name: 'ps', desc: 'probabilities (can be unnormalized)', type: types.nonNegativeVectorOrRealArray, optional: true }, { name: 'vs', desc: 'support', type: types.array(types.any) } ], wikipedia: true, mixins: [base.finiteSupport], constructor: function () { if (ad.scalar.peq(this.params.ps, undefined)) { this.params = { ps: _.fill(Array(this.params.vs.length), 1), vs: this.params.vs }; } var ps = this.params.ps; var vs = this.params.vs; if (ad.scalar.pneq(vs.length, ad.value(ps).length)) { throw new Error('Parameters ps and vs should have the same length.'); } if (ad.scalar.peq(vs.length, 0)) { throw new Error('Parameters ps and vs should have length > 0.'); } var dist = {}; var norm = _.isArray(ps) ? numeric.sum(ps) : T.sumreduce(ps); for (var i in vs) { var val = vs[i]; var k = util.serialize(val); if (!_.has(dist, k)) { dist[k] = { val: val, prob: 0 }; } dist[k].prob = ad.scalar.add(dist[k].prob, ad.scalar.div(_.isArray(ps) ? ps[i] : T.get(ps, i), norm)); } this.marginal = new Marginal({ dist: dist }); }, sample: function () { return this.marginal.sample(); }, score: function (val) { return this.marginal.score(val); }, support: function () { return this.marginal.support(); } }); module.exports = { Categorical: Categorical };