UNPKG

@ai-on-browser/data-analysis-models

Version:

Data analysis model package without any dependencies

96 lines (84 loc) 3.02 kB
import { onnx } from '../onnx_exporter.js' import { getConstNodeName } from '../utils.js' /** * Handle pdelu layer */ export default { /** * Export to onnx object. * @param {onnx.ModelProto} model Model object * @param {import("../../graph").LayerObject & {type: 'pdelu'}} obj Node object */ export(model, obj) { const tensor0 = getConstNodeName(model, 0) const tensor1 = getConstNodeName(model, 1) const tensor_alpha = new onnx.TensorProto() tensor_alpha.setName(obj.name + '_alpha') tensor_alpha.setDataType(onnx.TensorProto.DataType.FLOAT) tensor_alpha.setDimsList([1]) tensor_alpha.setFloatDataList([obj.alpha ?? 1]) const tensor_t = new onnx.TensorProto() tensor_t.setName(obj.name + '_t') tensor_t.setDataType(onnx.TensorProto.DataType.FLOAT) tensor_t.setDimsList([1]) tensor_t.setFloatDataList([obj.t ?? 0.1]) const node_sub_t = new onnx.NodeProto() node_sub_t.setOpType('Sub') node_sub_t.addInput(tensor1) node_sub_t.addInput(obj.name + '_t') node_sub_t.addOutput(obj.name + '_1-t') const input = Array.isArray(obj.input) ? obj.input[0] : obj.input const node_mul = new onnx.NodeProto() node_mul.setOpType('Mul') node_mul.addInput(input) node_mul.addInput(obj.name + '_1-t') node_mul.addOutput(obj.name + '_(1-t)*v') const node_add_1 = new onnx.NodeProto() node_add_1.setOpType('Add') node_add_1.addInput(obj.name + '_(1-t)*v') node_add_1.addInput(tensor1) node_add_1.addOutput(obj.name + '_1+(1-t)*v') const node_inv = new onnx.NodeProto() node_inv.setOpType('Reciprocal') node_inv.addInput(obj.name + '_1-t') node_inv.addOutput(obj.name + '_1/(1-t)') const node_pow = new onnx.NodeProto() node_pow.setOpType('Pow') node_pow.addInput(obj.name + '_1+(1-t)*v') node_pow.addInput(obj.name + '_1/(1-t)') node_pow.addOutput(obj.name + '_pow') const node_sub = new onnx.NodeProto() node_sub.setOpType('Sub') node_sub.addInput(obj.name + '_pow') node_sub.addInput(tensor1) node_sub.addOutput(obj.name + '_sub') const node_mul_alpha = new onnx.NodeProto() node_mul_alpha.setOpType('Mul') node_mul_alpha.addInput(obj.name + '_sub') node_mul_alpha.addInput(obj.name + '_alpha') node_mul_alpha.addOutput(obj.name + '_mul_alpha') const node_posneg = new onnx.NodeProto() node_posneg.setOpType('Greater') node_posneg.addInput(input) node_posneg.addInput(tensor0) node_posneg.addOutput(obj.name + '_posneg') const node_where = new onnx.NodeProto() node_where.setOpType('Where') node_where.addInput(obj.name + '_posneg') node_where.addInput(input) node_where.addInput(obj.name + '_mul_alpha') node_where.addOutput(obj.name) const graph = model.getGraph() graph.addInitializer(tensor_alpha) graph.addInitializer(tensor_t) graph.addNode(node_sub_t) graph.addNode(node_mul) graph.addNode(node_add_1) graph.addNode(node_inv) graph.addNode(node_pow) graph.addNode(node_sub) graph.addNode(node_mul_alpha) graph.addNode(node_posneg) graph.addNode(node_where) }, }