UNPKG

@ai-on-browser/data-analysis-models

Version:

Data analysis model package without any dependencies

152 lines (140 loc) 3.93 kB
import { DecisionTreeClassifier, DecisionTreeRegression } from './decision_tree.js' /** * Bsae class for random forest models */ class RandomForest { // see https://ja.wikipedia.org/wiki/%E3%83%A9%E3%83%B3%E3%83%80%E3%83%A0%E3%83%95%E3%82%A9%E3%83%AC%E3%82%B9%E3%83%88 /** * @param {number} tree_num Number of trees * @param {number} [sampling_rate] Sampling rate * @param {DecisionTreeClassifier | DecisionTreeRegression} tree_class Tree class * @param {*[]} [tree_class_args] Arguments for constructor of tree class */ constructor(tree_num, sampling_rate = 0.8, tree_class, tree_class_args = null) { this._samplingRate = sampling_rate this._trees = [] for (let i = 0; i < tree_num; i++) { const tree = new tree_class(...(tree_class_args || [])) this._trees.push(tree) } } /** * The max depth among the trees. * @type {number} */ get depth() { return Math.max(...this._trees.map(t => t.depth)) } _sample(n) { const arr = Array.from({ length: n }, (_, i) => i) for (let i = n - 1; i > 0; i--) { const r = Math.floor(Math.random() * (i + 1)) ;[arr[i], arr[r]] = [arr[r], arr[i]] } return arr.slice(0, Math.ceil(n * this._samplingRate)) } /** * Initialize model. * @param {Array<Array<number>>} datas Training data * @param {*[]} targets Target values */ init(datas, targets) { for (let i = 0; i < this._trees.length; i++) { const idx = this._sample(datas.length) const tdata = [] const ttarget = [] for (let k = 0; k < idx.length; k++) { tdata.push(datas[idx[k]]) ttarget.push(targets[idx[k]]) } this._trees[i].init(tdata, ttarget) } } /** * Fit model. */ fit() { this._trees.forEach(t => t.fit()) } /** * Returns probability of the datas. * @param {Array<Array<number>>} datas Sample data * @returns {Map<number, number>[]} Predicted values */ predict_prob(datas) { const preds = this._trees.map(t => t.predict_prob(datas)) const ret = [] for (let i = 0; i < datas.length; i++) { const prob = new Map() for (let n = 0; n < preds.length; n++) { preds[n][i].forEach((v, k) => { prob.set(k, (prob.get(k) || 0) + v) }) } prob.forEach((v, k) => { prob.set(k, v / preds.length) }) ret.push(prob) } return ret } } /** * Random forest classifier */ export class RandomForestClassifier extends RandomForest { // see https://ja.wikipedia.org/wiki/%E3%83%A9%E3%83%B3%E3%83%80%E3%83%A0%E3%83%95%E3%82%A9%E3%83%AC%E3%82%B9%E3%83%88 /** * @param {number} tree_num Number of trees * @param {number} [sampling_rate] Sampling rate * @param {'ID3' | 'CART'} [method] Method name */ constructor(tree_num, sampling_rate = 0.8, method = 'CART') { super(tree_num, sampling_rate, DecisionTreeClassifier, [method]) } /** * Returns predicted values. * @param {Array<Array<number>>} datas Sample data * @returns {*[]} Predicted values */ predict(datas) { const prob = this.predict_prob(datas) return prob.map(d => { let max_c = 0 let max_cls = -1 d.forEach((v, k) => { if (v > max_c) { max_c = v max_cls = k } }) return max_cls }) } } /** * Random forest regressor */ export class RandomForestRegressor extends RandomForest { // see https://ja.wikipedia.org/wiki/%E3%83%A9%E3%83%B3%E3%83%80%E3%83%A0%E3%83%95%E3%82%A9%E3%83%AC%E3%82%B9%E3%83%88 /** * @param {number} tree_num Number of trees * @param {number} [sampling_rate] Sampling rate */ constructor(tree_num, sampling_rate = 0.8) { super(tree_num, sampling_rate, DecisionTreeRegression) } /** * Returns predicted values. * @param {Array<Array<number>>} datas Sample data * @returns {number[]} Predicted values */ predict(datas) { const pred = this._trees.map(t => t.predict(datas)) const ret = [] for (let i = 0; i < datas.length; i++) { ret.push(pred.reduce((acc, v) => acc + v[i], 0) / pred.length) } return ret } }