@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
117 lines • 5.57 kB
JavaScript
;
var __extends = (this && this.__extends) || (function () {
var extendStatics = Object.setPrototypeOf ||
({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; };
return function (d, b) {
extendStatics(d, b);
function __() { this.constructor = d; }
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
};
})();
Object.defineProperty(exports, "__esModule", { value: true });
var environment_1 = require("../environment");
var globals_1 = require("../globals");
var ops_1 = require("../ops/ops");
var serialization_1 = require("../serialization");
var optimizer_1 = require("./optimizer");
var AdamaxOptimizer = (function (_super) {
__extends(AdamaxOptimizer, _super);
function AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay) {
if (epsilon === void 0) { epsilon = 1e-8; }
if (decay === void 0) { decay = 0.0; }
var _this = _super.call(this) || this;
_this.learningRate = learningRate;
_this.beta1 = beta1;
_this.beta2 = beta2;
_this.epsilon = epsilon;
_this.decay = decay;
_this.accumulatedFirstMoment = {};
_this.accumulatedWeightedInfNorm = {};
_this.c = globals_1.keep(ops_1.scalar(-learningRate));
_this.epsScalar = globals_1.keep(ops_1.scalar(epsilon));
_this.beta1Scalar = globals_1.keep(ops_1.scalar(beta1));
_this.beta2Scalar = globals_1.keep(ops_1.scalar(beta2));
_this.decayScalar = globals_1.keep(ops_1.scalar(decay));
globals_1.tidy(function () {
_this.iteration = ops_1.scalar(0).variable();
_this.accBeta1 = ops_1.scalar(beta1).variable();
});
_this.oneMinusBeta1 = globals_1.keep(ops_1.scalar(1 - beta1));
_this.one = globals_1.keep(ops_1.scalar(1));
return _this;
}
AdamaxOptimizer.prototype.applyGradients = function (variableGradients) {
var _this = this;
globals_1.tidy(function () {
var oneMinusAccBeta1 = _this.one.sub(_this.accBeta1);
var lr = _this.c.div(_this.one.add(_this.decayScalar.mul(_this.iteration)));
for (var variableName in variableGradients) {
var value = environment_1.ENV.engine.registeredVariables[variableName];
if (_this.accumulatedFirstMoment[variableName] == null) {
var trainable = false;
_this.accumulatedFirstMoment[variableName] =
ops_1.zerosLike(value).variable(trainable);
}
if (_this.accumulatedWeightedInfNorm[variableName] == null) {
var trainable = false;
_this.accumulatedWeightedInfNorm[variableName] =
ops_1.zerosLike(value).variable(trainable);
}
var gradient = variableGradients[variableName];
var firstMoment = _this.accumulatedFirstMoment[variableName];
var weightedInfNorm = _this.accumulatedWeightedInfNorm[variableName];
var newFirstMoment = _this.beta1Scalar.mul(firstMoment)
.add(_this.oneMinusBeta1.mul(gradient));
var ut0 = _this.beta2Scalar.mul(weightedInfNorm);
var ut1 = gradient.abs();
var newWeightedInfNorm = ut0.maximum(ut1);
_this.accumulatedFirstMoment[variableName].assign(newFirstMoment);
_this.accumulatedWeightedInfNorm[variableName].assign(newWeightedInfNorm);
var newValue = lr.div(oneMinusAccBeta1)
.mul(newFirstMoment.div(_this.epsScalar.add(newWeightedInfNorm)))
.add(value);
value.assign(newValue);
}
_this.iteration.assign(_this.iteration.add(_this.one));
_this.accBeta1.assign(_this.accBeta1.mul(_this.beta1Scalar));
});
};
AdamaxOptimizer.prototype.dispose = function () {
var _this = this;
this.c.dispose();
this.epsScalar.dispose();
this.accBeta1.dispose();
this.beta1Scalar.dispose();
this.beta2Scalar.dispose();
this.oneMinusBeta1.dispose();
this.decayScalar.dispose();
this.iteration.dispose();
this.one.dispose();
if (this.accumulatedFirstMoment != null) {
Object.keys(this.accumulatedFirstMoment)
.forEach(function (name) { return _this.accumulatedFirstMoment[name].dispose(); });
}
if (this.accumulatedWeightedInfNorm != null) {
Object.keys(this.accumulatedWeightedInfNorm)
.forEach(function (name) { return _this.accumulatedWeightedInfNorm[name].dispose(); });
}
};
AdamaxOptimizer.prototype.getConfig = function () {
return {
learningRate: this.learningRate,
beta1: this.beta1,
beta2: this.beta2,
epsilon: this.epsilon,
decay: this.decay
};
};
AdamaxOptimizer.fromConfig = function (cls, config) {
return new cls(config.learningRate, config.beta1, config.beta2, config.epsilon, config.decay);
};
AdamaxOptimizer.className = 'AdamaxOptimizer';
return AdamaxOptimizer;
}(optimizer_1.Optimizer));
exports.AdamaxOptimizer = AdamaxOptimizer;
serialization_1.SerializationMap.register(AdamaxOptimizer);
//# sourceMappingURL=adamax_optimizer.js.map