@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
95 lines (83 loc) • 2.91 kB
text/typescript
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {ENGINE} from '../engine';
import {keep, tidy} from '../globals';
import {scalar} from '../ops/ops';
import {ConfigDict, registerClass, Serializable, SerializableConstructor} from '../serialization';
import {Scalar} from '../tensor';
import {NamedTensor, NamedTensorMap} from '../tensor_types';
import {Optimizer} from './optimizer';
/** @doclink Optimizer */
export class SGDOptimizer extends Optimizer {
/** @nocollapse */
static className = 'SGD'; // Note: Name matters for Python compatibility.
protected c: Scalar;
constructor(protected learningRate: number) {
super();
this.setLearningRate(learningRate);
}
applyGradients(variableGradients: NamedTensorMap|NamedTensor[]) {
const varNames = Array.isArray(variableGradients) ?
variableGradients.map(v => v.name) :
Object.keys(variableGradients);
varNames.forEach((name, i) => {
const gradient = Array.isArray(variableGradients) ?
variableGradients[i].tensor :
variableGradients[name];
if (gradient == null) {
return;
}
const value = ENGINE.registeredVariables[name];
tidy(() => {
const newValue = this.c.mul(gradient).add(value);
value.assign(newValue);
});
});
this.incrementIterations();
}
/**
* Sets the learning rate of the optimizer.
*/
setLearningRate(learningRate: number) {
this.learningRate = learningRate;
if (this.c != null) {
this.c.dispose();
}
this.c = keep(scalar(-learningRate));
}
dispose() {
this.c.dispose();
}
async getWeights(): Promise<NamedTensor[]> {
return [await this.saveIterations()];
}
async setWeights(weightValues: NamedTensor[]): Promise<void> {
weightValues = await this.extractIterations(weightValues);
if (weightValues.length !== 0) {
throw new Error('SGD optimizer does not have settable weights.');
}
}
getConfig(): ConfigDict {
return {'learningRate': this.learningRate};
}
/** @nocollapse */
static fromConfig<T extends Serializable>(
cls: SerializableConstructor<T>, config: ConfigDict): T {
return new cls(config['learningRate']);
}
}
registerClass(SGDOptimizer);