UNPKG

@ai-on-browser/data-analysis-models

Version:

Data analysis model package without any dependencies

99 lines (88 loc) 3.14 kB
import { onnx } from '../onnx_exporter.js' import { getConstNodeName } from '../utils.js' /** * Handle hexpo layer */ export default { /** * Export to onnx object. * @param {onnx.ModelProto} model Model object * @param {import("../../graph").LayerObject & {type: 'hexpo'}} obj Node object */ export(model, obj) { const tensor0 = getConstNodeName(model, 0) const tensor_a = new onnx.TensorProto() tensor_a.setName(obj.name + '_a') tensor_a.setDataType(onnx.TensorProto.DataType.FLOAT) tensor_a.setDimsList([1]) tensor_a.setFloatDataList([obj.a ?? 1]) const tensor_b = new onnx.TensorProto() tensor_b.setName(obj.name + '_b') tensor_b.setDataType(onnx.TensorProto.DataType.FLOAT) tensor_b.setDimsList([1]) tensor_b.setFloatDataList([obj.b ?? 1]) const tensor_c = new onnx.TensorProto() tensor_c.setName(obj.name + '_c') tensor_c.setDataType(onnx.TensorProto.DataType.FLOAT) tensor_c.setDimsList([1]) tensor_c.setFloatDataList([obj.c ?? 1]) const tensor_d = new onnx.TensorProto() tensor_d.setName(obj.name + '_d') tensor_d.setDataType(onnx.TensorProto.DataType.FLOAT) tensor_d.setDimsList([1]) tensor_d.setFloatDataList([obj.d ?? 1]) const node_nega = new onnx.NodeProto() node_nega.setOpType('Neg') node_nega.addInput(obj.name + '_a') node_nega.addOutput(obj.name + '_-a') const node_negb = new onnx.NodeProto() node_negb.setOpType('Neg') node_negb.addInput(obj.name + '_b') node_negb.addOutput(obj.name + '_-b') const input = Array.isArray(obj.input) ? obj.input[0] : obj.input const node_posneg = new onnx.NodeProto() node_posneg.setOpType('GreaterOrEqual') node_posneg.addInput(input) node_posneg.addInput(tensor0) node_posneg.addOutput(obj.name + '_posneg') const node_where1 = new onnx.NodeProto() node_where1.setOpType('Where') node_where1.addInput(obj.name + '_posneg') node_where1.addInput(obj.name + '_-b') node_where1.addInput(obj.name + '_d') node_where1.addOutput(obj.name + '_where1') const node_div = new onnx.NodeProto() node_div.setOpType('Div') node_div.addInput(input) node_div.addInput(obj.name + '_where1') node_div.addOutput(obj.name + '_div') const node_elu = new onnx.NodeProto() node_elu.setOpType('Elu') node_elu.addInput(obj.name + '_div') node_elu.addOutput(obj.name + '_elu') const node_where2 = new onnx.NodeProto() node_where2.setOpType('Where') node_where2.addInput(obj.name + '_posneg') node_where2.addInput(obj.name + '_-a') node_where2.addInput(obj.name + '_c') node_where2.addOutput(obj.name + '_where2') const node_mul = new onnx.NodeProto() node_mul.setOpType('Mul') node_mul.addInput(obj.name + '_elu') node_mul.addInput(obj.name + '_where2') node_mul.addOutput(obj.name) const graph = model.getGraph() graph.addInitializer(tensor_a) graph.addInitializer(tensor_b) graph.addInitializer(tensor_c) graph.addInitializer(tensor_d) graph.addNode(node_nega) graph.addNode(node_negb) graph.addNode(node_posneg) graph.addNode(node_where1) graph.addNode(node_div) graph.addNode(node_elu) graph.addNode(node_where2) graph.addNode(node_mul) }, }