UNPKG

@astermind/astermind-premium

Version:

Astermind Premium - Premium ML Toolkit

210 lines 7.67 kB
// graph-kernel-elm.ts — Graph Kernel ELM // Graph kernels (Weisfeiler-Lehman, etc.) for graph structure encoding import { KernelELM } from '@astermind/astermind-elm'; import { requireLicense } from '../core/license.js'; /** * Graph Kernel ELM * Features: * - Graph kernels (Weisfeiler-Lehman, shortest-path, random-walk) * - Graph structure encoding * - Node classification/regression */ export class GraphKernelELM { constructor(options) { this.trained = false; requireLicense(); // Premium feature - requires valid license this.categories = options.categories; this.options = { categories: options.categories, kernelType: options.kernelType ?? 'weisfeiler-lehman', wlIterations: options.wlIterations ?? 3, kernel: options.kernel ?? 'rbf', gamma: options.gamma ?? 1.0, degree: options.degree ?? 2, coef0: options.coef0 ?? 0, activation: options.activation ?? 'relu', }; this.kelm = new KernelELM({ categories: this.options.categories, kernel: this.options.kernel, gamma: this.options.gamma, degree: this.options.degree, coef0: this.options.coef0, }); } /** * Train on graphs */ train(graphs, y) { // Prepare labels const labelIndices = y.map(label => typeof label === 'number' ? label : this.options.categories.indexOf(label)); // Compute graph kernel features const features = this._computeGraphKernelFeatures(graphs); // Train KELM this.kelm.setCategories?.(this.options.categories); this.kelm.trainFromData?.(features, labelIndices); this.trained = true; } /** * Compute graph kernel features */ _computeGraphKernelFeatures(graphs) { const features = []; for (const graph of graphs) { let graphFeatures; if (this.options.kernelType === 'weisfeiler-lehman') { graphFeatures = this._weisfeilerLehmanKernel(graph); } else if (this.options.kernelType === 'shortest-path') { graphFeatures = this._shortestPathKernel(graph); } else { graphFeatures = this._randomWalkKernel(graph); } features.push(graphFeatures); } return features; } /** * Weisfeiler-Lehman kernel */ _weisfeilerLehmanKernel(graph) { const features = []; const nodeLabels = new Map(); // Initialize labels with node features for (const node of graph.nodes) { const label = node.features.join(','); nodeLabels.set(node.id, label); } // WL iterations for (let iter = 0; iter < this.options.wlIterations; iter++) { const newLabels = new Map(); for (const node of graph.nodes) { // Get neighbor labels const neighbors = graph.edges .filter(e => e.source === node.id || e.target === node.id) .map(e => e.source === node.id ? e.target : e.source); const neighborLabels = neighbors .map(nid => nodeLabels.get(nid) || '') .sort() .join(','); // New label = current label + sorted neighbor labels const newLabel = `${nodeLabels.get(node.id)}|${neighborLabels}`; newLabels.set(node.id, newLabel); } // Count label frequencies const labelCounts = new Map(); for (const label of newLabels.values()) { labelCounts.set(label, (labelCounts.get(label) || 0) + 1); } // Add to features for (const [label, count] of labelCounts) { features.push(count); } nodeLabels.clear(); for (const [id, label] of newLabels) { nodeLabels.set(id, label); } } return features.length > 0 ? features : new Array(10).fill(0); } /** * Shortest-path kernel */ _shortestPathKernel(graph) { // Compute shortest paths between all pairs const distances = this._computeShortestPaths(graph); // Create histogram of distances const maxDist = Math.max(...distances.flat().filter(d => d < Infinity)); const bins = Math.min(10, maxDist + 1); const histogram = new Array(bins).fill(0); for (const row of distances) { for (const dist of row) { if (dist < Infinity) { const bin = Math.min(Math.floor(dist), bins - 1); histogram[bin]++; } } } return histogram; } /** * Random-walk kernel */ _randomWalkKernel(graph) { // Simplified random-walk kernel const features = []; // Node degree distribution const degrees = new Map(); for (const edge of graph.edges) { degrees.set(edge.source, (degrees.get(edge.source) || 0) + 1); degrees.set(edge.target, (degrees.get(edge.target) || 0) + 1); } const degreeHist = new Array(10).fill(0); for (const degree of degrees.values()) { const bin = Math.min(degree, 9); degreeHist[bin]++; } features.push(...degreeHist); // Graph statistics features.push(graph.nodes.length); features.push(graph.edges.length); features.push(graph.nodes.length > 0 ? graph.edges.length / graph.nodes.length : 0); return features; } /** * Compute shortest paths (Floyd-Warshall simplified) */ _computeShortestPaths(graph) { const n = graph.nodes.length; const dist = Array(n).fill(null).map(() => Array(n).fill(Infinity)); // Initialize for (let i = 0; i < n; i++) { dist[i][i] = 0; } // Add edges for (const edge of graph.edges) { const srcIdx = graph.nodes.findIndex(n => n.id === edge.source); const tgtIdx = graph.nodes.findIndex(n => n.id === edge.target); if (srcIdx >= 0 && tgtIdx >= 0) { dist[srcIdx][tgtIdx] = 1; dist[tgtIdx][srcIdx] = 1; } } // Floyd-Warshall for (let k = 0; k < n; k++) { for (let i = 0; i < n; i++) { for (let j = 0; j < n; j++) { if (dist[i][k] + dist[k][j] < dist[i][j]) { dist[i][j] = dist[i][k] + dist[k][j]; } } } } return dist; } /** * Predict on graphs */ predict(graphs, topK = 3) { if (!this.trained) { throw new Error('Model must be trained before prediction'); } const graphArray = Array.isArray(graphs) ? graphs : [graphs]; const features = this._computeGraphKernelFeatures(graphArray); const results = []; for (const feature of features) { const preds = this.kelm.predictFromVector?.([feature], topK) || []; for (const pred of preds.slice(0, topK)) { results.push({ label: pred.label || this.options.categories[pred.index || 0], prob: pred.prob || 0, }); } } return results; } } //# sourceMappingURL=graph-kernel-elm.js.map