UNPKG

tinygrad

Version:

A JavaScript/TypeScript autograd engine with operator overloading, inspired by micrograd

275 lines (272 loc) 7.91 kB
import { __export } from "./chunk-CTAAG5j7.js"; //#region src/tinygrad/engine.ts var engine_exports = /* @__PURE__ */ __export({ Value: () => Value }); /** * Stores a single scalar value and its gradient for automatic differentiation */ var Value = class Value { grad; #backward; #prev; constructor(data, children = [], _op = "") { this.data = data; this._op = _op; this.grad = 0; this.#backward = () => {}; this.#prev = new Set(children); } add(other) { const otherValue = other instanceof Value ? other : new Value(other); const out = new Value((() => { "operator-overloading disabled"; const __lhs = this.data; const __rhs = otherValue.data; const __sym = Symbol.for("+"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs + __rhs; })(), [this, otherValue], "+"); out.#backward = () => { this.grad += out.grad; otherValue.grad += out.grad; }; return out; } [Symbol.for("+")](other) { return this.add(other); } mul(other) { const otherValue = other instanceof Value ? other : new Value(other); const out = new Value((() => { "operator-overloading disabled"; const __lhs = this.data; const __rhs = otherValue.data; const __sym = Symbol.for("*"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs * __rhs; })(), [this, otherValue], "*"); out.#backward = () => { this.grad += (() => { "operator-overloading disabled"; const __lhs = otherValue.data; const __rhs = out.grad; const __sym = Symbol.for("*"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs * __rhs; })(); otherValue.grad += (() => { "operator-overloading disabled"; const __lhs = this.data; const __rhs = out.grad; const __sym = Symbol.for("*"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs * __rhs; })(); }; return out; } [Symbol.for("*")](other) { return this.mul(other); } pow(other) { if (typeof other !== "number") throw new Error("only supporting number powers for now"); const out = new Value((() => { "operator-overloading disabled"; const __lhs = this.data; const __rhs = other; const __sym = Symbol.for("**"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs ** __rhs; })(), [this], `**${other}`); out.#backward = () => { this.grad += other * this.data ** (() => { "operator-overloading disabled"; const __lhs = other; const __rhs = 1; const __sym = Symbol.for("-"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs - __rhs; })() * out.grad; }; return out; } [Symbol.for("**")](other) { return this.pow(other); } neg() { return this.mul((() => { "operator-overloading disabled"; const __arg = 1; const __sym = Symbol.for("minus"); return __arg != null && __arg[__sym] !== void 0 ? __arg[__sym]() : -__arg; })()); } [Symbol.for("minus")]() { return this.neg(); } sub(other) { const otherValue = other instanceof Value ? other : new Value(other); return this.add(otherValue.neg()); } [Symbol.for("-")](other) { return this.sub(other); } div(other) { const otherValue = other instanceof Value ? other : new Value(other); return this.mul(otherValue.pow((() => { "operator-overloading disabled"; const __arg = 1; const __sym = Symbol.for("minus"); return __arg != null && __arg[__sym] !== void 0 ? __arg[__sym]() : -__arg; })())); } [Symbol.for("/")](other) { return this.div(other); } relu() { const out = new Value((() => { "operator-overloading disabled"; const __lhs = this.data; const __rhs = 0; const __sym = Symbol.for("<"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs < __rhs; })() ? 0 : this.data, [this], "ReLU"); out.#backward = () => { this.grad += ((() => { "operator-overloading disabled"; const __lhs = out.data; const __rhs = 0; const __sym = Symbol.for(">"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs > __rhs; })() ? 1 : 0) * out.grad; }; return out; } backward() { const topo = []; const visited = /* @__PURE__ */ new Set(); const buildTopo = (v) => { if ((() => { "operator-overloading disabled"; const __arg = visited.has(v); const __sym = Symbol.for("!"); return __arg != null && __arg[__sym] !== void 0 ? __arg[__sym]() : !__arg; })()) { visited.add(v); v.#prev.forEach((child) => { buildTopo(child); }); topo.push(v); } }; buildTopo(this); this.grad = 1; for (let i = (() => { "operator-overloading disabled"; const __lhs = topo.length; const __rhs = 1; const __sym = Symbol.for("-"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs - __rhs; })(); (() => { "operator-overloading disabled"; const __lhs = i; const __rhs = 0; const __sym = Symbol.for(">="); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs >= __rhs; })(); i--) topo[i].#backward(); } toString() { return `Value(data=${this.data}, grad=${this.grad})`; } }; //#endregion //#region src/tinygrad/nn.ts var nn_exports = /* @__PURE__ */ __export({ Layer: () => Layer, MLP: () => MLP, Neuron: () => Neuron }); var Module = class { zeroGrad() { for (const p of this.parameters()) p.grad = 0; } }; var Neuron = class extends Module { w; b; nonlin; constructor(nin, nonlin = true) { super(); this.w = Array.from({ length: nin }, () => new Value((() => { "operator-overloading disabled"; const __lhs = Math.random() * 2; const __rhs = 1; const __sym = Symbol.for("-"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs - __rhs; })())); this.b = new Value(0); this.nonlin = nonlin; } call(x) { const act = this.w.reduce((sum, wi, i) => sum.add(wi.mul(x[i])), this.b); return this.nonlin ? act.relu() : act; } parameters() { return [...this.w, this.b]; } toString() { return `${this.nonlin ? "ReLU" : "Linear"}Neuron(${this.w.length})`; } }; var Layer = class extends Module { neurons; constructor(nin, nout, nonlin = true) { super(); this.neurons = Array.from({ length: nout }, () => new Neuron(nin, nonlin)); } call(x) { const out = this.neurons.map((n) => n.call(x)); return out.length === 1 ? out[0] : out; } parameters() { return this.neurons.flatMap((n) => n.parameters()); } toString() { return `Layer of [${this.neurons.map((n) => n.toString()).join(", ")}]`; } }; var MLP = class extends Module { layers; constructor(nin, nouts) { super(); const sz = [nin, ...nouts]; this.layers = sz.slice(0, (() => { "operator-overloading disabled"; const __arg = 1; const __sym = Symbol.for("minus"); return __arg != null && __arg[__sym] !== void 0 ? __arg[__sym]() : -__arg; })()).map((_, i) => { const nonlin = i !== (() => { "operator-overloading disabled"; const __lhs = nouts.length; const __rhs = 1; const __sym = Symbol.for("-"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs - __rhs; })(); return new Layer(sz[i], sz[(() => { "operator-overloading disabled"; const __lhs = i; const __rhs = 1; const __sym = Symbol.for("+"); return __lhs != null && __lhs[__sym] !== void 0 ? __lhs[__sym](__rhs) : __lhs + __rhs; })()], nonlin); }); } call(x) { let out = x; for (const layer of this.layers) out = layer.call(Array.isArray(out) ? out : [out]); return out; } parameters() { return this.layers.flatMap((layer) => layer.parameters()); } toString() { return `MLP of [${this.layers.map((l) => l.toString()).join(", ")}]`; } }; //#endregion export { engine_exports as engine, nn_exports as nn }; //# sourceMappingURL=index.js.map