UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

34 lines 1.6 kB
import { ConvNode } from '../nodes/conv/conv'; import { SequenceOptimization } from './optimization'; export class ConvBatchNorm extends SequenceOptimization { constructor() { super(['Conv', 'BatchNormalization']); } apply(nodes, resolveConstant, constants, onnxVersion) { const conv = nodes[0]; const batchNorm = nodes[1]; const kernelConv = resolveConstant(conv.inputs[1]); const biasConv = resolveConstant(conv.inputs[2]); const scaleBN = resolveConstant(batchNorm.inputs[1]); const biasBN = resolveConstant(batchNorm.inputs[2]); const meanBN = resolveConstant(batchNorm.inputs[3]); const varianceBN = resolveConstant(batchNorm.inputs[4]); const varSqrt = varianceBN.add(batchNorm.epsTensor).sqrt(); const scale = scaleBN.divide(varSqrt); varSqrt.delete(); const bias = biasBN.subtract(meanBN.multiply(scale)); const newShape = [ ...scale.getShape(), ...new Array(kernelConv.getShape().length - scale.getShape().length).fill(1), ]; const newKernel = kernelConv.multiply(scale.reshape(newShape, false)); let newBias = bias; if (biasConv !== undefined) { const scaledBias = biasConv.multiply(scale); newBias = newBias.add(scaledBias); scaledBias.delete(); } return new ConvNode(Object.entries(conv.attributes).map(x => x[1]), [conv.inputs[0]], batchNorm.outputs, constants, onnxVersion, conv.mode, newKernel, newBias); } } //# sourceMappingURL=convBatchnorm.js.map