UNPKG

@ai-on-browser/data-analysis-models

Version:

Data analysis model package without any dependencies

184 lines (173 loc) 7 kB
import { onnx } from '../onnx_exporter.js' /** * Handle batch normalization layer */ export default { /** * Export to onnx object. * @param {onnx.ModelProto} model Model object * @param {import("../../graph").LayerObject & {type: 'batch_normalization'}} obj Node object * @param {{[key: string]: {type: onnx.TensorProto.DataType; size: number[]}}} info Output informatino of other layers */ export(model, obj, info) { const graph = model.getGraph() const input = Array.isArray(obj.input) ? obj.input[0] : obj.input const size = info[input].size.concat() const node = new onnx.NodeProto() node.setOpType('BatchNormalization') if (obj.channel_dim === 1) { node.addInput(input) node.addOutput(obj.name) } else if (obj.channel_dim == null || obj.channel_dim === -1) { const node_transpose1 = new onnx.NodeProto() node_transpose1.setOpType('Transpose') node_transpose1.addInput(input) node_transpose1.addOutput(obj.name + '_t1') const attrPerm1 = new onnx.AttributeProto() attrPerm1.setName('perm') attrPerm1.setType(onnx.AttributeProto.AttributeType.INTS) const perm1 = Array.from(size, (_, i) => i - 1) perm1[0] = 0 perm1[1] = size.length - 1 attrPerm1.setIntsList(perm1) node_transpose1.addAttribute(attrPerm1) graph.addNode(node_transpose1) node.addInput(obj.name + '_t1') node.addOutput(obj.name + '_ap') const node_transpose2 = new onnx.NodeProto() node_transpose2.setOpType('Transpose') node_transpose2.addInput(obj.name + '_ap') node_transpose2.addOutput(obj.name) const attrPerm2 = new onnx.AttributeProto() attrPerm2.setName('perm') attrPerm2.setType(onnx.AttributeProto.AttributeType.INTS) const perm2 = Array.from(size, (_, i) => i + 1) perm2[0] = 0 perm2[perm2.length - 1] = 1 attrPerm2.setIntsList(perm2) node_transpose2.addAttribute(attrPerm2) graph.addNode(node_transpose2) } else { throw new Error(`Not implemented value of attribute 'channel_dim' ${obj.channel_dim}.`) } const attrTrainingMode = new onnx.AttributeProto() attrTrainingMode.setName('training_mode') attrTrainingMode.setType(onnx.AttributeProto.AttributeType.INT) attrTrainingMode.setI(0) node.addAttribute(attrTrainingMode) const channelDim = obj.channel_dim === 1 ? 1 : size.length - 1 if (typeof obj.scale === 'string') { node.addInput(obj.scale) } else { if (!Array.isArray(obj.scale) && size[channelDim] == null) { throw new Error('Size of channel dim must be specified if scale is scalar.') } const scale = Array.isArray(obj.scale) ? obj.scale : Array(size[channelDim]).fill(obj.scale ?? 1) const tensor_scale = new onnx.TensorProto() tensor_scale.setName(obj.name + '_scale') tensor_scale.setDataType(onnx.TensorProto.DataType.FLOAT) tensor_scale.setDimsList([scale.length]) tensor_scale.setFloatDataList(scale) graph.addInitializer(tensor_scale) node.addInput(obj.name + '_scale') } if (typeof obj.offset === 'string') { node.addInput(obj.offset) } else { if (!Array.isArray(obj.offset) && size[channelDim] == null) { throw new Error('Size of channel dim must be specified if offset is scalar.') } const offset = Array.isArray(obj.offset) ? obj.offset : Array(size[channelDim]).fill(obj.offset ?? 0) const tensor_offset = new onnx.TensorProto() tensor_offset.setName(obj.name + '_offset') tensor_offset.setDataType(onnx.TensorProto.DataType.FLOAT) tensor_offset.setDimsList([offset.length]) tensor_offset.setFloatDataList(offset) graph.addInitializer(tensor_offset) node.addInput(obj.name + '_offset') } let tensor_axis = null const readyReduceAxisTensor = () => { if (tensor_axis) { return } const axis = Array.from(size, (_, i) => i) axis.splice(channelDim, 1) tensor_axis = new onnx.TensorProto() tensor_axis.setName(obj.name + '_reduce_axis') tensor_axis.setDataType(onnx.TensorProto.DataType.INT64) tensor_axis.setDimsList([axis.length]) tensor_axis.setInt64DataList(axis) graph.addInitializer(tensor_axis) } if (typeof obj.input_mean === 'string') { node.addInput(obj.input_mean) } else if (obj.input_mean) { const tensor_input_mean = new onnx.TensorProto() tensor_input_mean.setName(obj.name + '_input_mean') tensor_input_mean.setDataType(onnx.TensorProto.DataType.FLOAT) tensor_input_mean.setDimsList([obj.input_mean.length]) tensor_input_mean.setFloatDataList(obj.input_mean) graph.addInitializer(tensor_input_mean) node.addInput(obj.name + '_input_mean') } else { readyReduceAxisTensor() const node_reduce_mean = new onnx.NodeProto() node_reduce_mean.setOpType('ReduceMean') node_reduce_mean.addInput(input) node_reduce_mean.addInput(obj.name + '_reduce_axis') node_reduce_mean.addOutput(obj.name + '_input_mean') const attrKeepdims = new onnx.AttributeProto() attrKeepdims.setName('keepdims') attrKeepdims.setType(onnx.AttributeProto.AttributeType.INT) attrKeepdims.setI(0) node_reduce_mean.addAttribute(attrKeepdims) graph.addNode(node_reduce_mean) node.addInput(obj.name + '_input_mean') } if (typeof obj.input_var === 'string') { node.addInput(obj.input_var) } else if (obj.input_var) { const tensor_input_var = new onnx.TensorProto() tensor_input_var.setName(obj.name + '_input_var') tensor_input_var.setDataType(onnx.TensorProto.DataType.FLOAT) tensor_input_var.setDimsList([obj.input_var.length]) tensor_input_var.setFloatDataList(obj.input_var) graph.addInitializer(tensor_input_var) node.addInput(obj.name + '_input_var') } else { readyReduceAxisTensor() const attrKeepdims = new onnx.AttributeProto() attrKeepdims.setName('keepdims') attrKeepdims.setType(onnx.AttributeProto.AttributeType.INT) attrKeepdims.setI(0) const node_reduce_mean1 = new onnx.NodeProto() node_reduce_mean1.setOpType('ReduceMean') node_reduce_mean1.addInput(input) node_reduce_mean1.addInput(obj.name + '_reduce_axis') node_reduce_mean1.addOutput(obj.name + '_input_var_mean') graph.addNode(node_reduce_mean1) const node_sub = new onnx.NodeProto() node_sub.setOpType('Sub') node_sub.addInput(input) node_sub.addInput(obj.name + '_input_var_mean') node_sub.addOutput(obj.name + '_input_var_sub') graph.addNode(node_sub) const node_mul = new onnx.NodeProto() node_mul.setOpType('Mul') node_mul.addInput(obj.name + '_input_var_sub') node_mul.addInput(obj.name + '_input_var_sub') node_mul.addOutput(obj.name + '_input_var_mul') graph.addNode(node_mul) const node_reduce_mean2 = new onnx.NodeProto() node_reduce_mean2.setOpType('ReduceMean') node_reduce_mean2.addInput(obj.name + '_input_var_mul') node_reduce_mean2.addInput(obj.name + '_reduce_axis') node_reduce_mean2.addOutput(obj.name + '_input_var') node_reduce_mean2.addAttribute(attrKeepdims) graph.addNode(node_reduce_mean2) node.addInput(obj.name + '_input_var') } graph.addNode(node) }, }