UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

65 lines 2.82 kB
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; import { Variable } from '../../autograd/variable'; import { CPUTensor } from '../../tensor/cpu/tensor'; import { toCPU, toGPU, toWASM } from '../../util/convert'; import { OnnxNode } from '../node'; export class BatchNormalizationNode extends OnnxNode { constructor(attributes, inputs, outputs, constants, onnxVersion, mode) { super(attributes, inputs, outputs, constants, onnxVersion, mode); this.epsilon = this.getAttributeFloat('epsilon') || 1e-5; this.momentum = this.getAttributeFloat('momentum') || 0.9; this.epsTensor = new CPUTensor([1], [this.epsilon]); if (mode === 'train') { this.epsTensor = new Variable(this.epsTensor); } //TODO: Handle lower onnxversions here } forward(inputs) { return __awaiter(this, void 0, void 0, function* () { const x = inputs[0]; let scale = inputs[1]; let B = inputs[2]; let mean = inputs[3]; let variance = inputs[4]; //TODO: Handle lower onnx versions here const C = scale.getShape()[0]; const newShape = [1, C, ...new Array(x.getShape().length - 2).fill(1)]; scale = scale.reshape(newShape, false); B = B.reshape(newShape, false); mean = mean.reshape(newShape, false); variance = variance.reshape(newShape, false); const result = x.normalize(mean, variance, this.epsilon, scale, B); return [result]; }); } getType() { return 'BatchNormalization'; } toCPU() { return __awaiter(this, void 0, void 0, function* () { this.epsTensor = yield toCPU(this.epsTensor); }); } toWASM() { return __awaiter(this, void 0, void 0, function* () { this.epsTensor = yield toWASM(this.epsTensor); }); } toGPU() { return __awaiter(this, void 0, void 0, function* () { this.epsTensor = yield toGPU(this.epsTensor); }); } delete() { this.epsTensor.delete(); } } //# sourceMappingURL=batchNormalization.js.map