mini-jstorch
Version:
A lightweight JavaScript neural network library for rapid frontend AI experimentation on low-resource devices Inspired by PyTorch.
130 lines (115 loc) • 4.8 kB
JavaScript
/**
* Mini-JSTorch Next Word Prediction Test (Self-contained Softmax)
* - Softmax function defined in this file
* - Beam search prediction
* - Large vocab & diverse sentences
* - Sequence length 2
* - Training loop 2000 epochs
*/
import { Sequential, Linear, ReLU, CrossEntropyLoss, Adam } from "../src/MainEngine.js";
// ------------------------
// Softmax Function
// ------------------------
function softmaxVector(logits) {
const maxVal = Math.max(...logits);
const exps = logits.map(v => Math.exp(v - maxVal));
const sumExps = exps.reduce((a,b)=>a+b, 0);
return exps.map(v => v/sumExps);
}
// ------------------------
// Vocabulary & Tokenization
// ------------------------
const vocab = [
"i","you","he","she","we","they",
"like","love","hate","pizza","coding","game","movie","music","coffee","tea",
"run","walk","play","read","book","eat","drink","watch","listen","reads",
"drink","drinks"
];
const word2idx = Object.fromEntries(vocab.map((w,i)=>[w,i]));
const idx2word = Object.fromEntries(vocab.map((w,i)=>[i,w]));
function oneHot(index, vocabSize) {
return Array.from({length:vocabSize}, (_,i)=>i===index?1:0);
}
// ------------------------
// Dataset (sequence length 2)
// ------------------------
const sentences = [
["i","like","pizza"], ["i","like","coding"], ["i","love","music"],
["i","read","book"], ["i","watch","movie"], ["you","like","pizza"],
["you","love","coffee"], ["you","read","book"], ["you","play","game"],
["he","hate","coffee"], ["he","like","music"], ["he","play","game"],
["she","love","tea"], ["she","read","book"], ["she","watch","movie"],
["we","play","game"], ["we","read","book"], ["we","love","coffee"],
["they","eat","pizza"], ["they","play","game"], ["they","listen","music"],
["i","drink","coffee"], ["he","drink","tea"], ["she","drink","coffee"],
["we","drink","tea"], ["they","drink","coffee"], ["i","play","game"],
["you","watch","movie"], ["he","read","book"], ["she","listen","music"],
["he","reads","book"], ["you","read","book"], ["they","read","book"],
["they","watch","movie"], ["we","listen","music"], ["we","watch","movie"],
["we","reads","book"],["we","drinks","coffee"],["we","love","you"],
["i","read","book"], ["i","love","you"]
];
// Convert to input-output pairs
const X = [], Y = [];
const seqLength = 2;
for(const s of sentences){
for(let i=0;i<=s.length-seqLength;i++){
const inpSeq = s.slice(i,i+seqLength).map(w=>oneHot(word2idx[w],vocab.length)).flat();
const outWord = s[i+seqLength] ? oneHot(word2idx[s[i+seqLength]], vocab.length)
: oneHot(word2idx[s[i+seqLength-1]], vocab.length);
X.push(inpSeq);
Y.push(outWord);
}
}
// ------------------------
// Model Definition
// ------------------------
const model = new Sequential([
new Linear(vocab.length*seqLength, 128),
new ReLU(),
new Linear(128, vocab.length)
]);
const lossFn = new CrossEntropyLoss();
const optimizer = new Adam(model.parameters(), 0.01);
// ------------------------
// Training Loop
// ------------------------
for(let epoch=1; epoch<=1600; epoch++){
const pred = model.forward(X);
const loss = lossFn.forward(pred, Y);
const grad = lossFn.backward();
model.backward(grad);
optimizer.step();
if(epoch % 500 === 0) console.log(`Epoch ${epoch}, Loss: ${loss.toFixed(4)}`);
}
// ------------------------
// Beam Search Prediction
// ------------------------
function beamSearch(inputWords, beamWidth=2, predLength=3){
let sequences = [{words:[...inputWords], score:1}];
for(let step=0; step<predLength; step++){
const allCandidates = [];
for(const seq of sequences){
const inp = seq.words.slice(-seqLength).map(w=>oneHot(word2idx[w],vocab.length)).flat();
const out = model.forward([inp])[0];
const probs = softmaxVector(out); // softmax applied here
const top = probs.map((v,i)=>[i,v]).sort((a,b)=>b[1]-a[1]).slice(0,beamWidth);
for(const [idx,score] of top){
allCandidates.push({words:[...seq.words, idx2word[idx]], score:seq.score*score});
}
}
sequences = allCandidates.sort((a,b)=>b.score-a.score).slice(0,beamWidth);
}
return sequences.map(s=>({sequence:s.words, score:s.score.toFixed(3)}));
}
// ------------------------
// Test Predictions
// ------------------------
const testInputs = [
["i","like"], ["you","love"], ["they","play"], ["he","hate"], ["she","reads"]
];
for(const inp of testInputs){
const results = beamSearch(inp, 2, 3); // beam width 2, predict next 3 words
console.log(`Input: ${inp.join(" ")}`);
results.forEach(r=>console.log(` Sequence: ${r.sequence.join(" ")}, Score: ${r.score}`));
}