@ai-on-browser/data-analysis-models
Version:
Data analysis model package without any dependencies
78 lines (76 loc) • 2.52 kB
JavaScript
import { onnx } from '../onnx_importer.js'
import { loadTensor, loadAttribute } from '../utils.js'
/**
* Handle batch-normalization operator
* @module HandleONNXBatchNormalizationOperator
* @see https://github.com/onnx/onnx/blob/main/docs/Operators.md#BatchNormalization
*/
export default {
/**
* Import from onnx object.
* @param {onnx.ModelProto} model Model object
* @param {onnx.NodeProto} node Node object
* @returns {object[]} Objects represented a layer
*/
import(model, node) {
const attrs = { epsilon: 1e-5, momentum: 0.9, training_mode: 0 }
for (const attribute of node.getAttributeList()) {
attrs[attribute.getName()] = loadAttribute(attribute)
}
const inputList = node.getInputList()
const initializers = {}
for (const initializer of model.getGraph().getInitializerList()) {
if (initializer.getName() === inputList[1]) {
initializers.scale = loadTensor(initializer)
} else if (initializer.getName() === inputList[2]) {
initializers.b = loadTensor(initializer)
} else if (initializer.getName() === inputList[3]) {
initializers.inputMean = loadTensor(initializer)
} else if (initializer.getName() === inputList[4]) {
initializers.inputVar = loadTensor(initializer)
}
}
if (attrs.training_mode) {
throw new Error(`Invalid attribute 'training_mode' value ${attrs.training_mode}.`)
// return [
// { type: 'const', input: [], value: 1, name: `${node.inputList[0]}_1` },
// { type: 'const', input: [], value: attrs.momentum, name: `${node.inputList[0]}_momentum` },
// { type: 'const', input: [], value: attrs.epsilon, name: `${node.inputList[0]}_epsilon` },
// {
// type: 'mult',
// input: [node.inputList[3], `${node.inputList[0]}_momentum`],
// name: `${node.inputList[3]}_momentum`,
// },
// ]
}
const outputList = node.getOutputList()
const ret = [
{
type: 'batch_normalization',
input: [inputList[0]],
name: outputList[0],
scale: initializers.scale ?? inputList[1],
offset: initializers.b ?? inputList[2],
epsilon: attrs.epsilon ?? 1.0e-5,
input_mean: initializers.b ?? inputList[3],
input_var: initializers.b ?? inputList[4],
channel_dim: 1,
},
]
if (outputList.length >= 2) {
ret.push({
type: 'identity',
input: [`${outputList[0]}[mean]`],
name: outputList[1],
})
}
if (outputList.length >= 3) {
ret.push({
type: 'identity',
input: [`${outputList[0]}[var]`],
name: outputList[2],
})
}
return ret
},
}