@ai-on-browser/data-analysis-models
Version:
Data analysis model package without any dependencies
207 lines (178 loc) • 6.59 kB
JavaScript
import { onnx } from '../onnx_exporter.js'
import { getConstNodeName } from '../utils.js'
/**
* Handle gelu layer
*/
export default {
/**
* Export to onnx object.
* @param {onnx.ModelProto} model Model object
* @param {import("../../graph").LayerObject & {type: 'gelu'}} obj Node object
*/
export(model, obj) {
const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
const opset = model.getOpsetImportList()[0]
if (opset.getDomain() === '' && opset.getVersion() >= 20) {
const node = new onnx.NodeProto()
node.setOpType('Gelu')
node.addInput(input)
node.addOutput(obj.name)
const graph = model.getGraph()
graph.addNode(node)
} else {
const tensor1 = getConstNodeName(model, 1)
const tensor2 = getConstNodeName(model, 2)
const tensor_p = getConstNodeName(model, 0.3275911)
const tensor_a1 = getConstNodeName(model, 0.254829592)
const tensor_a2 = getConstNodeName(model, -0.284496736)
const tensor_a3 = getConstNodeName(model, 1.421413741)
const tensor_a4 = getConstNodeName(model, -1.453152027)
const tensor_a5 = getConstNodeName(model, 1.061405429)
const node_sqrt2 = new onnx.NodeProto()
node_sqrt2.setOpType('Sqrt')
node_sqrt2.addInput(tensor2)
node_sqrt2.addOutput(obj.name + '_sqrt2')
const node_v = new onnx.NodeProto()
node_v.setOpType('Div')
node_v.addInput(input)
node_v.addInput(obj.name + '_sqrt2')
node_v.addOutput(obj.name + '_v')
const node_abs = new onnx.NodeProto()
node_abs.setOpType('Abs')
node_abs.addInput(obj.name + '_v')
node_abs.addOutput(obj.name + '_abs')
const node_mul_p = new onnx.NodeProto()
node_mul_p.setOpType('Mul')
node_mul_p.addInput(obj.name + '_abs')
node_mul_p.addInput(tensor_p)
node_mul_p.addOutput(obj.name + '_mul_p')
const node_add1 = new onnx.NodeProto()
node_add1.setOpType('Add')
node_add1.addInput(obj.name + '_mul_p')
node_add1.addInput(tensor1)
node_add1.addOutput(obj.name + '_add1')
const node_t = new onnx.NodeProto()
node_t.setOpType('Reciprocal')
node_t.addInput(obj.name + '_add1')
node_t.addOutput(obj.name + '_t')
const node_mul_a5 = new onnx.NodeProto()
node_mul_a5.setOpType('Mul')
node_mul_a5.addInput(obj.name + '_t')
node_mul_a5.addInput(tensor_a5)
node_mul_a5.addOutput(obj.name + '_mul_a5')
const node_add_a4 = new onnx.NodeProto()
node_add_a4.setOpType('Add')
node_add_a4.addInput(obj.name + '_mul_a5')
node_add_a4.addInput(tensor_a4)
node_add_a4.addOutput(obj.name + '_add_a4')
const node_mul_a4 = new onnx.NodeProto()
node_mul_a4.setOpType('Mul')
node_mul_a4.addInput(obj.name + '_add_a4')
node_mul_a4.addInput(obj.name + '_t')
node_mul_a4.addOutput(obj.name + '_mul_a4')
const node_add_a3 = new onnx.NodeProto()
node_add_a3.setOpType('Add')
node_add_a3.addInput(obj.name + '_mul_a4')
node_add_a3.addInput(tensor_a3)
node_add_a3.addOutput(obj.name + '_add_a3')
const node_mul_a3 = new onnx.NodeProto()
node_mul_a3.setOpType('Mul')
node_mul_a3.addInput(obj.name + '_add_a3')
node_mul_a3.addInput(obj.name + '_t')
node_mul_a3.addOutput(obj.name + '_mul_a3')
const node_add_a2 = new onnx.NodeProto()
node_add_a2.setOpType('Add')
node_add_a2.addInput(obj.name + '_mul_a3')
node_add_a2.addInput(tensor_a2)
node_add_a2.addOutput(obj.name + '_add_a2')
const node_mul_a2 = new onnx.NodeProto()
node_mul_a2.setOpType('Mul')
node_mul_a2.addInput(obj.name + '_add_a2')
node_mul_a2.addInput(obj.name + '_t')
node_mul_a2.addOutput(obj.name + '_mul_a2')
const node_add_a1 = new onnx.NodeProto()
node_add_a1.setOpType('Add')
node_add_a1.addInput(obj.name + '_mul_a2')
node_add_a1.addInput(tensor_a1)
node_add_a1.addOutput(obj.name + '_add_a1')
const node_mul_a1 = new onnx.NodeProto()
node_mul_a1.setOpType('Mul')
node_mul_a1.addInput(obj.name + '_add_a1')
node_mul_a1.addInput(obj.name + '_t')
node_mul_a1.addOutput(obj.name + '_mul_a1')
const node_pow = new onnx.NodeProto()
node_pow.setOpType('Pow')
node_pow.addInput(obj.name + '_v')
node_pow.addInput(tensor2)
node_pow.addOutput(obj.name + '_pow')
const node_neg = new onnx.NodeProto()
node_neg.setOpType('Neg')
node_neg.addInput(obj.name + '_pow')
node_neg.addOutput(obj.name + '_neg')
const node_exp = new onnx.NodeProto()
node_exp.setOpType('Exp')
node_exp.addInput(obj.name + '_neg')
node_exp.addOutput(obj.name + '_exp')
const node_mul = new onnx.NodeProto()
node_mul.setOpType('Mul')
node_mul.addInput(obj.name + '_mul_a1')
node_mul.addInput(obj.name + '_exp')
node_mul.addOutput(obj.name + '_mul')
const node_erf = new onnx.NodeProto()
node_erf.setOpType('Sub')
node_erf.addInput(tensor1)
node_erf.addInput(obj.name + '_mul')
node_erf.addOutput(obj.name + '_erf')
const node_sign = new onnx.NodeProto()
node_sign.setOpType('Sign')
node_sign.addInput(obj.name + '_v')
node_sign.addOutput(obj.name + '_sign')
const node_sign_erf = new onnx.NodeProto()
node_sign_erf.setOpType('Mul')
node_sign_erf.addInput(obj.name + '_erf')
node_sign_erf.addInput(obj.name + '_sign')
node_sign_erf.addOutput(obj.name + '_sign_erf')
const node_erf_add1 = new onnx.NodeProto()
node_erf_add1.setOpType('Add')
node_erf_add1.addInput(tensor1)
node_erf_add1.addInput(obj.name + '_sign_erf')
node_erf_add1.addOutput(obj.name + '_erf_add1')
const node_mul_v = new onnx.NodeProto()
node_mul_v.setOpType('Mul')
node_mul_v.addInput(input)
node_mul_v.addInput(obj.name + '_erf_add1')
node_mul_v.addOutput(obj.name + '_mul_v')
const node_div2 = new onnx.NodeProto()
node_div2.setOpType('Div')
node_div2.addInput(obj.name + '_mul_v')
node_div2.addInput(tensor2)
node_div2.addOutput(obj.name)
const graph = model.getGraph()
graph.addNode(node_sqrt2)
graph.addNode(node_v)
graph.addNode(node_abs)
graph.addNode(node_mul_p)
graph.addNode(node_add1)
graph.addNode(node_t)
graph.addNode(node_mul_a5)
graph.addNode(node_add_a4)
graph.addNode(node_mul_a4)
graph.addNode(node_add_a3)
graph.addNode(node_mul_a3)
graph.addNode(node_add_a2)
graph.addNode(node_mul_a2)
graph.addNode(node_add_a1)
graph.addNode(node_mul_a1)
graph.addNode(node_pow)
graph.addNode(node_neg)
graph.addNode(node_exp)
graph.addNode(node_mul)
graph.addNode(node_erf)
graph.addNode(node_sign)
graph.addNode(node_sign_erf)
graph.addNode(node_erf_add1)
graph.addNode(node_mul_v)
graph.addNode(node_div2)
}
},
}