UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

117 lines 5.57 kB
"use strict"; var __extends = (this && this.__extends) || (function () { var extendStatics = Object.setPrototypeOf || ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; }; return function (d, b) { extendStatics(d, b); function __() { this.constructor = d; } d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); }; })(); Object.defineProperty(exports, "__esModule", { value: true }); var environment_1 = require("../environment"); var globals_1 = require("../globals"); var ops_1 = require("../ops/ops"); var serialization_1 = require("../serialization"); var optimizer_1 = require("./optimizer"); var AdamaxOptimizer = (function (_super) { __extends(AdamaxOptimizer, _super); function AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay) { if (epsilon === void 0) { epsilon = 1e-8; } if (decay === void 0) { decay = 0.0; } var _this = _super.call(this) || this; _this.learningRate = learningRate; _this.beta1 = beta1; _this.beta2 = beta2; _this.epsilon = epsilon; _this.decay = decay; _this.accumulatedFirstMoment = {}; _this.accumulatedWeightedInfNorm = {}; _this.c = globals_1.keep(ops_1.scalar(-learningRate)); _this.epsScalar = globals_1.keep(ops_1.scalar(epsilon)); _this.beta1Scalar = globals_1.keep(ops_1.scalar(beta1)); _this.beta2Scalar = globals_1.keep(ops_1.scalar(beta2)); _this.decayScalar = globals_1.keep(ops_1.scalar(decay)); globals_1.tidy(function () { _this.iteration = ops_1.scalar(0).variable(); _this.accBeta1 = ops_1.scalar(beta1).variable(); }); _this.oneMinusBeta1 = globals_1.keep(ops_1.scalar(1 - beta1)); _this.one = globals_1.keep(ops_1.scalar(1)); return _this; } AdamaxOptimizer.prototype.applyGradients = function (variableGradients) { var _this = this; globals_1.tidy(function () { var oneMinusAccBeta1 = _this.one.sub(_this.accBeta1); var lr = _this.c.div(_this.one.add(_this.decayScalar.mul(_this.iteration))); for (var variableName in variableGradients) { var value = environment_1.ENV.engine.registeredVariables[variableName]; if (_this.accumulatedFirstMoment[variableName] == null) { var trainable = false; _this.accumulatedFirstMoment[variableName] = ops_1.zerosLike(value).variable(trainable); } if (_this.accumulatedWeightedInfNorm[variableName] == null) { var trainable = false; _this.accumulatedWeightedInfNorm[variableName] = ops_1.zerosLike(value).variable(trainable); } var gradient = variableGradients[variableName]; var firstMoment = _this.accumulatedFirstMoment[variableName]; var weightedInfNorm = _this.accumulatedWeightedInfNorm[variableName]; var newFirstMoment = _this.beta1Scalar.mul(firstMoment) .add(_this.oneMinusBeta1.mul(gradient)); var ut0 = _this.beta2Scalar.mul(weightedInfNorm); var ut1 = gradient.abs(); var newWeightedInfNorm = ut0.maximum(ut1); _this.accumulatedFirstMoment[variableName].assign(newFirstMoment); _this.accumulatedWeightedInfNorm[variableName].assign(newWeightedInfNorm); var newValue = lr.div(oneMinusAccBeta1) .mul(newFirstMoment.div(_this.epsScalar.add(newWeightedInfNorm))) .add(value); value.assign(newValue); } _this.iteration.assign(_this.iteration.add(_this.one)); _this.accBeta1.assign(_this.accBeta1.mul(_this.beta1Scalar)); }); }; AdamaxOptimizer.prototype.dispose = function () { var _this = this; this.c.dispose(); this.epsScalar.dispose(); this.accBeta1.dispose(); this.beta1Scalar.dispose(); this.beta2Scalar.dispose(); this.oneMinusBeta1.dispose(); this.decayScalar.dispose(); this.iteration.dispose(); this.one.dispose(); if (this.accumulatedFirstMoment != null) { Object.keys(this.accumulatedFirstMoment) .forEach(function (name) { return _this.accumulatedFirstMoment[name].dispose(); }); } if (this.accumulatedWeightedInfNorm != null) { Object.keys(this.accumulatedWeightedInfNorm) .forEach(function (name) { return _this.accumulatedWeightedInfNorm[name].dispose(); }); } }; AdamaxOptimizer.prototype.getConfig = function () { return { learningRate: this.learningRate, beta1: this.beta1, beta2: this.beta2, epsilon: this.epsilon, decay: this.decay }; }; AdamaxOptimizer.fromConfig = function (cls, config) { return new cls(config.learningRate, config.beta1, config.beta2, config.epsilon, config.decay); }; AdamaxOptimizer.className = 'AdamaxOptimizer'; return AdamaxOptimizer; }(optimizer_1.Optimizer)); exports.AdamaxOptimizer = AdamaxOptimizer; serialization_1.SerializationMap.register(AdamaxOptimizer); //# sourceMappingURL=adamax_optimizer.js.map