@ai-on-browser/data-analysis-models
Version:
Data analysis model package without any dependencies
207 lines (178 loc) • 6.59 kB
JavaScript
import { onnx } from '../onnx_exporter.js'
import { getConstNodeName } from '../utils.js'
/**
* Handle gelu layer
*/
export default {
/**
* Export to onnx object.
* @param {onnx.ModelProto} model Model object
* @param {import("../../graph").LayerObject & {type: 'gelu'}} obj Node object
*/
export(model, obj) {
const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
const opset = model.getOpsetImportList()[0]
if (opset.getDomain() === '' && opset.getVersion() >= 20) {
const node = new onnx.NodeProto()
node.setOpType('Gelu')
node.addInput(input)
node.addOutput(obj.name)
const graph = model.getGraph()
graph.addNode(node)
} else {
const tensor1 = getConstNodeName(model, 1)
const tensor2 = getConstNodeName(model, 2)
const tensor_p = getConstNodeName(model, 0.3275911)
const tensor_a1 = getConstNodeName(model, 0.254829592)
const tensor_a2 = getConstNodeName(model, -0.284496736)
const tensor_a3 = getConstNodeName(model, 1.421413741)
const tensor_a4 = getConstNodeName(model, -1.453152027)
const tensor_a5 = getConstNodeName(model, 1.061405429)
const node_sqrt2 = new onnx.NodeProto()
node_sqrt2.setOpType('Sqrt')
node_sqrt2.addInput(tensor2)
node_sqrt2.addOutput(`${obj.name}_sqrt2`)
const node_v = new onnx.NodeProto()
node_v.setOpType('Div')
node_v.addInput(input)
node_v.addInput(`${obj.name}_sqrt2`)
node_v.addOutput(`${obj.name}_v`)
const node_abs = new onnx.NodeProto()
node_abs.setOpType('Abs')
node_abs.addInput(`${obj.name}_v`)
node_abs.addOutput(`${obj.name}_abs`)
const node_mul_p = new onnx.NodeProto()
node_mul_p.setOpType('Mul')
node_mul_p.addInput(`${obj.name}_abs`)
node_mul_p.addInput(tensor_p)
node_mul_p.addOutput(`${obj.name}_mul_p`)
const node_add1 = new onnx.NodeProto()
node_add1.setOpType('Add')
node_add1.addInput(`${obj.name}_mul_p`)
node_add1.addInput(tensor1)
node_add1.addOutput(`${obj.name}_add1`)
const node_t = new onnx.NodeProto()
node_t.setOpType('Reciprocal')
node_t.addInput(`${obj.name}_add1`)
node_t.addOutput(`${obj.name}_t`)
const node_mul_a5 = new onnx.NodeProto()
node_mul_a5.setOpType('Mul')
node_mul_a5.addInput(`${obj.name}_t`)
node_mul_a5.addInput(tensor_a5)
node_mul_a5.addOutput(`${obj.name}_mul_a5`)
const node_add_a4 = new onnx.NodeProto()
node_add_a4.setOpType('Add')
node_add_a4.addInput(`${obj.name}_mul_a5`)
node_add_a4.addInput(tensor_a4)
node_add_a4.addOutput(`${obj.name}_add_a4`)
const node_mul_a4 = new onnx.NodeProto()
node_mul_a4.setOpType('Mul')
node_mul_a4.addInput(`${obj.name}_add_a4`)
node_mul_a4.addInput(`${obj.name}_t`)
node_mul_a4.addOutput(`${obj.name}_mul_a4`)
const node_add_a3 = new onnx.NodeProto()
node_add_a3.setOpType('Add')
node_add_a3.addInput(`${obj.name}_mul_a4`)
node_add_a3.addInput(tensor_a3)
node_add_a3.addOutput(`${obj.name}_add_a3`)
const node_mul_a3 = new onnx.NodeProto()
node_mul_a3.setOpType('Mul')
node_mul_a3.addInput(`${obj.name}_add_a3`)
node_mul_a3.addInput(`${obj.name}_t`)
node_mul_a3.addOutput(`${obj.name}_mul_a3`)
const node_add_a2 = new onnx.NodeProto()
node_add_a2.setOpType('Add')
node_add_a2.addInput(`${obj.name}_mul_a3`)
node_add_a2.addInput(tensor_a2)
node_add_a2.addOutput(`${obj.name}_add_a2`)
const node_mul_a2 = new onnx.NodeProto()
node_mul_a2.setOpType('Mul')
node_mul_a2.addInput(`${obj.name}_add_a2`)
node_mul_a2.addInput(`${obj.name}_t`)
node_mul_a2.addOutput(`${obj.name}_mul_a2`)
const node_add_a1 = new onnx.NodeProto()
node_add_a1.setOpType('Add')
node_add_a1.addInput(`${obj.name}_mul_a2`)
node_add_a1.addInput(tensor_a1)
node_add_a1.addOutput(`${obj.name}_add_a1`)
const node_mul_a1 = new onnx.NodeProto()
node_mul_a1.setOpType('Mul')
node_mul_a1.addInput(`${obj.name}_add_a1`)
node_mul_a1.addInput(`${obj.name}_t`)
node_mul_a1.addOutput(`${obj.name}_mul_a1`)
const node_pow = new onnx.NodeProto()
node_pow.setOpType('Pow')
node_pow.addInput(`${obj.name}_v`)
node_pow.addInput(tensor2)
node_pow.addOutput(`${obj.name}_pow`)
const node_neg = new onnx.NodeProto()
node_neg.setOpType('Neg')
node_neg.addInput(`${obj.name}_pow`)
node_neg.addOutput(`${obj.name}_neg`)
const node_exp = new onnx.NodeProto()
node_exp.setOpType('Exp')
node_exp.addInput(`${obj.name}_neg`)
node_exp.addOutput(`${obj.name}_exp`)
const node_mul = new onnx.NodeProto()
node_mul.setOpType('Mul')
node_mul.addInput(`${obj.name}_mul_a1`)
node_mul.addInput(`${obj.name}_exp`)
node_mul.addOutput(`${obj.name}_mul`)
const node_erf = new onnx.NodeProto()
node_erf.setOpType('Sub')
node_erf.addInput(tensor1)
node_erf.addInput(`${obj.name}_mul`)
node_erf.addOutput(`${obj.name}_erf`)
const node_sign = new onnx.NodeProto()
node_sign.setOpType('Sign')
node_sign.addInput(`${obj.name}_v`)
node_sign.addOutput(`${obj.name}_sign`)
const node_sign_erf = new onnx.NodeProto()
node_sign_erf.setOpType('Mul')
node_sign_erf.addInput(`${obj.name}_erf`)
node_sign_erf.addInput(`${obj.name}_sign`)
node_sign_erf.addOutput(`${obj.name}_sign_erf`)
const node_erf_add1 = new onnx.NodeProto()
node_erf_add1.setOpType('Add')
node_erf_add1.addInput(tensor1)
node_erf_add1.addInput(`${obj.name}_sign_erf`)
node_erf_add1.addOutput(`${obj.name}_erf_add1`)
const node_mul_v = new onnx.NodeProto()
node_mul_v.setOpType('Mul')
node_mul_v.addInput(input)
node_mul_v.addInput(`${obj.name}_erf_add1`)
node_mul_v.addOutput(`${obj.name}_mul_v`)
const node_div2 = new onnx.NodeProto()
node_div2.setOpType('Div')
node_div2.addInput(`${obj.name}_mul_v`)
node_div2.addInput(tensor2)
node_div2.addOutput(obj.name)
const graph = model.getGraph()
graph.addNode(node_sqrt2)
graph.addNode(node_v)
graph.addNode(node_abs)
graph.addNode(node_mul_p)
graph.addNode(node_add1)
graph.addNode(node_t)
graph.addNode(node_mul_a5)
graph.addNode(node_add_a4)
graph.addNode(node_mul_a4)
graph.addNode(node_add_a3)
graph.addNode(node_mul_a3)
graph.addNode(node_add_a2)
graph.addNode(node_mul_a2)
graph.addNode(node_add_a1)
graph.addNode(node_mul_a1)
graph.addNode(node_pow)
graph.addNode(node_neg)
graph.addNode(node_exp)
graph.addNode(node_mul)
graph.addNode(node_erf)
graph.addNode(node_sign)
graph.addNode(node_sign_erf)
graph.addNode(node_erf_add1)
graph.addNode(node_mul_v)
graph.addNode(node_div2)
}
},
}