UNPKG

@ai-on-browser/data-analysis-models

Version:

Data analysis model package without any dependencies

192 lines (182 loc) 4.95 kB
import Layer from './base.js' import Matrix from '../../../util/matrix.js' /** * Graph convolutional layer */ export default class GraphConvolutionalLayer extends Layer { /** * @param {object} config object * @param {number} config.out_size Size of output * @param {number[][] | Matrix | string} [config.w] Weight of kernel * @param {number[][] | Matrix | string} [config.b] Bias * @param {string | object} [config.activation] Name of activation or activation layer object * @param {number} [config.l2_decay] L2 decay * @param {number} [config.l1_decay] L1 decay */ constructor({ out_size, w = null, b = null, activation = null, l2_decay = 0, l1_decay = 0, ...rest }) { super(rest) this._out_size = out_size this._w = null if (typeof w === 'string') { this._wname = w } else if (w) { this._w = Matrix.fromArray(w) } this._b = null if (typeof b === 'string') { this._bname = b } else if (b) { this._b = Matrix.fromArray(b) } if (typeof activation === 'string') { this._activation = Layer.fromObject({ type: activation }) } else if (activation) { this._activation = Layer.fromObject(activation) } this._l2_decay = l2_decay this._l1_decay = l1_decay this._aggregate = 'mean' } get dependentLayers() { const layers = [] if (this._wname) { layers.push(this._wname) } if (this._bname) { layers.push(this._bname) } if (this._activation) { layers.push(...this._activation.dependentLayers) } return layers } calc(x) { if (this._wname) { this._w = this.graph.getNode(this._wname).outputValue if (!this._out_size) { this._out_size = this._w.sizes.at(-1) } } if (this._bname) { this._b = this.graph.getNode(this._bname).outputValue } if (!this._w) { this._w = Matrix.randn(x.value[0].nodes[0].sizes.at(-1), this._out_size) } if (!this._b) { this._b = Matrix.zeros(1, this._out_size) } this._i = x this._cum = [] this._o = x.copy() for (let k = 0; k < x.length; k++) { const graph = x.value[k].copy() const cgraph = x.value[k].copy() for (let i = 0; i < graph.order; i++) { const v = x.value[k].getNode(i).copy() const adj = x.value[k].adjacencies(i, true, 'in') for (let j = 0; j < adj.length; j++) { v.add(x.value[k].getNode(adj[j])) } if (this._aggregate === 'mean') { v.div(adj.length + 1) } cgraph.nodes[i] = v let o = v.dot(this._w) o.broadcastOperate(this._b, (a, b) => a + b) if (this._activation) { o = this._activation.calc(o) } graph.nodes[i] = o } this._cum[k] = cgraph this._o.value[k] = graph } return this._o } grad(bo) { if (this._activation) { for (let k = 0; k < bo.length; k++) { for (let i = 0; i < bo.value[k].order; i++) { bo.value[k].nodes[i] = this._activation.grad(bo.value[k].nodes[i]) } } } this._dw = null this._db = null this._bi = this._i.copy() for (let k = 0; k < this._bi.length; k++) { const graph = this._i.value[k].copy() for (let i = 0; i < this._bi.value[k].order; i++) { graph.nodes[i] = this._i.value[k].nodes[i].copy() graph.nodes[i].fill(0) } this._bi.value[k] = graph } let count = 0 for (let k = 0; k < bo.length; k++) { for (let i = 0; i < bo.value[k].order; i++) { const dw = this._cum[k].nodes[i].tDot(bo.value[k].nodes[i]) const db = bo.value[k].nodes[i].sum(0) if (!this._dw) { this._dw = dw this._db = db } else { this._dw.add(dw) this._db.add(db) } count++ const bik = bo.value[k].nodes[i].dot(this._w.t) const adj = this._bi.value[k].adjacencies(i, true, 'in') if (this._aggregate === 'mean') { bik.div(adj.length + 1) } this._bi.value[k].nodes[i].broadcastOperate(bik, (a, b) => a + b) for (let j = 0; j < adj.length; j++) { this._bi.value[k].nodes[adj[j]].broadcastOperate(bik, (a, b) => a + b) } } } this._dw.div(count) if (this._l2_decay > 0 || this._l1_decay > 0) { for (let i = 0; i < this._dw.rows; i++) { for (let j = 0; j < this._dw.cols; j++) { const v = this._w.at(i, j) this._dw.addAt(i, j, v * this._l2_decay + Math.sign(v) * this._l1_decay) } } } this._db.div(count) if (this._wname || this._bname) { const gp = {} if (this._wname) { gp[this._wname] = this._dw } if (this._bname) { gp[this._bname] = this._db } return [this._bi, gp] } return this._bi } update(optimizer) { if (!this._wname) { this._w.sub(optimizer.delta('w', this._dw)) } if (!this._bname) { this._b.sub(optimizer.delta('b', this._db)) } } toObject() { return { type: 'graph_conv', out_size: this._out_size, w: this._wname || this._w?.toArray(), b: this._bname || this._b?.toArray(), activation: this._activation?.toObject(), l2_decay: this._l2_decay, l1_decay: this._l1_decay, } } } GraphConvolutionalLayer.registLayer('graph_conv')