@ai-on-browser/data-analysis-models
Version:
Data analysis model package without any dependencies
86 lines (75 loc) • 2.63 kB
JavaScript
import { onnx } from '../onnx_exporter.js'
/**
* Handle plu layer
*/
export default {
/**
* Export to onnx object.
* @param {onnx.ModelProto} model Model object
* @param {import("../../graph").LayerObject & {type: 'plu'}} obj Node object
*/
export(model, obj) {
const tensor_alpha = new onnx.TensorProto()
tensor_alpha.setName(`${obj.name}_alpha`)
tensor_alpha.setDataType(onnx.TensorProto.DataType.FLOAT)
tensor_alpha.setDimsList([1])
tensor_alpha.setFloatDataList([obj.alpha ?? 0.1])
const tensor_c = new onnx.TensorProto()
tensor_c.setName(`${obj.name}_c`)
tensor_c.setDataType(onnx.TensorProto.DataType.FLOAT)
tensor_c.setDimsList([1])
tensor_c.setFloatDataList([obj.c ?? 1])
const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
const node_lb_add = new onnx.NodeProto()
node_lb_add.setOpType('Add')
node_lb_add.addInput(input)
node_lb_add.addInput(`${obj.name}_c`)
node_lb_add.addOutput(`${obj.name}_lb_add`)
const node_lb_mul = new onnx.NodeProto()
node_lb_mul.setOpType('Mul')
node_lb_mul.addInput(`${obj.name}_lb_add`)
node_lb_mul.addInput(`${obj.name}_alpha`)
node_lb_mul.addOutput(`${obj.name}_lb_mul`)
const node_lb = new onnx.NodeProto()
node_lb.setOpType('Sub')
node_lb.addInput(`${obj.name}_lb_mul`)
node_lb.addInput(`${obj.name}_c`)
node_lb.addOutput(`${obj.name}_lb`)
const node_ub_sub = new onnx.NodeProto()
node_ub_sub.setOpType('Sub')
node_ub_sub.addInput(input)
node_ub_sub.addInput(`${obj.name}_c`)
node_ub_sub.addOutput(`${obj.name}_ub_sub`)
const node_ub_mul = new onnx.NodeProto()
node_ub_mul.setOpType('Mul')
node_ub_mul.addInput(`${obj.name}_ub_sub`)
node_ub_mul.addInput(`${obj.name}_alpha`)
node_ub_mul.addOutput(`${obj.name}_ub_mul`)
const node_ub = new onnx.NodeProto()
node_ub.setOpType('Add')
node_ub.addInput(`${obj.name}_ub_mul`)
node_ub.addInput(`${obj.name}_c`)
node_ub.addOutput(`${obj.name}_ub`)
const node_min = new onnx.NodeProto()
node_min.setOpType('Min')
node_min.addInput(input)
node_min.addInput(`${obj.name}_ub`)
node_min.addOutput(`${obj.name}_min`)
const node_max = new onnx.NodeProto()
node_max.setOpType('Max')
node_max.addInput(`${obj.name}_min`)
node_max.addInput(`${obj.name}_lb`)
node_max.addOutput(obj.name)
const graph = model.getGraph()
graph.addInitializer(tensor_alpha)
graph.addInitializer(tensor_c)
graph.addNode(node_lb_add)
graph.addNode(node_lb_mul)
graph.addNode(node_lb)
graph.addNode(node_ub_sub)
graph.addNode(node_ub_mul)
graph.addNode(node_ub)
graph.addNode(node_min)
graph.addNode(node_max)
},
}