UNPKG

gradiatorjs

Version:

GradiatorJS is a lightweight, from-scratch autodiff engine and a neural network library written in typescript. Featuring a powerful automatic differentiation engine using a computation graph to enable backpropagation on dynamic network architectures. You

288 lines (240 loc) 11.1 kB
import { Val } from "./val.js"; import { Sequential } from "./model.js"; import { getStopTraining, endTraining, getIsPaused, getTrainingContext, advanceEpoch } from "./state_management.js"; import { calcBinaryAccuracy, calcMultiClassAccuracy } from "./accuracy.js"; import type { NetworkParams, Messenger } from "./types.js"; import { assert } from "./utils.js"; function createBatchVal(ogVal: Val, batchIndices: number[], currentBatchSize: number, features: number) { const batchShape = [currentBatchSize, ...ogVal.shape.slice(1)]; const batchVal = new Val(batchShape); for (let k=0; k<currentBatchSize; k++) { const ogIdx = batchIndices[k]!; const sourceOffset = ogIdx*features; const destOffset = k*features; batchVal.data.set(ogVal.data.subarray(sourceOffset, sourceOffset+features), destOffset); } return batchVal; } // generator that will yield mini-batches from the full dataset export function* getMiniBatch(X: Val, Y: Val, batchSize: number, shuffle: boolean = true){ const numSamples = X.shape[0]!; const indices = Array.from({ length: numSamples }, (_, i) => i); if (shuffle) { for (let i=indices.length-1; i>0; i--) { const j=Math.floor(Math.random()*(i+1)); [indices[i], indices[j]] = [indices[j]!, indices[i]!]; } } const xFeatures = X.size / numSamples; // H*W*C for images, or F for dense const yFeatures = Y.size / numSamples; // num classes for 1-hot labels for (let i=0; i<numSamples; i+=batchSize) { const batchIndices = indices.slice(i, i+batchSize); const currentBatchSize = batchIndices.length; const xBatchVal = createBatchVal(X, batchIndices, currentBatchSize, xFeatures) const yBatchVal = createBatchVal(Y, batchIndices, currentBatchSize, yFeatures) yield {x: xBatchVal, y: yBatchVal} } } function getActSample(batch: {x: Val, y: Val}): [Val, number] { const sampleX = new Val(batch.x.shape.slice(1)).reshape([1, ...batch.x.shape.slice(1)]); sampleX.data = batch.x.data.slice(0, batch.x.size / batch.x.shape[0]!); let sampleY_label = -1; const y_features = batch.y.size / batch.y.shape[0]!; if (y_features === 1) { // If label is a single number sampleY_label = batch.y.data[0]; } else { // If label is one-hot encoded const y_sample_data = batch.y.data.slice(0, y_features); sampleY_label = y_sample_data.indexOf(1); } return [sampleX, sampleY_label]; } export async function trainModel( model: Sequential, X_train: Val, Y_train: Val, params: NetworkParams, messenger?: Messenger, ) : Promise<void> { const {loss_fn, l_rate, epochs, batch_size, multiClass} = params console.log(`----starting training. ${epochs} epochs, batch size ${batch_size}----`); let totalProcessingTime = 0; let iteration = 0; try { for (let e=0; e<epochs; e++) { const batchGenerator = getMiniBatch(X_train, Y_train, batch_size) let batch_idx = 0; for (const batch of batchGenerator) { while(getIsPaused()) { if(getStopTraining()) { console.log("training stopped during pause") return; } await new Promise(resolve => setTimeout(resolve, 200)); } if (getStopTraining()) { console.log(`Training stopped at epoch ${e}`); return; } const iterStartTime = performance.now(); const { x: X_batch, y: Y_batch } = batch; model.zeroGrad(); const Y_pred = model.forward(X_batch); const loss = loss_fn(Y_pred, Y_batch); loss.backward(); const params = model.parameters(); for (const p of params) { if (!p.grad || p.data.length !== p.grad.length) { console.warn(`Skipping update for parameter due to missing/mismatched gradient at iteration ${e}. Param shape: ${p.shape}`); continue; } let hasInvalidGrad = false; for (const gradVal of p.grad) { if (isNaN(gradVal) || !isFinite(gradVal)) { console.warn("Invalid gradient found. halting update for this batch", p); hasInvalidGrad = true; break; } } if(hasInvalidGrad) continue; for (let j=0; j<p.data.length; j++) { if (j < p.grad.length) { p.data[j] -= l_rate * p.grad[j]!; } } } const iterEndTime = performance.now(); const iterTime = iterEndTime - iterStartTime; totalProcessingTime += iterTime; iteration++; if (messenger) { assert(model instanceof Sequential, ()=>`Model is not an instance of sequential.`) let accuracy = multiClass ? calcMultiClassAccuracy(Y_pred, Y_batch) : calcBinaryAccuracy(Y_pred, Y_batch); const [sampleX, sampleY_label] = getActSample(batch); const rawLayerOutputs = model.getLayerOutputs(sampleX).map(val => ({ Zdata: val['Z']?.data.buffer, Zshape: val['Z']?.shape, Adata: val['A']?.data.buffer, Ashape: val['A']?.shape })); const transferableBuffersSet = new Set<ArrayBuffer>(); rawLayerOutputs.forEach(layer => { if (layer.Zdata) transferableBuffersSet.add(layer.Zdata); if (layer.Adata) transferableBuffersSet.add(layer.Adata); }); transferableBuffersSet.add(sampleX.data.buffer); const transferableBuffers = Array.from(transferableBuffersSet); messenger.postMessage({ type: 'batchEnd', epoch: e, batch_idx: iteration, loss: loss.data[0], accuracy: accuracy, iterTime: iterTime, visData: { sampleX: { data: sampleX.data.buffer, shape: sampleX.shape }, sampleY_label: sampleY_label, layerOutputs: rawLayerOutputs } }, transferableBuffers); } batch_idx++; } } console.log(`Training finished.`); console.log(`Total processing time: ${(totalProcessingTime / 1000).toFixed(3)}s over ${iteration} iterations.`); console.log(`Average time per batch: ${(totalProcessingTime / iteration / 1000).toFixed(4)}s`); console.log(model) } catch (error) { console.error("Error during training loop execution:", error); throw error; } finally { endTraining(); } } export async function trainSingleBatch(messenger: Messenger): Promise<void> { const C = getTrainingContext(); if (getStopTraining() || getIsPaused() || !C.model || !C.batchGenerator || !C.params) { if (getStopTraining()) { endTraining(); messenger.postMessage({type: 'complete', reason: 'stopped by user'}); return; } return; } C.iteration++; const iterStartTime = performance.now(); const batchResult = C.batchGenerator.next(); if (batchResult.done) { if(advanceEpoch()) { messenger.postMessage({type: 'epochEnd', epoch: C.currentEpoch}); } else { messenger.postMessage({type: 'complete', reason: 'All epochs finished'}) } return; } const batch = batchResult.value; const { x: X_batch, y: Y_batch } = batch; const {loss_fn, l_rate, epochs, batch_size, multiClass} = C.params C.model.zeroGrad(); const Y_pred = C.model.forward(X_batch); const loss = loss_fn(Y_pred, Y_batch); loss.backward(); const modelParams = C.model.parameters(); for (const p of modelParams) { if (!p.grad || p.data.length !== p.grad.length) { console.warn(`Skipping update for parameter due to missing/mismatched gradient. Param shape: ${p.shape}`); continue; } let hasInvalidGrad = false; for (const gradVal of p.grad) { if (isNaN(gradVal) || !isFinite(gradVal)) { console.warn("Invalid gradient found. halting update for this batch", p); hasInvalidGrad = true; break; } } if(hasInvalidGrad) continue; for (let j=0; j<p.data.length; j++) { if (j < p.grad.length) { p.data[j] -= l_rate * p.grad[j]!; } } } const iterEndTime = performance.now(); const iterTime = iterEndTime - iterStartTime; assert(C.model instanceof Sequential, ()=>`Model is not an instance of sequential.`) let accuracy = multiClass ? calcMultiClassAccuracy(Y_pred, Y_batch) : calcBinaryAccuracy(Y_pred, Y_batch); const [sampleX, sampleY_label] = getActSample(batch); const rawLayerOutputs = C.model.getLayerOutputs(sampleX).map(val => ({ Zdata: val['Z']?.data.buffer, Zshape: val['Z']?.shape, Adata: val['A']?.data.buffer, Ashape: val['A']?.shape })); const transferableBuffersSet = new Set<ArrayBuffer>(); rawLayerOutputs.forEach(layer => { if (layer.Zdata) transferableBuffersSet.add(layer.Zdata); if (layer.Adata) transferableBuffersSet.add(layer.Adata); }); transferableBuffersSet.add(sampleX.data.buffer); const transferableBuffers = Array.from(transferableBuffersSet); messenger.postMessage({ type: 'batchEnd', epoch: C.currentEpoch, batch_idx: C.iteration, loss: loss.data[0], accuracy: accuracy, iterTime: iterTime, visData: { sampleX: { data: sampleX.data.buffer, shape: sampleX.shape }, sampleY_label: sampleY_label, layerOutputs: rawLayerOutputs } }, transferableBuffers); }