UNPKG

mini-jstorch

Version:

A lightweight JavaScript neural network library for rapid frontend AI experimentation on low-resource devices Inspired by PyTorch.

76 lines (68 loc) 2.31 kB
// JSTORCH TESTS TEMPLATE FILES // CIRCLE CLASSIFICATION TESTS FILES import { Sequential, Linear, ReLU, Sigmoid, CrossEntropyLoss, Adam } from '../src/MainEngine.js'; // === Generate Circle Dataset === function generateCircleData(n) { const X = [], Y = []; for (let i = 0; i < n; i++) { const x = Math.random() * 2 - 1; // -1 to 1 const y = Math.random() * 2 - 1; const label = (x*x + y*y < 0.5*0.5) ? [1,0] : [0,1]; // inside circle radius 0.5 X.push([x, y]); Y.push(label); } return { X, Y }; } const { X, Y } = generateCircleData(300); // === Build Model (bigger hidden layers) === const model = new Sequential([ new Linear(2, 16), new ReLU(), new Linear(16, 8), new ReLU(), new Linear(8, 2), new Sigmoid() ]); // Loss & Optimizer const lossFn = new CrossEntropyLoss(); const optimizer = new Adam(model.parameters(), 0.01); // smaller LR for not causing the model get stuck. // === Training === console.log("=== Circle Classification Training (Fixed) ==="); const start = Date.now(); for (let epoch = 1; epoch <= 2000; epoch++) { const pred = model.forward(X); const loss = lossFn.forward(pred, Y); const gradLoss = lossFn.backward(); model.backward(gradLoss); optimizer.step(); if (epoch % 500 === 0) { console.log(`Epoch ${epoch}, Loss: ${loss.toFixed(4)}`); } } console.log(`Training Time: ${((Date.now()-start)/1000).toFixed(3)}s`); // FOR real time monitoring while training if this module is literally lightweight and does not take a long time to train. // === Predictions === console.log("\n=== Predictions ==="); const testInputs = [ [0,0], [0.7,0.7], [0.2,0.2], [0.9,0.1], [-0.5,-0.5], [0.6,-0.2] ]; testInputs.forEach(inp => { const out = model.forward([inp])[0]; const predClass = out.indexOf(Math.max(...out)); console.log( `Input: ${inp}, Pred: ${predClass}, Raw: ${out.map(v => v.toFixed(3))}` ); }); /** * === How to Run On Your VS Code Projects === * 1. Make sure Node.js (v20+ recommended) is installed on your system. * 2. Open this file in VS Code (or any editor). * 3. Open the terminal in the project root folder. * 4. Run: node tests/proj3.js * * You should see the training logs and prediction outputs directly. */