@2bad/micrograd
Version:
[](https://www.npmjs.com/package/@2bad/micrograd) [](https://opensource.org/license/MIT) [ • 2.67 kB
JavaScript
/* eslint-disable @typescript-eslint/no-non-null-assertion */ import { Value } from "./value.js";
export class Neuron {
weights;
bias;
constructor(inputs){
this.weights = Array.from({
length: inputs
}).map(()=>new Value(Math.random() * 2 - 1));
this.bias = new Value(Math.random() * 2 - 1);
}
forward(inputs) {
// w * x + b
const activation = this.weights.reduce((sum, w, i)=>sum.add(w.mul(inputs[i])), this.bias);
return activation.tanh();
}
parameters() {
return [
...this.weights,
this.bias
];
}
}
export class Layer {
neurons;
constructor(inputs, outputs){
this.neurons = Array.from({
length: outputs
}).map(()=>new Neuron(inputs));
}
forward(inputs) {
return this.neurons.map((neuron)=>neuron.forward(inputs));
}
parameters() {
return this.neurons.flatMap((neuron)=>neuron.parameters());
}
}
// biome-ignore lint/style/useNamingConvention:
export class MLP {
layers;
constructor(inputs, outputs){
const sizes = [
inputs,
...outputs
];
this.layers = sizes.slice(1).map((size, i)=>{
// biome-ignore lint/style/noNonNullAssertion: this is a typescript limitation
return new Layer(sizes[i], size);
});
}
forward(inputs) {
return this.layers.reduce((prev, layer)=>layer.forward(prev), inputs);
}
parameters() {
return this.layers.flatMap((layer)=>layer.parameters());
}
train(xs, ys, learningRate = 0.1, epochs = 100) {
for(let epoch = 0; epoch < epochs; epoch++){
let totalLoss = new Value(0);
for(let i = 0; i < xs.length; i++){
// biome-ignore lint/style/noNonNullAssertion:
const inputs = xs[i].map((x)=>new Value(x));
const pred = this.forward(inputs)[0];
// biome-ignore lint/style/noNonNullAssertion:
const target = new Value(ys[i]);
// biome-ignore lint/style/noNonNullAssertion:
const loss = pred.sub(target).pow(2);
totalLoss = totalLoss.add(loss);
}
totalLoss.resetGrad();
// Backward pass
totalLoss.backward();
// Update parameters
for (const p of this.parameters()){
p.data -= learningRate * p.grad;
p.grad = 0;
}
if (epoch % 10 === 0) {
console.log(`Epoch ${epoch}, Loss: ${totalLoss.data}`);
}
}
}
}