@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
49 lines (48 loc) • 2.03 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
var tf = require("../index");
var test_util_1 = require("../test_util");
var jasmine_util_1 = require("../jasmine_util");
jasmine_util_1.describeWithFlags('MomentumOptimizer', test_util_1.ALL_ENVS, function () {
it('basic', function () {
var learningRate = .1;
var momentum = .5;
var optimizer = tf.train.momentum(learningRate, momentum);
var x = tf.tensor1d([1, 2]).variable();
var f = function () { return x.square().sum(); };
var numTensors = tf.memory().numTensors;
var cost = optimizer.minimize(f, true);
expect(tf.memory().numTensors).toBe(numTensors + 2);
test_util_1.expectArraysClose(x, [.8, 1.6]);
cost.dispose();
numTensors = tf.memory().numTensors;
cost = optimizer.minimize(f, false);
test_util_1.expectArraysClose(x, [0.54, 1.08]);
expect(tf.memory().numTensors).toBe(numTensors);
expect(cost).toBe(null);
x.dispose();
optimizer.dispose();
expect(tf.memory().numTensors).toBe(1);
});
it('basic - with Nesterov', function () {
var learningRate = .1;
var momentum = .5;
var useNesterov = true;
var optimizer = tf.train.momentum(learningRate, momentum, useNesterov);
var x = tf.tensor1d([1, 2]).variable();
var f = function () { return x.square().sum(); };
var numTensors = tf.memory().numTensors;
var cost = optimizer.minimize(f, true);
expect(tf.memory().numTensors).toBe(numTensors + 2);
test_util_1.expectArraysClose(x, [.7, 1.4]);
cost.dispose();
numTensors = tf.memory().numTensors;
cost = optimizer.minimize(f, false);
test_util_1.expectArraysClose(x, [0.44, 0.88]);
expect(tf.memory().numTensors).toBe(numTensors);
expect(cost).toBe(null);
x.dispose();
optimizer.dispose();
expect(tf.memory().numTensors).toBe(1);
});
});