UNPKG

@ai-on-browser/data-analysis-models

Version:

Data analysis model package without any dependencies

207 lines (178 loc) 6.59 kB
import { onnx } from '../onnx_exporter.js' import { getConstNodeName } from '../utils.js' /** * Handle gelu layer */ export default { /** * Export to onnx object. * @param {onnx.ModelProto} model Model object * @param {import("../../graph").LayerObject & {type: 'gelu'}} obj Node object */ export(model, obj) { const input = Array.isArray(obj.input) ? obj.input[0] : obj.input const opset = model.getOpsetImportList()[0] if (opset.getDomain() === '' && opset.getVersion() >= 20) { const node = new onnx.NodeProto() node.setOpType('Gelu') node.addInput(input) node.addOutput(obj.name) const graph = model.getGraph() graph.addNode(node) } else { const tensor1 = getConstNodeName(model, 1) const tensor2 = getConstNodeName(model, 2) const tensor_p = getConstNodeName(model, 0.3275911) const tensor_a1 = getConstNodeName(model, 0.254829592) const tensor_a2 = getConstNodeName(model, -0.284496736) const tensor_a3 = getConstNodeName(model, 1.421413741) const tensor_a4 = getConstNodeName(model, -1.453152027) const tensor_a5 = getConstNodeName(model, 1.061405429) const node_sqrt2 = new onnx.NodeProto() node_sqrt2.setOpType('Sqrt') node_sqrt2.addInput(tensor2) node_sqrt2.addOutput(`${obj.name}_sqrt2`) const node_v = new onnx.NodeProto() node_v.setOpType('Div') node_v.addInput(input) node_v.addInput(`${obj.name}_sqrt2`) node_v.addOutput(`${obj.name}_v`) const node_abs = new onnx.NodeProto() node_abs.setOpType('Abs') node_abs.addInput(`${obj.name}_v`) node_abs.addOutput(`${obj.name}_abs`) const node_mul_p = new onnx.NodeProto() node_mul_p.setOpType('Mul') node_mul_p.addInput(`${obj.name}_abs`) node_mul_p.addInput(tensor_p) node_mul_p.addOutput(`${obj.name}_mul_p`) const node_add1 = new onnx.NodeProto() node_add1.setOpType('Add') node_add1.addInput(`${obj.name}_mul_p`) node_add1.addInput(tensor1) node_add1.addOutput(`${obj.name}_add1`) const node_t = new onnx.NodeProto() node_t.setOpType('Reciprocal') node_t.addInput(`${obj.name}_add1`) node_t.addOutput(`${obj.name}_t`) const node_mul_a5 = new onnx.NodeProto() node_mul_a5.setOpType('Mul') node_mul_a5.addInput(`${obj.name}_t`) node_mul_a5.addInput(tensor_a5) node_mul_a5.addOutput(`${obj.name}_mul_a5`) const node_add_a4 = new onnx.NodeProto() node_add_a4.setOpType('Add') node_add_a4.addInput(`${obj.name}_mul_a5`) node_add_a4.addInput(tensor_a4) node_add_a4.addOutput(`${obj.name}_add_a4`) const node_mul_a4 = new onnx.NodeProto() node_mul_a4.setOpType('Mul') node_mul_a4.addInput(`${obj.name}_add_a4`) node_mul_a4.addInput(`${obj.name}_t`) node_mul_a4.addOutput(`${obj.name}_mul_a4`) const node_add_a3 = new onnx.NodeProto() node_add_a3.setOpType('Add') node_add_a3.addInput(`${obj.name}_mul_a4`) node_add_a3.addInput(tensor_a3) node_add_a3.addOutput(`${obj.name}_add_a3`) const node_mul_a3 = new onnx.NodeProto() node_mul_a3.setOpType('Mul') node_mul_a3.addInput(`${obj.name}_add_a3`) node_mul_a3.addInput(`${obj.name}_t`) node_mul_a3.addOutput(`${obj.name}_mul_a3`) const node_add_a2 = new onnx.NodeProto() node_add_a2.setOpType('Add') node_add_a2.addInput(`${obj.name}_mul_a3`) node_add_a2.addInput(tensor_a2) node_add_a2.addOutput(`${obj.name}_add_a2`) const node_mul_a2 = new onnx.NodeProto() node_mul_a2.setOpType('Mul') node_mul_a2.addInput(`${obj.name}_add_a2`) node_mul_a2.addInput(`${obj.name}_t`) node_mul_a2.addOutput(`${obj.name}_mul_a2`) const node_add_a1 = new onnx.NodeProto() node_add_a1.setOpType('Add') node_add_a1.addInput(`${obj.name}_mul_a2`) node_add_a1.addInput(tensor_a1) node_add_a1.addOutput(`${obj.name}_add_a1`) const node_mul_a1 = new onnx.NodeProto() node_mul_a1.setOpType('Mul') node_mul_a1.addInput(`${obj.name}_add_a1`) node_mul_a1.addInput(`${obj.name}_t`) node_mul_a1.addOutput(`${obj.name}_mul_a1`) const node_pow = new onnx.NodeProto() node_pow.setOpType('Pow') node_pow.addInput(`${obj.name}_v`) node_pow.addInput(tensor2) node_pow.addOutput(`${obj.name}_pow`) const node_neg = new onnx.NodeProto() node_neg.setOpType('Neg') node_neg.addInput(`${obj.name}_pow`) node_neg.addOutput(`${obj.name}_neg`) const node_exp = new onnx.NodeProto() node_exp.setOpType('Exp') node_exp.addInput(`${obj.name}_neg`) node_exp.addOutput(`${obj.name}_exp`) const node_mul = new onnx.NodeProto() node_mul.setOpType('Mul') node_mul.addInput(`${obj.name}_mul_a1`) node_mul.addInput(`${obj.name}_exp`) node_mul.addOutput(`${obj.name}_mul`) const node_erf = new onnx.NodeProto() node_erf.setOpType('Sub') node_erf.addInput(tensor1) node_erf.addInput(`${obj.name}_mul`) node_erf.addOutput(`${obj.name}_erf`) const node_sign = new onnx.NodeProto() node_sign.setOpType('Sign') node_sign.addInput(`${obj.name}_v`) node_sign.addOutput(`${obj.name}_sign`) const node_sign_erf = new onnx.NodeProto() node_sign_erf.setOpType('Mul') node_sign_erf.addInput(`${obj.name}_erf`) node_sign_erf.addInput(`${obj.name}_sign`) node_sign_erf.addOutput(`${obj.name}_sign_erf`) const node_erf_add1 = new onnx.NodeProto() node_erf_add1.setOpType('Add') node_erf_add1.addInput(tensor1) node_erf_add1.addInput(`${obj.name}_sign_erf`) node_erf_add1.addOutput(`${obj.name}_erf_add1`) const node_mul_v = new onnx.NodeProto() node_mul_v.setOpType('Mul') node_mul_v.addInput(input) node_mul_v.addInput(`${obj.name}_erf_add1`) node_mul_v.addOutput(`${obj.name}_mul_v`) const node_div2 = new onnx.NodeProto() node_div2.setOpType('Div') node_div2.addInput(`${obj.name}_mul_v`) node_div2.addInput(tensor2) node_div2.addOutput(obj.name) const graph = model.getGraph() graph.addNode(node_sqrt2) graph.addNode(node_v) graph.addNode(node_abs) graph.addNode(node_mul_p) graph.addNode(node_add1) graph.addNode(node_t) graph.addNode(node_mul_a5) graph.addNode(node_add_a4) graph.addNode(node_mul_a4) graph.addNode(node_add_a3) graph.addNode(node_mul_a3) graph.addNode(node_add_a2) graph.addNode(node_mul_a2) graph.addNode(node_add_a1) graph.addNode(node_mul_a1) graph.addNode(node_pow) graph.addNode(node_neg) graph.addNode(node_exp) graph.addNode(node_mul) graph.addNode(node_erf) graph.addNode(node_sign) graph.addNode(node_sign_erf) graph.addNode(node_erf_add1) graph.addNode(node_mul_v) graph.addNode(node_div2) } }, }