@2bad/micrograd
Version:
[](https://www.npmjs.com/package/@2bad/micrograd) [](https://opensource.org/license/MIT) [ • 1.84 kB
JavaScript
export class GraphVisualizer {
nodeIds;
visited;
mermaidCode;
constructor(){
this.nodeIds = new Map();
this.visited = new Set();
this.mermaidCode = [];
}
getNodeId(value) {
let id = this.nodeIds.get(value);
if (!id) {
id = `node${this.nodeIds.size}`;
this.nodeIds.set(value, id);
}
return id;
}
getOpNodeId(valueId) {
return `${valueId}_op`;
}
trace(value) {
const valueId = this.getNodeId(value);
if (!this.visited.has(value)) {
this.visited.add(value);
const opLabel = value.operation;
this.mermaidCode.push(` ${valueId}["${value.label}
data: ${value.data.toFixed(4)}
grad: ${value.grad.toFixed(4)}"]:::valueNode;`);
if (value.children.length > 0) {
const opNodeId = this.getOpNodeId(valueId);
this.mermaidCode.push(` ${opNodeId}["${/[*+]/.test(opLabel) ? `\\${opLabel}` : opLabel}"];`);
this.mermaidCode.push(` ${opNodeId} --> ${valueId};`);
}
for (const child of value.prev()){
this.trace(child);
}
}
if (value.children.length > 0) {
for (const child of value.prev()){
const childId = this.getNodeId(child);
const opNodeId = this.getOpNodeId(valueId);
this.mermaidCode.push(` ${childId} --> ${opNodeId};`);
}
}
}
generateMermaid(root) {
this.nodeIds.clear();
this.visited.clear();
this.mermaidCode = [
'graph LR;'
];
this.trace(root);
this.mermaidCode.push(' classDef valueNode rx,ry:10,10;');
return this.mermaidCode.join('\n');
}
}