tinygrad
Version:
A JavaScript/TypeScript autograd engine with operator overloading, inspired by micrograd
297 lines (293 loc) • 8.27 kB
JavaScript
//#region rolldown:runtime
var __defProp = Object.defineProperty;
var __export = (all) => {
let target = {};
for (var name in all) __defProp(target, name, {
get: all[name],
enumerable: true
});
return target;
};
//#endregion
//#region src/tinygrad/engine.ts
var engine_exports = /* @__PURE__ */ __export({ Value: () => Value });
/**
* Stores a single scalar value and its gradient for automatic differentiation
*/
var Value = class Value {
grad;
#backward;
#prev;
constructor(data, children = [], _op = "") {
this.data = data;
this._op = _op;
this.grad = 0;
this.#backward = () => {};
this.#prev = new Set(children);
}
add(other) {
const otherValue = other instanceof Value ? other : new Value(other);
const out = new Value((() => {
"operator-overloading disabled";
const __lhs = this.data;
const __rhs = otherValue.data;
const __sym = Symbol.for("+");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs + __rhs;
})(), [this, otherValue], "+");
out.#backward = () => {
this.grad += out.grad;
otherValue.grad += out.grad;
};
return out;
}
[Symbol.for("+")](other) {
return this.add(other);
}
mul(other) {
const otherValue = other instanceof Value ? other : new Value(other);
const out = new Value((() => {
"operator-overloading disabled";
const __lhs = this.data;
const __rhs = otherValue.data;
const __sym = Symbol.for("*");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs * __rhs;
})(), [this, otherValue], "*");
out.#backward = () => {
this.grad += (() => {
"operator-overloading disabled";
const __lhs = otherValue.data;
const __rhs = out.grad;
const __sym = Symbol.for("*");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs * __rhs;
})();
otherValue.grad += (() => {
"operator-overloading disabled";
const __lhs = this.data;
const __rhs = out.grad;
const __sym = Symbol.for("*");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs * __rhs;
})();
};
return out;
}
[Symbol.for("*")](other) {
return this.mul(other);
}
pow(other) {
if (typeof other !== "number") throw new Error("only supporting number powers for now");
const out = new Value((() => {
"operator-overloading disabled";
const __lhs = this.data;
const __rhs = other;
const __sym = Symbol.for("**");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs ** __rhs;
})(), [this], `**${other}`);
out.#backward = () => {
this.grad += other * this.data ** (() => {
"operator-overloading disabled";
const __lhs = other;
const __rhs = 1;
const __sym = Symbol.for("-");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs - __rhs;
})() * out.grad;
};
return out;
}
[Symbol.for("**")](other) {
return this.pow(other);
}
neg() {
return this.mul((() => {
"operator-overloading disabled";
const __arg = 1;
const __sym = Symbol.for("minus");
return __arg != null && __arg[__sym] !== void 0 ? __arg[__sym]() : -__arg;
})());
}
[Symbol.for("minus")]() {
return this.neg();
}
sub(other) {
const otherValue = other instanceof Value ? other : new Value(other);
return this.add(otherValue.neg());
}
[Symbol.for("-")](other) {
return this.sub(other);
}
div(other) {
const otherValue = other instanceof Value ? other : new Value(other);
return this.mul(otherValue.pow((() => {
"operator-overloading disabled";
const __arg = 1;
const __sym = Symbol.for("minus");
return __arg != null && __arg[__sym] !== void 0 ? __arg[__sym]() : -__arg;
})()));
}
[Symbol.for("/")](other) {
return this.div(other);
}
relu() {
const out = new Value((() => {
"operator-overloading disabled";
const __lhs = this.data;
const __rhs = 0;
const __sym = Symbol.for("<");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs < __rhs;
})() ? 0 : this.data, [this], "ReLU");
out.#backward = () => {
this.grad += ((() => {
"operator-overloading disabled";
const __lhs = out.data;
const __rhs = 0;
const __sym = Symbol.for(">");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs > __rhs;
})() ? 1 : 0) * out.grad;
};
return out;
}
backward() {
const topo = [];
const visited = /* @__PURE__ */ new Set();
const buildTopo = (v) => {
if ((() => {
"operator-overloading disabled";
const __arg = visited.has(v);
const __sym = Symbol.for("!");
return __arg != null && __arg[__sym] !== void 0 ? __arg[__sym]() : !__arg;
})()) {
visited.add(v);
v.#prev.forEach((child) => {
buildTopo(child);
});
topo.push(v);
}
};
buildTopo(this);
this.grad = 1;
for (let i = (() => {
"operator-overloading disabled";
const __lhs = topo.length;
const __rhs = 1;
const __sym = Symbol.for("-");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs - __rhs;
})(); (() => {
"operator-overloading disabled";
const __lhs = i;
const __rhs = 0;
const __sym = Symbol.for(">=");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs >= __rhs;
})(); i--) topo[i].#backward();
}
toString() {
return `Value(data=${this.data}, grad=${this.grad})`;
}
};
//#endregion
//#region src/tinygrad/nn.ts
var nn_exports = /* @__PURE__ */ __export({
Layer: () => Layer,
MLP: () => MLP,
Neuron: () => Neuron
});
var Module = class {
zeroGrad() {
for (const p of this.parameters()) p.grad = 0;
}
};
var Neuron = class extends Module {
w;
b;
nonlin;
constructor(nin, nonlin = true) {
super();
this.w = Array.from({ length: nin }, () => new Value((() => {
"operator-overloading disabled";
const __lhs = Math.random() * 2;
const __rhs = 1;
const __sym = Symbol.for("-");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs - __rhs;
})()));
this.b = new Value(0);
this.nonlin = nonlin;
}
call(x) {
const act = this.w.reduce((sum, wi, i) => sum.add(wi.mul(x[i])), this.b);
return this.nonlin ? act.relu() : act;
}
parameters() {
return [...this.w, this.b];
}
toString() {
return `${this.nonlin ? "ReLU" : "Linear"}Neuron(${this.w.length})`;
}
};
var Layer = class extends Module {
neurons;
constructor(nin, nout, nonlin = true) {
super();
this.neurons = Array.from({ length: nout }, () => new Neuron(nin, nonlin));
}
call(x) {
const out = this.neurons.map((n) => n.call(x));
return out.length === 1 ? out[0] : out;
}
parameters() {
return this.neurons.flatMap((n) => n.parameters());
}
toString() {
return `Layer of [${this.neurons.map((n) => n.toString()).join(", ")}]`;
}
};
var MLP = class extends Module {
layers;
constructor(nin, nouts) {
super();
const sz = [nin, ...nouts];
this.layers = sz.slice(0, (() => {
"operator-overloading disabled";
const __arg = 1;
const __sym = Symbol.for("minus");
return __arg != null && __arg[__sym] !== void 0 ? __arg[__sym]() : -__arg;
})()).map((_, i) => {
const nonlin = i !== (() => {
"operator-overloading disabled";
const __lhs = nouts.length;
const __rhs = 1;
const __sym = Symbol.for("-");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs - __rhs;
})();
return new Layer(sz[i], sz[(() => {
"operator-overloading disabled";
const __lhs = i;
const __rhs = 1;
const __sym = Symbol.for("+");
return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs + __rhs;
})()], nonlin);
});
}
call(x) {
let out = x;
for (const layer of this.layers) out = layer.call(Array.isArray(out) ? out : [out]);
return out;
}
parameters() {
return this.layers.flatMap((layer) => layer.parameters());
}
toString() {
return `MLP of [${this.layers.map((l) => l.toString()).join(", ")}]`;
}
};
//#endregion
Object.defineProperty(exports, 'engine', {
enumerable: true,
get: function () {
return engine_exports;
}
});
Object.defineProperty(exports, 'nn', {
enumerable: true,
get: function () {
return nn_exports;
}
});
//# sourceMappingURL=index.cjs.map