mini-jstorch
Version:
A lightweight JavaScript neural network library for rapid frontend AI experimentation on low-resource devices Inspired by PyTorch.
76 lines (68 loc) • 2.31 kB
JavaScript
// JSTORCH TESTS TEMPLATE FILES
// CIRCLE CLASSIFICATION TESTS FILES
import { Sequential, Linear, ReLU, Sigmoid, CrossEntropyLoss, Adam } from '../src/MainEngine.js';
// === Generate Circle Dataset ===
function generateCircleData(n) {
const X = [], Y = [];
for (let i = 0; i < n; i++) {
const x = Math.random() * 2 - 1; // -1 to 1
const y = Math.random() * 2 - 1;
const label = (x*x + y*y < 0.5*0.5) ? [1,0] : [0,1]; // inside circle radius 0.5
X.push([x, y]);
Y.push(label);
}
return { X, Y };
}
const { X, Y } = generateCircleData(300);
// === Build Model (bigger hidden layers) ===
const model = new Sequential([
new Linear(2, 16),
new ReLU(),
new Linear(16, 8),
new ReLU(),
new Linear(8, 2),
new Sigmoid()
]);
// Loss & Optimizer
const lossFn = new CrossEntropyLoss();
const optimizer = new Adam(model.parameters(), 0.01); // smaller LR for not causing the model get stuck.
// === Training ===
console.log("=== Circle Classification Training (Fixed) ===");
const start = Date.now();
for (let epoch = 1; epoch <= 2000; epoch++) {
const pred = model.forward(X);
const loss = lossFn.forward(pred, Y);
const gradLoss = lossFn.backward();
model.backward(gradLoss);
optimizer.step();
if (epoch % 500 === 0) {
console.log(`Epoch ${epoch}, Loss: ${loss.toFixed(4)}`);
}
}
console.log(`Training Time: ${((Date.now()-start)/1000).toFixed(3)}s`); // FOR real time monitoring while training if this module is literally lightweight and does not take a long time to train.
// === Predictions ===
console.log("\n=== Predictions ===");
const testInputs = [
[0,0],
[0.7,0.7],
[0.2,0.2],
[0.9,0.1],
[-0.5,-0.5],
[0.6,-0.2]
];
testInputs.forEach(inp => {
const out = model.forward([inp])[0];
const predClass = out.indexOf(Math.max(...out));
console.log(
`Input: ${inp}, Pred: ${predClass}, Raw: ${out.map(v => v.toFixed(3))}`
);
});
/**
* === How to Run On Your VS Code Projects ===
* 1. Make sure Node.js (v20+ recommended) is installed on your system.
* 2. Open this file in VS Code (or any editor).
* 3. Open the terminal in the project root folder.
* 4. Run: node tests/proj3.js
*
* You should see the training logs and prediction outputs directly.
*/