mini-jstorch
Version:
A lightweight JavaScript neural network library for rapid frontend AI experimentation on low-resource devices Inspired by PyTorch.
86 lines (74 loc) • 3.75 kB
JavaScript
// TEST JSTORCH WHOLE SYSTEMS AT ONCE
import { Tensor, Linear, Sequential, ReLU, Sigmoid, Tanh, LeakyReLU, GELU, Dropout, Conv2D, MSELoss, CrossEntropyLoss, Adam, SGD, saveModel, loadModel, flattenBatch, reshape, stack, concat, eye } from '../src/MainEngine.js';
// ---------------------- Linear Test ----------------------
console.log("=== Linear Test ===");
const lin = new Linear(3,2);
const linInput = [[1,2,3],[4,5,6]];
const linOut = lin.forward(linInput);
console.log("Linear forward:", linOut);
const linGrad = [[0.1,0.2],[0.3,0.4]];
const linBack = lin.backward(linGrad);
console.log("Linear backward gradInput:", linBack);
// ---------------------- Sequential + Activations Test ----------------------
console.log("\n=== Sequential + Activations Test ===");
const model = new Sequential([new Linear(2,2), new ReLU(), new Linear(2,1), new Sigmoid()]);
const seqInput = [[0.5,1.0],[1.5,2.0]];
const seqOut = model.forward(seqInput);
console.log("Sequential forward:", seqOut);
const seqGrad = [[0.1],[0.2]];
const seqBack = model.backward(seqGrad);
console.log("Sequential backward gradInput:", seqBack);
// ---------------------- Conv2D Test ----------------------
console.log("\n=== Conv2D Test ===");
const conv = new Conv2D(1,1,3);
const convInput = [[[ [1,2,3],[4,5,6],[7,8,9] ]]]; // batch=1, inC=1, HxW=3x3
const convOut = conv.forward(convInput);
console.log("Conv2D forward:", convOut);
// Conv2D backward test
const convGrad = [[[ [0.1,0.2,0.1],[0.2,0.3,0.2],[0.1,0.2,0.1] ]]];
const convBack = conv.backward(convGrad);
console.log("Conv2D backward gradInput:", convBack);
// ---------------------- Tensor & Broadcast Test ----------------------
console.log("\n=== Tensor & Broadcast Test ===");
const a = Tensor.random(2,3);
const b = Tensor.ones(2,3);
const sum = a.add(b);
console.log("Tensor add broadcast:", sum);
// ---------------------- Loss + Optimizer Test ----------------------
console.log("\n=== Loss + Optimizer Test ===");
const lossModel = new Sequential([new Linear(2,2)]);
const pred = lossModel.forward([[1,2]]);
const target = [[0,1]];
const ceLoss = new CrossEntropyLoss();
const lval = ceLoss.forward(pred,target);
console.log("CrossEntropyLoss value:", lval);
const gradLoss = ceLoss.backward();
lossModel.backward(gradLoss);
const opt = new Adam(lossModel.parameters());
opt.step();
console.log("Updated parameters after Adam:", lossModel.parameters());
// ---------------------- Dropout Test ----------------------
console.log("\n=== Dropout Test ===");
const drop = new Dropout(0.5);
const dropInput = [[1,2],[3,4]];
const dropOut = drop.forward(dropInput);
console.log("Dropout forward:", dropOut);
const dropBack = drop.backward([[0.1,0.2],[0.3,0.4]]);
console.log("Dropout backward:", dropBack);
// ---------------------- Save / Load Model Test ----------------------
console.log("\n=== Save / Load Model Test ===");
const modelSave = new Sequential([new Linear(2,2)]);
const json = saveModel(modelSave);
console.log("Saved model JSON:", json);
const modelLoad = new Sequential([new Linear(2,2)]);
loadModel(modelLoad,json);
console.log("Loaded model parameters:", modelLoad.parameters());
// ---------------------- Advanced Utils Test ----------------------
console.log("\n=== Advanced Utils Test ===");
const batch = [[[1,2],[3,4]],[[5,6],[7,8]]];
console.log("Flatten batch:", flattenBatch(batch));
console.log("Eye 3:", eye(3));
console.log("Reshape:", reshape({data:[[1,2,3,4]]},2,2));
console.log("Stack:", stack([Tensor.ones(2,2), Tensor.zeros(2,2)]));
console.log("Concat axis0:", concat([[1,2],[3,4]], [[5,6],[7,8]], 0));
console.log("Concat axis1:", concat([[1,2],[3,4]], [[5,6],[7,8]], 1));