UNPKG

ts-trueskill

Version:

Port of python trueskill package in TypeScript

182 lines (181 loc) 5.45 kB
import { v1 as uuid } from 'uuid'; import { SkillGaussian } from './mathematics.js'; export class Variable extends SkillGaussian { messages = {}; setVal(val) { const delta = this.delta(val); this.pi = val.pi; this.tau = val.tau; return delta; } delta(other) { const piDelta = Math.abs(this.pi - other.pi); if (piDelta === Infinity) { return 0; } return Math.max(Math.abs(this.tau - other.tau), Math.sqrt(piDelta)); } updateMessage(factor, pi = 0, tau = 0, message) { const newMessage = message ? message : new SkillGaussian(null, null, pi, tau); const str = factor.toString(); const oldMessage = this.messages[str]; this.messages[str] = newMessage; return this.setVal(this.div(oldMessage).mul(newMessage)); } updateValue(factor, pi = 0, tau = 0, value) { value ||= new SkillGaussian(null, null, pi, tau); const oldMessage = this.messages[factor.toString()]; this.messages[factor.toString()] = value.mul(oldMessage).div(this); return this.setVal(value); } toString() { const count = Object.keys(this.messages).length; const s = count === 1 ? '' : 's'; const val = super.toString(); return `<Variable ${val} with ${count} connection${s}>`; } } export class Factor { vars; uuid; constructor(vars) { this.vars = vars; this.uuid = uuid(); const k = this.toString(); vars.forEach(v => { v.messages[k] = new SkillGaussian(); }); } down() { return 0; } up() { return 0; } get v() { if (this.vars.length !== 1) { throw new Error('Too long'); } return this.vars[0]; } toString() { const s = this.vars.length === 1 ? '' : 's'; return `<Factor with ${this.vars.length} connection${s} ${this.uuid}>`; } } export class PriorFactor extends Factor { val; dynamic; constructor(v, val, dynamic = 0) { super([v]); this.val = val; this.dynamic = dynamic; } down() { const sigma = Math.sqrt(this.val.sigma ** 2 + this.dynamic ** 2); const value = new SkillGaussian(this.val.mu, sigma); return this.v.updateValue(this, undefined, undefined, value); } } export class LikelihoodFactor extends Factor { mean; value; variance; constructor(mean, value, variance) { super([mean, value]); this.mean = mean; this.value = value; this.variance = variance; } calcA(v) { return 1.0 / (this.variance * v.pi + 1.0); } down() { const msg = this.mean.div(this.mean.messages[this.toString()]); const a = this.calcA(msg); return this.value.updateMessage(this, a * msg.pi, a * msg.tau); } up() { const msg = this.value.div(this.value.messages[this.toString()]); const a = this.calcA(msg); return this.mean.updateMessage(this, a * msg.pi, a * msg.tau); } } export class SumFactor extends Factor { sum; terms; coeffs; constructor(sum, terms, coeffs) { super([sum].concat(terms)); this.sum = sum; this.terms = terms; this.coeffs = coeffs; } down() { const k = this.toString(); const msgs = this.terms.map(v => v.messages[k]); return this.update(this.sum, this.terms, msgs, this.coeffs); } up(index = 0) { const coeff = this.coeffs[index]; let x = 0; const coeffs = this.coeffs.map(c => { let p = -c / coeff; if (x === index) { p = 1.0 / coeff; } p = Number.isFinite(p) ? p : 0; if (coeff === 0) { p = 0; } x += 1; return p; }); const vals = [...this.terms]; vals[index] = this.sum; const k = this.toString(); const msgs = vals.map(v => v.messages[k]); return this.update(this.terms[index], vals, msgs, coeffs); } update(v, vals, msgs, coeffs) { let piInv = 0; let mu = 0; for (let i = 0; i < vals.length; i++) { const val = vals[i]; const msg = msgs[i]; const coeff = coeffs[i]; const div = val.div(msg); mu += coeff * div.mu; if (!Number.isFinite(piInv)) { continue; } piInv += coeff ** 2 / div.pi; } const pi = 1.0 / piInv; const tau = pi * mu; return v.updateMessage(this, pi, tau); } } export class TruncateFactor extends Factor { vFunc; wFunc; drawMargin; constructor(v, vFunc, wFunc, drawMargin) { super([v]); this.vFunc = vFunc; this.wFunc = wFunc; this.drawMargin = drawMargin; } up() { const val = this.v; const msg = this.v.messages[this.toString()]; const div = val.div(msg); const sqrtPi = Math.sqrt(div.pi); const v = this.vFunc(div.tau / sqrtPi, this.drawMargin * sqrtPi); const w = this.wFunc(div.tau / sqrtPi, this.drawMargin * sqrtPi); const denom = 1.0 - w; const pi = div.pi / denom; const tau = (div.tau + sqrtPi * v) / denom; return val.updateValue(this, pi, tau); } }