@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
25 lines • 701 B
JavaScript
import { Optimizer } from './optimizer';
/**
* Stochastic gradient descent optimizer
*/
export class SGD extends Optimizer {
/**
* New SGD optimizer for a particular model.
*
* @param lr Learning rate, the step size for each update step
*/
constructor(model, lr = 0.001) {
super(model);
this.lr = lr;
}
step() {
for (const parameter of this.parameters) {
if (parameter.grad !== undefined) {
const oldValue = parameter.value;
parameter.value = parameter.value.subtract(parameter.grad, 1, this.lr);
oldValue.delete();
}
}
}
}
//# sourceMappingURL=SGD.js.map