@ai-on-browser/data-analysis-models
Version:
Data analysis model package without any dependencies
24 lines (20 loc) • 438 B
JavaScript
import Matrix from '../../../util/matrix.js'
import Tensor from '../../../util/tensor.js'
import Layer from './base.js'
/**
* Shape layer
*/
export default class ShapeLayer extends Layer {
calc(x) {
this._i = x
this._size = x.sizes.concat()
return Tensor.fromArray(this._size)
}
grad() {
if (this._size.length === 2) {
return Matrix.zeros(this._size)
}
return Tensor.zeros(this._size)
}
}
ShapeLayer.registLayer()