UNPKG

catniff

Version:

A small Torch-like deep learning framework for Javascript

124 lines (123 loc) 5.16 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.Optim = void 0; const core_1 = require("./core"); class SGD { params; momentumBuffers = new Map(); lr; momentum; dampening; weightDecay; nesterov; constructor(params, options) { this.params = params; this.lr = options?.lr || 0.001; this.momentum = options?.momentum || 0; this.dampening = options?.dampening || 0; this.weightDecay = options?.weightDecay || 0; this.nesterov = options?.nesterov || false; } step() { for (const param of this.params) { if (!param.grad || !param.requiresGrad) continue; let grad = param.grad.detach(), detachedParam = param.detach(); // Apply weight decay (L2 regularization) if (this.weightDecay !== 0) { grad = grad.add(detachedParam.mul(this.weightDecay)); } // Apply momentum if (this.momentum !== 0) { let buf = this.momentumBuffers.get(param); if (!buf) { // First time: initialize momentum buffer with current gradient buf = grad.clone(); this.momentumBuffers.set(param, buf); } else { // Update momentum buffer: buf = momentum * buf + (1 - dampening) * grad buf = buf.mul(this.momentum).add(grad.mul(1 - this.dampening)); this.momentumBuffers.set(param, buf); } if (this.nesterov) { // Nesterov momentum: grad = grad + momentum * buf grad = grad.add(buf.mul(this.momentum)); } else { // Standard momentum: use momentum buffer as gradient grad = buf; } } // Update parameter: param = param - lr * grad const newParam = detachedParam.sub(grad.mul(this.lr)); param.replace(newParam); } } } class Adam { params; momentumBuffers = new Map(); // First moment (m_t) velocityBuffers = new Map(); // Second moment (v_t) stepCount = 0; lr; betas; eps; weightDecay; constructor(params, options) { this.params = params; this.lr = options?.lr || 0.001; this.betas = options?.betas || [0.9, 0.999]; this.eps = options?.eps || 1e-8; this.weightDecay = options?.weightDecay || 0; } step() { this.stepCount++; const beta1 = this.betas[0]; const beta2 = this.betas[1]; // Bias correction factors const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount); const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount); for (const param of this.params) { if (!param.grad || !param.requiresGrad) continue; let grad = param.grad.detach(), detachedParam = param.detach(); // Apply weight decay (L2 regularization) if (this.weightDecay !== 0) { grad = grad.add(detachedParam.mul(this.weightDecay)); } // Get or initialize first moment buffer (momentum) let momentumBuffer = this.momentumBuffers.get(param); if (!momentumBuffer) { momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad) this.momentumBuffers.set(param, momentumBuffer); } // Get or initialize second moment buffer (velocity) let velocityBuffer = this.velocityBuffers.get(param); if (!velocityBuffer) { velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad) this.velocityBuffers.set(param, velocityBuffer); } // Update biased first moment estimate: m_t = β1 * m_{t-1} + (1 - β1) * g_t momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1)); this.momentumBuffers.set(param, momentumBuffer); // Update biased second moment estimate: v_t = β2 * v_{t-1} + (1 - β2) * g_t^2 velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2)); this.velocityBuffers.set(param, velocityBuffer); // Compute bias-corrected first moment: m̂_t = m_t / (1 - β1^t) const correctedMomentum = momentumBuffer.div(biasCorrection1); // Compute bias-corrected second moment: v̂_t = v_t / (1 - β2^t) const correctedVelocity = velocityBuffer.div(biasCorrection2); // Update parameters: θ_t = θ_{t-1} - α * m̂_t / (√v̂_t + ε) const denom = correctedVelocity.sqrt().add(this.eps); const stepSize = correctedMomentum.div(denom).mul(this.lr); const newParam = detachedParam.sub(stepSize); param.replace(newParam); } } } class Optim { static SGD = SGD; static Adam = Adam; } exports.Optim = Optim;