@ai-on-browser/data-analysis-models
Version:
Data analysis model package without any dependencies
59 lines (51 loc) • 1.26 kB
JavaScript
import NeuralNetwork from '../../neuralnetwork.js'
import Layer from './base.js'
import Matrix from '../../../util/matrix.js'
/**
* Include layer
*/
export default class IncludeLayer extends Layer {
/**
* @param {object} config object
* @param {NeuralNetwork | object[]} config.net Included network
* @param {string} [config.input_to] Input name of the network
* @param {boolean} [config.train] Train included network or not
*/
constructor({ net, input_to = null, train = true, ...rest }) {
super(rest)
this._model = net instanceof NeuralNetwork ? net : NeuralNetwork.fromObject(net)
this._input_to = input_to
this._train = train
this._org_i = null
this._org_t = null
}
bind({ input, supervisor }) {
this._org_i = input
this._org_t = supervisor
}
calc(x) {
if (!(this._org_i instanceof Matrix) && this._input_to) {
const org_x = x
x = this._org_i
x[this._input_to] = org_x
}
return this._model.calc(x)
}
grad(bo) {
return this._model.grad(bo)
}
update(optimizer) {
if (this._train) {
this._model.update(optimizer.lr)
}
}
toObject() {
return {
type: 'include',
net: this._model.toObject(),
input_to: this._input_to,
train: this._train,
}
}
}
IncludeLayer.registLayer()