@astermind/astermind-premium
Version:
Astermind Premium - Premium ML Toolkit
201 lines • 7.87 kB
JavaScript
// graph-elm.ts — Graph ELM for graph-structured data
// Graph neural network + ELM for node/edge classification
import { ELM } from '@astermind/astermind-elm';
import { requireLicense } from '../core/license.js';
/**
* Graph ELM for graph-structured data
* Features:
* - Node feature learning
* - Graph structure encoding
* - Edge-aware classification
* - Graph convolution operations
*/
export class GraphELM {
constructor(options) {
this.trained = false;
this.nodeFeatureMap = new Map();
requireLicense(); // Premium feature - requires valid license
this.categories = options.categories;
this.options = {
categories: options.categories,
hiddenUnits: options.hiddenUnits ?? 256,
aggregationType: options.aggregationType ?? 'mean',
numLayers: options.numLayers ?? 2,
activation: options.activation ?? 'relu',
maxLen: options.maxLen ?? 100,
useTokenizer: options.useTokenizer ?? true,
};
this.elm = new ELM({
useTokenizer: this.options.useTokenizer ? true : undefined,
hiddenUnits: this.options.hiddenUnits,
categories: this.options.categories,
maxLen: this.options.maxLen,
activation: this.options.activation,
});
}
/**
* Train on graph data
* @param graphs Array of graphs
* @param y Labels for each graph (or node labels)
*/
train(graphs, y) {
// Prepare labels
const labelIndices = y.map(label => typeof label === 'number'
? label
: this.options.categories.indexOf(label));
// Extract graph features
const graphFeatures = graphs.map(graph => this._extractGraphFeatures(graph));
// Train base ELM
this.elm.setCategories?.(this.options.categories);
this.elm.trainFromData?.(graphFeatures, labelIndices);
this.trained = true;
}
/**
* Extract features from graph structure
*/
_extractGraphFeatures(graph) {
// Build adjacency map
const adjacencyMap = new Map();
for (const edge of graph.edges) {
if (!adjacencyMap.has(edge.source)) {
adjacencyMap.set(edge.source, []);
}
if (!adjacencyMap.has(edge.target)) {
adjacencyMap.set(edge.target, []);
}
adjacencyMap.get(edge.source).push(String(edge.target));
adjacencyMap.get(edge.target).push(String(edge.source));
}
// Compute node features through graph convolution
const nodeFeatures = new Map();
// Initialize with node features
for (const node of graph.nodes) {
nodeFeatures.set(node.id, [...node.features]);
}
// Graph convolution layers
for (let layer = 0; layer < this.options.numLayers; layer++) {
const newFeatures = new Map();
for (const node of graph.nodes) {
const neighbors = adjacencyMap.get(node.id) || [];
const neighborFeatures = neighbors
.map(nid => {
const node = graph.nodes.find(n => String(n.id) === String(nid));
return node ? nodeFeatures.get(node.id) : null;
})
.filter(f => f !== null);
// Aggregate neighbor features
const aggregated = this._aggregateNeighbors(neighborFeatures);
// Combine with self features
const selfFeatures = nodeFeatures.get(node.id) || [];
const combined = this._combineFeatures(selfFeatures, aggregated);
newFeatures.set(node.id, combined);
}
// Update features
for (const [id, features] of newFeatures) {
nodeFeatures.set(id, features);
}
}
// Aggregate all node features to graph-level features
const allNodeFeatures = Array.from(nodeFeatures.values());
const graphFeatures = this._aggregateNodes(allNodeFeatures);
return graphFeatures;
}
/**
* Aggregate neighbor features
*/
_aggregateNeighbors(neighborFeatures) {
if (neighborFeatures.length === 0) {
return [];
}
const dim = neighborFeatures[0].length;
const aggregated = new Array(dim).fill(0);
for (const features of neighborFeatures) {
for (let i = 0; i < dim; i++) {
if (this.options.aggregationType === 'mean') {
aggregated[i] += features[i] / neighborFeatures.length;
}
else if (this.options.aggregationType === 'sum') {
aggregated[i] += features[i];
}
else if (this.options.aggregationType === 'max') {
aggregated[i] = Math.max(aggregated[i], features[i]);
}
}
}
return aggregated;
}
/**
* Combine self and neighbor features
*/
_combineFeatures(self, neighbors) {
const dim = Math.max(self.length, neighbors.length);
const combined = new Array(dim).fill(0);
for (let i = 0; i < dim; i++) {
const selfVal = i < self.length ? self[i] : 0;
const neighborVal = i < neighbors.length ? neighbors[i] : 0;
combined[i] = selfVal + neighborVal; // Simple addition
}
// Apply activation
if (this.options.activation === 'relu') {
return combined.map(x => Math.max(0, x));
}
else if (this.options.activation === 'tanh') {
return combined.map(x => Math.tanh(x));
}
else if (this.options.activation === 'sigmoid') {
return combined.map(x => 1 / (1 + Math.exp(-x)));
}
return combined;
}
/**
* Aggregate all node features to graph level
*/
_aggregateNodes(nodeFeatures) {
if (nodeFeatures.length === 0) {
return [];
}
const dim = nodeFeatures[0].length;
const graphFeatures = new Array(dim).fill(0);
for (const features of nodeFeatures) {
for (let i = 0; i < dim; i++) {
if (this.options.aggregationType === 'mean') {
graphFeatures[i] += features[i] / nodeFeatures.length;
}
else if (this.options.aggregationType === 'sum') {
graphFeatures[i] += features[i];
}
else if (this.options.aggregationType === 'max') {
graphFeatures[i] = Math.max(graphFeatures[i], features[i]);
}
}
}
return graphFeatures;
}
/**
* Predict on graph
*/
predict(graph, topK = 3) {
if (!this.trained) {
throw new Error('Model must be trained before prediction');
}
const graphs = Array.isArray(graph) ? graph : [graph];
const results = [];
for (const g of graphs) {
const graphFeatures = this._extractGraphFeatures(g);
const preds = this.elm.predictFromVector?.([graphFeatures], topK) || [];
// Store node features for first node (for interpretability)
const firstNodeFeatures = g.nodes.length > 0
? this.nodeFeatureMap.get(g.nodes[0].id) || g.nodes[0].features
: undefined;
for (const pred of preds.slice(0, topK)) {
results.push({
label: pred.label || this.options.categories[pred.index || 0],
prob: pred.prob || 0,
nodeFeatures: firstNodeFeatures,
});
}
}
return results;
}
}
//# sourceMappingURL=graph-elm.js.map