UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

121 lines 4.97 kB
import { GPUTensor } from '../../../tensor/gpu/tensor'; import { Optimizer } from '../optimizer'; import { defaultUpdateMomentsD } from './updateMoments'; import { defaultUpdateValueD } from './updateParams'; /** * Implements the Adam optimizer * * This is currently quite slow on the CPU and WASM backends. On the GPU * backend, one update step is only slightly slower than an update step of SGD * and will converge a lot quicker. */ export class Adam extends Optimizer { constructor(model, lr = 0.001, beta1 = 0.9, beta2 = 0.999, epsilon = 10e-8) { super(model); this.lr = lr; this.beta1 = beta1; this.beta2 = beta2; this.epsilon = epsilon; this.t = 0; const params = this.parameters; if (params[0].value instanceof GPUTensor) { this.moments = new Array(params.length); for (let i = 0; i < params.length; i++) { this.moments[i] = new GPUTensor(new Array(params[i].value.size * 4).fill(0), [...params[i].getShape(), 4], params[i].value.dtype); } } else { this.moment1 = new Array(params.length).fill(undefined); this.moment2 = new Array(params.length).fill(undefined); } } step() { this.t++; if (this.moment1 !== undefined && this.moment2 !== undefined) { for (let i = 0; i < this.parameters.length; i++) { const parameter = this.parameters[i]; if (parameter.grad !== undefined) { const oldValue = parameter.value; const { newValue, moment1, moment2 } = this.paramStep(parameter.value, parameter.grad, this.moment1[i], this.moment2[i]); parameter.value = newValue; this.moment1[i] = moment1; this.moment2[i] = moment2; oldValue.delete(); } } } else if (this.moments !== undefined) { for (let i = 0; i < this.parameters.length; i++) { const parameter = this.parameters[i]; if (parameter.grad !== undefined) { const oldValue = parameter.value; const { newValue, moments } = this.gpuParamStep(parameter.value, parameter.grad, this.moments[i]); parameter.value = newValue; this.moments[i] = moments; oldValue.delete(); } } } } updateMoments(grad, moment1, moment2) { let moment1New; if (moment1 === undefined) { moment1New = grad.multiplyScalar(1 - this.beta1); } else { const oldMoment1 = moment1; moment1New = moment1.add(grad, this.beta1, 1 - this.beta1); oldMoment1.delete(); } let moment2New; if (moment2 === undefined) { moment2New = grad.multiply(grad, 1 - this.beta2); } else { const gradSquared = grad.multiply(grad); const oldMoment2 = moment2; moment2New = moment2.add(gradSquared, this.beta2, 1 - this.beta2); gradSquared.delete(); oldMoment2.delete(); } return { moment1New, moment2New }; } getCorrectedMoments(moment1, moment2) { const correctMoment1 = moment1.addMultiplyScalar(1 / (1 - Math.pow(this.beta1, this.t)), 0); const correctMoment2 = moment2.addMultiplyScalar(1 / (1 - Math.pow(this.beta2, this.t)), 0); return { correctMoment1, correctMoment2 }; } paramStep(value, grad, moment1, moment2) { const { moment1New, moment2New } = this.updateMoments(grad, moment1, moment2); // This is not 100% correct, in the original paper // the epsilon occurs outside of the square root // It does not make much of a difference though // and is slightly faster const correctMoment2 = moment2New.addMultiplyScalar(1 / (1 - Math.pow(this.beta2, this.t)), this.epsilon); const moment2Sqrt = correctMoment2.sqrt(); correctMoment2.delete(); const step = moment1New.divide(moment2Sqrt, -this.lr / (1 - Math.pow(this.beta1, this.t))); moment2Sqrt.delete(); const newValue = value.add(step); step.delete(); return { newValue, moment1, moment2 }; } gpuParamStep(value, grad, moments) { const newMoments = defaultUpdateMomentsD.calc({ Grad: grad, Moments: moments, beta1: this.beta1, beta2: this.beta2, t: this.t, }, value.dtype); moments.delete(); const newValue = defaultUpdateValueD.calc({ Value: value, Moments: newMoments, alpha: this.lr, epsilon: this.epsilon, }, value.dtype); return { newValue, moments: newMoments }; } } //# sourceMappingURL=Adam.js.map