UNPKG

gradiatorjs

Version:

GradiatorJS is a lightweight, from-scratch autodiff engine and a neural network library written in typescript. Featuring a powerful automatic differentiation engine using a computation graph to enable backpropagation on dynamic network architectures. You

138 lines (122 loc) 4.11 kB
import type { Sequential } from "./model.js"; import { getMiniBatch } from "./train.js"; import type { NetworkParams } from "./types.js"; import type { Val } from "./val.js"; let isTraining = false; let stopTraining = false; let isPaused = false; const trainingContext = { model: null as Sequential | null, X_train: null as Val | null, Y_train: null as Val | null, params: null as NetworkParams | null, currentEpoch: 0, batchGenerator: null as Generator<any, void, unknown> | null, iteration: 0 }; export type trainingState = 'IDLE' | 'TRAINING' | 'PAUSED' | 'STOPPING'; export const getIsTraining = (): boolean => isTraining; export const getIsPaused = (): boolean => isPaused; export const getStopTraining = (): boolean => stopTraining; export function getTrainingContext() { return trainingContext; } export function startTraining(): void { isTraining = true; stopTraining = false; isPaused = false; console.log("training started"); } export function requestStopTraining(): void { if (!isTraining) { console.log("training stop requested when it wasnt running in the first place") return; } stopTraining = true; isPaused = false; console.log("stop requested"); } export function endTraining(): void { isTraining = false; stopTraining = false; isPaused = false; trainingContext.model = null; trainingContext.X_train = null; trainingContext.Y_train = null; trainingContext.params = null; trainingContext.currentEpoch = 0; trainingContext.batchGenerator = null; console.log("training finished"); } export function requestPause(): void { if (isTraining && !isPaused) { isPaused = true; console.log("pausing training"); } } export function requestResume(): void { if (isTraining && isPaused) { isPaused = false; console.log("resume requested") } } export function setTrainingState(newState: trainingState): void { switch(newState) { case 'IDLE': isTraining = false; stopTraining = false; isPaused = false; break; case 'TRAINING': if (!isTraining) { isTraining = true; stopTraining = false; isPaused = false; } else if (isPaused) { // already paused, resuming isPaused = false; } break; case 'PAUSED': if (isTraining && !isPaused) { isPaused = true; } break; case 'STOPPING': if (isTraining) { stopTraining = true; isPaused = false; } break; } console.log("State changed to: ", newState); } export function setupTrainingContext(model: Sequential, x: Val, y: Val, params: NetworkParams): void { trainingContext.model = model; trainingContext.X_train = x; trainingContext.Y_train = y; trainingContext.params = params; trainingContext.currentEpoch = 0; trainingContext.iteration = 0; trainingContext.batchGenerator = getMiniBatch(x, y, params.batch_size); startTraining(); console.log("training context has been set up for Epoch 0."); } export function advanceEpoch(): boolean { if (!trainingContext.params || !trainingContext.X_train || !trainingContext.Y_train) { return false; } trainingContext.currentEpoch++; if (trainingContext.currentEpoch >= trainingContext.params.epochs) { console.log("Max epochs reached."); endTraining(); return false; } // Create a new generator for the new epoch trainingContext.batchGenerator = getMiniBatch( trainingContext.X_train, trainingContext.Y_train, trainingContext.params.batch_size ); console.log(`Advanced to Epoch ${trainingContext.currentEpoch}`); return true; }