UNPKG

@astermind/astermind-premium

Version:

Astermind Premium - Premium ML Toolkit

201 lines 7.87 kB
// graph-elm.ts — Graph ELM for graph-structured data // Graph neural network + ELM for node/edge classification import { ELM } from '@astermind/astermind-elm'; import { requireLicense } from '../core/license.js'; /** * Graph ELM for graph-structured data * Features: * - Node feature learning * - Graph structure encoding * - Edge-aware classification * - Graph convolution operations */ export class GraphELM { constructor(options) { this.trained = false; this.nodeFeatureMap = new Map(); requireLicense(); // Premium feature - requires valid license this.categories = options.categories; this.options = { categories: options.categories, hiddenUnits: options.hiddenUnits ?? 256, aggregationType: options.aggregationType ?? 'mean', numLayers: options.numLayers ?? 2, activation: options.activation ?? 'relu', maxLen: options.maxLen ?? 100, useTokenizer: options.useTokenizer ?? true, }; this.elm = new ELM({ useTokenizer: this.options.useTokenizer ? true : undefined, hiddenUnits: this.options.hiddenUnits, categories: this.options.categories, maxLen: this.options.maxLen, activation: this.options.activation, }); } /** * Train on graph data * @param graphs Array of graphs * @param y Labels for each graph (or node labels) */ train(graphs, y) { // Prepare labels const labelIndices = y.map(label => typeof label === 'number' ? label : this.options.categories.indexOf(label)); // Extract graph features const graphFeatures = graphs.map(graph => this._extractGraphFeatures(graph)); // Train base ELM this.elm.setCategories?.(this.options.categories); this.elm.trainFromData?.(graphFeatures, labelIndices); this.trained = true; } /** * Extract features from graph structure */ _extractGraphFeatures(graph) { // Build adjacency map const adjacencyMap = new Map(); for (const edge of graph.edges) { if (!adjacencyMap.has(edge.source)) { adjacencyMap.set(edge.source, []); } if (!adjacencyMap.has(edge.target)) { adjacencyMap.set(edge.target, []); } adjacencyMap.get(edge.source).push(String(edge.target)); adjacencyMap.get(edge.target).push(String(edge.source)); } // Compute node features through graph convolution const nodeFeatures = new Map(); // Initialize with node features for (const node of graph.nodes) { nodeFeatures.set(node.id, [...node.features]); } // Graph convolution layers for (let layer = 0; layer < this.options.numLayers; layer++) { const newFeatures = new Map(); for (const node of graph.nodes) { const neighbors = adjacencyMap.get(node.id) || []; const neighborFeatures = neighbors .map(nid => { const node = graph.nodes.find(n => String(n.id) === String(nid)); return node ? nodeFeatures.get(node.id) : null; }) .filter(f => f !== null); // Aggregate neighbor features const aggregated = this._aggregateNeighbors(neighborFeatures); // Combine with self features const selfFeatures = nodeFeatures.get(node.id) || []; const combined = this._combineFeatures(selfFeatures, aggregated); newFeatures.set(node.id, combined); } // Update features for (const [id, features] of newFeatures) { nodeFeatures.set(id, features); } } // Aggregate all node features to graph-level features const allNodeFeatures = Array.from(nodeFeatures.values()); const graphFeatures = this._aggregateNodes(allNodeFeatures); return graphFeatures; } /** * Aggregate neighbor features */ _aggregateNeighbors(neighborFeatures) { if (neighborFeatures.length === 0) { return []; } const dim = neighborFeatures[0].length; const aggregated = new Array(dim).fill(0); for (const features of neighborFeatures) { for (let i = 0; i < dim; i++) { if (this.options.aggregationType === 'mean') { aggregated[i] += features[i] / neighborFeatures.length; } else if (this.options.aggregationType === 'sum') { aggregated[i] += features[i]; } else if (this.options.aggregationType === 'max') { aggregated[i] = Math.max(aggregated[i], features[i]); } } } return aggregated; } /** * Combine self and neighbor features */ _combineFeatures(self, neighbors) { const dim = Math.max(self.length, neighbors.length); const combined = new Array(dim).fill(0); for (let i = 0; i < dim; i++) { const selfVal = i < self.length ? self[i] : 0; const neighborVal = i < neighbors.length ? neighbors[i] : 0; combined[i] = selfVal + neighborVal; // Simple addition } // Apply activation if (this.options.activation === 'relu') { return combined.map(x => Math.max(0, x)); } else if (this.options.activation === 'tanh') { return combined.map(x => Math.tanh(x)); } else if (this.options.activation === 'sigmoid') { return combined.map(x => 1 / (1 + Math.exp(-x))); } return combined; } /** * Aggregate all node features to graph level */ _aggregateNodes(nodeFeatures) { if (nodeFeatures.length === 0) { return []; } const dim = nodeFeatures[0].length; const graphFeatures = new Array(dim).fill(0); for (const features of nodeFeatures) { for (let i = 0; i < dim; i++) { if (this.options.aggregationType === 'mean') { graphFeatures[i] += features[i] / nodeFeatures.length; } else if (this.options.aggregationType === 'sum') { graphFeatures[i] += features[i]; } else if (this.options.aggregationType === 'max') { graphFeatures[i] = Math.max(graphFeatures[i], features[i]); } } } return graphFeatures; } /** * Predict on graph */ predict(graph, topK = 3) { if (!this.trained) { throw new Error('Model must be trained before prediction'); } const graphs = Array.isArray(graph) ? graph : [graph]; const results = []; for (const g of graphs) { const graphFeatures = this._extractGraphFeatures(g); const preds = this.elm.predictFromVector?.([graphFeatures], topK) || []; // Store node features for first node (for interpretability) const firstNodeFeatures = g.nodes.length > 0 ? this.nodeFeatureMap.get(g.nodes[0].id) || g.nodes[0].features : undefined; for (const pred of preds.slice(0, topK)) { results.push({ label: pred.label || this.options.categories[pred.index || 0], prob: pred.prob || 0, nodeFeatures: firstNodeFeatures, }); } } return results; } } //# sourceMappingURL=graph-elm.js.map