ts-trueskill
Version:
Port of python trueskill package in TypeScript
182 lines (181 loc) • 5.45 kB
JavaScript
import { v1 as uuid } from 'uuid';
import { SkillGaussian } from './mathematics.js';
export class Variable extends SkillGaussian {
messages = {};
setVal(val) {
const delta = this.delta(val);
this.pi = val.pi;
this.tau = val.tau;
return delta;
}
delta(other) {
const piDelta = Math.abs(this.pi - other.pi);
if (piDelta === Infinity) {
return 0;
}
return Math.max(Math.abs(this.tau - other.tau), Math.sqrt(piDelta));
}
updateMessage(factor, pi = 0, tau = 0, message) {
const newMessage = message ? message : new SkillGaussian(null, null, pi, tau);
const str = factor.toString();
const oldMessage = this.messages[str];
this.messages[str] = newMessage;
return this.setVal(this.div(oldMessage).mul(newMessage));
}
updateValue(factor, pi = 0, tau = 0, value) {
value ||= new SkillGaussian(null, null, pi, tau);
const oldMessage = this.messages[factor.toString()];
this.messages[factor.toString()] = value.mul(oldMessage).div(this);
return this.setVal(value);
}
toString() {
const count = Object.keys(this.messages).length;
const s = count === 1 ? '' : 's';
const val = super.toString();
return `<Variable ${val} with ${count} connection${s}>`;
}
}
export class Factor {
vars;
uuid;
constructor(vars) {
this.vars = vars;
this.uuid = uuid();
const k = this.toString();
vars.forEach(v => {
v.messages[k] = new SkillGaussian();
});
}
down() {
return 0;
}
up() {
return 0;
}
get v() {
if (this.vars.length !== 1) {
throw new Error('Too long');
}
return this.vars[0];
}
toString() {
const s = this.vars.length === 1 ? '' : 's';
return `<Factor with ${this.vars.length} connection${s} ${this.uuid}>`;
}
}
export class PriorFactor extends Factor {
val;
dynamic;
constructor(v, val, dynamic = 0) {
super([v]);
this.val = val;
this.dynamic = dynamic;
}
down() {
const sigma = Math.sqrt(this.val.sigma ** 2 + this.dynamic ** 2);
const value = new SkillGaussian(this.val.mu, sigma);
return this.v.updateValue(this, undefined, undefined, value);
}
}
export class LikelihoodFactor extends Factor {
mean;
value;
variance;
constructor(mean, value, variance) {
super([mean, value]);
this.mean = mean;
this.value = value;
this.variance = variance;
}
calcA(v) {
return 1.0 / (this.variance * v.pi + 1.0);
}
down() {
const msg = this.mean.div(this.mean.messages[this.toString()]);
const a = this.calcA(msg);
return this.value.updateMessage(this, a * msg.pi, a * msg.tau);
}
up() {
const msg = this.value.div(this.value.messages[this.toString()]);
const a = this.calcA(msg);
return this.mean.updateMessage(this, a * msg.pi, a * msg.tau);
}
}
export class SumFactor extends Factor {
sum;
terms;
coeffs;
constructor(sum, terms, coeffs) {
super([sum].concat(terms));
this.sum = sum;
this.terms = terms;
this.coeffs = coeffs;
}
down() {
const k = this.toString();
const msgs = this.terms.map(v => v.messages[k]);
return this.update(this.sum, this.terms, msgs, this.coeffs);
}
up(index = 0) {
const coeff = this.coeffs[index];
let x = 0;
const coeffs = this.coeffs.map(c => {
let p = -c / coeff;
if (x === index) {
p = 1.0 / coeff;
}
p = Number.isFinite(p) ? p : 0;
if (coeff === 0) {
p = 0;
}
x += 1;
return p;
});
const vals = [...this.terms];
vals[index] = this.sum;
const k = this.toString();
const msgs = vals.map(v => v.messages[k]);
return this.update(this.terms[index], vals, msgs, coeffs);
}
update(v, vals, msgs, coeffs) {
let piInv = 0;
let mu = 0;
for (let i = 0; i < vals.length; i++) {
const val = vals[i];
const msg = msgs[i];
const coeff = coeffs[i];
const div = val.div(msg);
mu += coeff * div.mu;
if (!Number.isFinite(piInv)) {
continue;
}
piInv += coeff ** 2 / div.pi;
}
const pi = 1.0 / piInv;
const tau = pi * mu;
return v.updateMessage(this, pi, tau);
}
}
export class TruncateFactor extends Factor {
vFunc;
wFunc;
drawMargin;
constructor(v, vFunc, wFunc, drawMargin) {
super([v]);
this.vFunc = vFunc;
this.wFunc = wFunc;
this.drawMargin = drawMargin;
}
up() {
const val = this.v;
const msg = this.v.messages[this.toString()];
const div = val.div(msg);
const sqrtPi = Math.sqrt(div.pi);
const v = this.vFunc(div.tau / sqrtPi, this.drawMargin * sqrtPi);
const w = this.wFunc(div.tau / sqrtPi, this.drawMargin * sqrtPi);
const denom = 1.0 - w;
const pi = div.pi / denom;
const tau = (div.tau + sqrtPi * v) / denom;
return val.updateValue(this, pi, tau);
}
}