@ai-on-browser/data-analysis-models
Version:
Data analysis model package without any dependencies
176 lines (168 loc) • 5.13 kB
JavaScript
/**
* Semi-supervised naive bayes
*/
export default class SemiSupervisedNaiveBayes {
// Text classification from labeled and unlabeled documents using EM.
// https://www.ri.cmu.edu/pub_files/pub1/nigam_k_1999_1/nigam_k_1999_1.pdf
// https://github.com/jmatayoshi/semi-supervised-naive-bayes
/**
* @param {number} [lambda] Weight applied to the contribution of the unlabeled data
*/
constructor(lambda = 1) {
this._lambda = lambda
this._alpha = 2
}
/**
* Initialize model.
* @param {Array<Array<string>>} datas Training data
* @param {(* | null)[]} labels Target values
*/
init(datas, labels) {
const voc = new Set()
this._labels = []
for (let i = 0; i < datas.length; i++) {
for (let j = 0; j < datas[i].length; j++) {
voc.add(datas[i][j])
}
}
this._vocabulary = [...voc]
this._labeled_data = { w: [], i: [] }
this._unlabeled_data = { w: [], i: [] }
for (let i = 0; i < datas.length; i++) {
const di = datas[i].map(w => this._vocabulary.indexOf(w))
if (labels[i] !== null) {
this._labeled_data.w.push(datas[i])
this._labeled_data.i.push(di)
this._labels.push(labels[i])
} else {
this._unlabeled_data.w.push(datas[i])
this._unlabeled_data.i.push(di)
}
}
this._classes = [...new Set(this._labels)]
this._prob_wc = []
this._prob_c = []
for (let k = 0; k < this._classes.length; k++) {
const pwc = Array(this._vocabulary.length).fill(0)
let pc = 0
for (let i = 0; i < this._labeled_data.i.length; i++) {
if (this._labels[i] !== this._classes[k]) {
continue
}
for (let j = 0; j < this._labeled_data.i[i].length; j++) {
pwc[this._labeled_data.i[i][j]]++
}
pc++
}
const vock = pwc.reduce((s, v) => s + v, 0)
this._prob_wc[k] = pwc.map(v => (1 + v) / (this._vocabulary.length + vock))
this._prob_c[k] = (1 + pc) / (this._classes.length + datas.length)
}
}
/**
* Fit model.
*/
fit() {
const unlabeled_probs = this.probability(this._unlabeled_data.w)
const prob_wc = []
const prob_c = []
for (let k = 0; k < this._classes.length; k++) {
const pwc = Array(this._vocabulary.length).fill(0)
let pc = 0
for (let i = 0; i < this._labeled_data.i.length; i++) {
if (this._labels[i] !== this._classes[k]) {
continue
}
for (let j = 0; j < this._labeled_data.i[i].length; j++) {
pwc[this._labeled_data.i[i][j]]++
}
pc++
}
for (let i = 0; i < this._unlabeled_data.i.length; i++) {
for (let j = 0; j < this._unlabeled_data.i[i].length; j++) {
pwc[this._unlabeled_data.i[i][j]] += this._lambda * unlabeled_probs[i][k]
}
pc += this._lambda * unlabeled_probs[i][k]
}
const vock = pwc.reduce((s, v) => s + v, 0)
prob_wc[k] = pwc.map(v => (1 + v) / (this._vocabulary.length + vock))
prob_c[k] =
(1 + pc) /
(this._classes.length + this._labeled_data.i.length + this._lambda * this._unlabeled_data.i.length)
}
this._prob_wc = prob_wc
this._prob_c = prob_c
}
/**
* Returns predicted probabilities.
* @param {Array<Array<string>>} datas Sample data
* @returns {Array<Array<number>>} Predicted values
*/
probability(datas) {
return datas.map(v => {
const vocidx = v.map(w => this._vocabulary.indexOf(w))
const p = Array(this._classes.length).fill(0)
for (let k = 0; k < this._classes.length; k++) {
p[k] = this._prob_c[k]
for (let j = 0; j < v.length; j++) {
p[k] *= this._prob_wc[k][vocidx[j]]
}
}
const s = p.reduce((s, v) => s + v, 0)
return p.map(v => v / s)
})
}
/**
* Returns predicted categories.
* @returns {number} Log likelihood value
*/
logLikelihood() {
const labeled_z = this.probability(this._labeled_data.w)
const unlabeled_z = this.probability(this._unlabeled_data.w)
let llh = 0
for (let i = 0; i < this._classes.length; i++) {
llh += (this._alpha - 1) * Math.log(this._prob_c[i])
for (let j = 0; j < this._vocabulary.length; j++) {
llh += (this._alpha - 1) * Math.log(this._prob_wc[i][j])
}
}
for (let i = 0; i < this._labeled_data.i.length; i++) {
for (let k = 0; k < this._classes.length; k++) {
let v = Math.log(this._prob_c[k])
for (let j = 0; j < this._labeled_data.i[i].length; j++) {
v += Math.log(this._prob_wc[k][this._labeled_data.i[i][j]])
}
llh += labeled_z[i][k] * v
}
}
for (let i = 0; i < this._unlabeled_data.i.length; i++) {
for (let k = 0; k < this._classes.length; k++) {
let v = Math.log(this._prob_c[k])
for (let j = 0; j < this._unlabeled_data.i[i].length; j++) {
v += Math.log(this._prob_wc[k][this._unlabeled_data.i[i][j]])
}
llh += this._lambda * unlabeled_z[i][k] * v
}
}
return llh
}
/**
* Returns predicted categories.
* @param {Array<Array<string>>} datas Sample data
* @returns {*[]} Predicted values
*/
predict(datas) {
const prob = this.probability(datas)
return prob.map(v => {
let max_p = 0
let max_c = -1
for (let i = 0; i < this._classes.length; i++) {
if (v[i] > max_p) {
max_p = v[i]
max_c = i
}
}
return max_c < 0 ? null : this._classes[max_c]
})
}
}