@ai-on-browser/data-analysis-models
Version:
Data analysis model package without any dependencies
71 lines (66 loc) • 2.19 kB
JavaScript
import { onnx } from '../onnx_exporter.js'
/**
* Handle less layer
*/
export default {
/**
* Export to onnx object.
* @param {onnx.ModelProto} model Model object
* @param {import("../../graph").LayerObject & {type: 'less'}} obj Node object
* @returns {{type: onnx.TensorProto.DataType; size: number[]}} Output information of this layer
*/
export(model, obj) {
if (!Array.isArray(obj.input)) {
throw new Error(`Invalid attribute 'input' value ${obj.input}.`)
}
const graph = model.getGraph()
const node = new onnx.NodeProto()
if (obj.input.length === 1) {
const node_shape = new onnx.NodeProto()
node_shape.setOpType('Shape')
node_shape.addInput(obj.input[0])
node_shape.addOutput(`${obj.name}_shape`)
graph.addNode(node_shape)
node.setOpType('ConstantOfShape')
node.addInput(`${obj.name}_shape`)
const tensor = new onnx.TensorProto()
tensor.setDataType(onnx.TensorProto.DataType.BOOL)
tensor.setDimsList([1])
tensor.setInt32DataList([1])
const attrValue = new onnx.AttributeProto()
attrValue.setName('value')
attrValue.setType(onnx.AttributeProto.AttributeType.TENSOR)
attrValue.setT(tensor)
node.addAttribute(attrValue)
} else if (obj.input.length === 2) {
node.setOpType('Less')
node.addInput(obj.input[0])
node.addInput(obj.input[1])
} else {
for (let i = 0; i < obj.input.length - 1; i++) {
const node_equal = new onnx.NodeProto()
node_equal.setOpType('Less')
node_equal.addInput(obj.input[i])
node_equal.addInput(obj.input[i + 1])
node_equal.addOutput(`${obj.name}_lt_${i}`)
graph.addNode(node_equal)
}
let prev_in = `${obj.name}_lt_0`
for (let i = 1; i < obj.input.length - 2; i++) {
const node_mul = new onnx.NodeProto()
node_mul.setOpType('And')
node_mul.addInput(prev_in)
node_mul.addInput(`${obj.name}_lt_${i}`)
prev_in = `${obj.name}_and_${i - 1}`
node_mul.addOutput(prev_in)
graph.addNode(node_mul)
}
node.setOpType('And')
node.addInput(prev_in)
node.addInput(`${obj.name}_lt_${obj.input.length - 2}`)
}
node.addOutput(obj.name)
graph.addNode(node)
return { type: onnx.TensorProto.DataType.BOOL }
},
}