UNPKG

@2bad/micrograd

Version:

[![NPM version](https://img.shields.io/npm/v/@2bad/micrograd)](https://www.npmjs.com/package/@2bad/micrograd) [![License](https://img.shields.io/npm/l/@2bad/micrograd)](https://opensource.org/license/MIT) [![GitHub Build Status](https://img.shields.io/git

57 lines (56 loc) 1.84 kB
export class GraphVisualizer { nodeIds; visited; mermaidCode; constructor(){ this.nodeIds = new Map(); this.visited = new Set(); this.mermaidCode = []; } getNodeId(value) { let id = this.nodeIds.get(value); if (!id) { id = `node${this.nodeIds.size}`; this.nodeIds.set(value, id); } return id; } getOpNodeId(valueId) { return `${valueId}_op`; } trace(value) { const valueId = this.getNodeId(value); if (!this.visited.has(value)) { this.visited.add(value); const opLabel = value.operation; this.mermaidCode.push(` ${valueId}["${value.label} data: ${value.data.toFixed(4)} grad: ${value.grad.toFixed(4)}"]:::valueNode;`); if (value.children.length > 0) { const opNodeId = this.getOpNodeId(valueId); this.mermaidCode.push(` ${opNodeId}["${/[*+]/.test(opLabel) ? `\\${opLabel}` : opLabel}"];`); this.mermaidCode.push(` ${opNodeId} --> ${valueId};`); } for (const child of value.prev()){ this.trace(child); } } if (value.children.length > 0) { for (const child of value.prev()){ const childId = this.getNodeId(child); const opNodeId = this.getOpNodeId(valueId); this.mermaidCode.push(` ${childId} --> ${opNodeId};`); } } } generateMermaid(root) { this.nodeIds.clear(); this.visited.clear(); this.mermaidCode = [ 'graph LR;' ]; this.trace(root); this.mermaidCode.push(' classDef valueNode rx,ry:10,10;'); return this.mermaidCode.join('\n'); } }