claude-flow
Version:
Ruflo - Enterprise AI agent orchestration for Claude Code. Deploy 60+ specialized agents in coordinated swarms with self-learning, fault-tolerant consensus, vector memory, and MCP integration
1,647 lines (1,423 loc) • 89.9 kB
text/typescript
/**
* RuVector PostgreSQL Bridge - Graph Neural Network (GNN) Layers Module
*
* Comprehensive GNN support for RuVector PostgreSQL vector database integration.
* Implements GCN, GAT, GraphSAGE, GIN, MPNN, EdgeConv, and more.
*
* @module @claude-flow/plugins/integrations/ruvector/gnn
* @version 1.0.0
*/
import type {
GNNLayerType,
GNNLayer,
GraphData,
GNNOutput,
GNNAggregation,
GNNStats,
ActivationFunction,
} from './types.js';
// ============================================================================
// Constants and Configuration
// ============================================================================
/**
* Default configuration values for GNN layers.
*/
export const GNN_DEFAULTS = {
dropout: 0.0,
addSelfLoops: true,
normalize: true,
useBias: true,
activation: 'relu' as ActivationFunction,
aggregation: 'mean' as GNNAggregation,
numHeads: 1,
negativeSlope: 0.2, // For LeakyReLU in GAT
eps: 0.0, // For GIN
sampleSize: 10, // For GraphSAGE
k: 20, // For EdgeConv k-NN
} as const;
/**
* SQL function mapping for GNN operations.
*/
export const GNN_SQL_FUNCTIONS = {
gcn: 'ruvector.gcn_layer',
gat: 'ruvector.gat_layer',
gat_v2: 'ruvector.gat_v2_layer',
sage: 'ruvector.sage_layer',
gin: 'ruvector.gin_layer',
mpnn: 'ruvector.mpnn_layer',
edge_conv: 'ruvector.edge_conv_layer',
point_conv: 'ruvector.point_conv_layer',
transformer: 'ruvector.graph_transformer_layer',
pna: 'ruvector.pna_layer',
film: 'ruvector.film_layer',
rgcn: 'ruvector.rgcn_layer',
hgt: 'ruvector.hgt_layer',
han: 'ruvector.han_layer',
metapath: 'ruvector.metapath_layer',
} as const;
// ============================================================================
// Core Interfaces
// ============================================================================
/**
* Node identifier type.
*/
export type NodeId = string | number;
/**
* Node features representation.
*/
export interface NodeFeatures {
/** Node IDs */
readonly ids: NodeId[];
/** Feature vectors [num_nodes, feature_dim] */
readonly features: number[][];
/** Optional node types for heterogeneous graphs */
readonly types?: string[];
/** Optional node labels */
readonly labels?: number[];
}
/**
* Edge features representation.
*/
export interface EdgeFeatures {
/** Source node IDs */
readonly sources: NodeId[];
/** Target node IDs */
readonly targets: NodeId[];
/** Edge feature vectors [num_edges, edge_dim] (optional) */
readonly features?: number[][];
/** Edge weights (optional) */
readonly weights?: number[];
/** Edge types for heterogeneous graphs (optional) */
readonly types?: string[];
}
/**
* Message representation for message passing.
*/
export interface Message {
/** Source node ID */
readonly source: NodeId;
/** Target node ID */
readonly target: NodeId;
/** Message vector */
readonly vector: number[];
/** Edge features (if applicable) */
readonly edgeFeatures?: number[];
/** Message weight */
readonly weight?: number;
}
/**
* Aggregation method type with extended options.
*/
export type AggregationMethod =
| GNNAggregation
| 'concat'
| 'weighted_mean'
| 'multi_head';
/**
* Path representation for graph traversal.
*/
export interface Path {
/** Ordered list of node IDs */
readonly nodes: NodeId[];
/** Total path weight/distance */
readonly weight: number;
/** Edge types along the path (for heterogeneous graphs) */
readonly edgeTypes?: string[];
}
/**
* Community detection result.
*/
export interface Community {
/** Community identifier */
readonly id: number;
/** Member node IDs */
readonly members: NodeId[];
/** Community centroid (average features) */
readonly centroid?: number[];
/** Modularity score */
readonly modularity?: number;
/** Internal edge density */
readonly density?: number;
}
/**
* PageRank computation options.
*/
export interface PageRankOptions {
/** Damping factor (default: 0.85) */
readonly damping?: number;
/** Maximum iterations (default: 100) */
readonly maxIterations?: number;
/** Convergence tolerance (default: 1e-6) */
readonly tolerance?: number;
/** Personalization vector (teleport probabilities) */
readonly personalization?: Map<NodeId, number>;
/** Whether to use weighted edges */
readonly weighted?: boolean;
}
/**
* Community detection options.
*/
export interface CommunityOptions {
/** Detection algorithm */
readonly algorithm: 'louvain' | 'label_propagation' | 'girvan_newman' | 'spectral';
/** Resolution parameter (for Louvain) */
readonly resolution?: number;
/** Maximum iterations */
readonly maxIterations?: number;
/** Minimum community size */
readonly minSize?: number;
/** Random seed for reproducibility */
readonly seed?: number;
}
/**
* GNN layer configuration with validation.
*/
export interface GNNLayerConfig extends GNNLayer {
/** Layer name/identifier */
readonly name?: string;
/** Whether to cache intermediate results */
readonly cache?: boolean;
/** Quantization bits for memory efficiency */
readonly quantizeBits?: 8 | 16 | 32;
}
/**
* Factory function type for creating GNN layers.
*/
export type GNNLayerFactory = (config: GNNLayerConfig) => IGNNLayer;
/**
* Interface for GNN layer implementations.
*/
export interface IGNNLayer {
/** Layer type */
readonly type: GNNLayerType;
/** Layer configuration */
readonly config: GNNLayerConfig;
/**
* Forward pass through the GNN layer.
* @param graph - Input graph data
* @returns Promise resolving to GNN output
*/
forward(graph: GraphData): Promise<GNNOutput>;
/**
* Message passing step.
* @param nodes - Node features
* @param edges - Edge features
* @returns Promise resolving to updated node features
*/
messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures>;
/**
* Aggregate messages using the specified method.
* @param messages - Array of messages to aggregate
* @param method - Aggregation method
* @returns Promise resolving to aggregated vector
*/
aggregate(messages: Message[], method: AggregationMethod): Promise<number[]>;
/**
* Reset layer state (if stateful).
*/
reset(): void;
/**
* Generate SQL for this layer.
* @param tableName - Target table name
* @param options - SQL generation options
* @returns SQL string
*/
toSQL(tableName: string, options?: SQLGenerationOptions): string;
}
/**
* SQL generation options.
*/
export interface SQLGenerationOptions {
/** Schema name */
readonly schema?: string;
/** Node features column */
readonly nodeColumn?: string;
/** Edge table name */
readonly edgeTable?: string;
/** Whether to use prepared statements */
readonly prepared?: boolean;
/** Parameter prefix for prepared statements */
readonly paramPrefix?: string;
}
// ============================================================================
// GNN Layer Registry
// ============================================================================
/**
* Registry for managing GNN layer types and factories.
*
* @example
* ```typescript
* const registry = new GNNLayerRegistry();
* registry.registerLayer('custom_gnn', CustomGNNFactory);
* const layer = registry.createLayer('gcn', { inputDim: 64, outputDim: 32 });
* ```
*/
export class GNNLayerRegistry {
private readonly factories: Map<GNNLayerType | string, GNNLayerFactory> = new Map();
private readonly defaultConfigs: Map<GNNLayerType | string, Partial<GNNLayerConfig>> = new Map();
constructor() {
// Register built-in layer factories
this.registerBuiltinLayers();
}
/**
* Register a GNN layer factory.
* @param type - Layer type identifier
* @param factory - Factory function
* @param defaultConfig - Optional default configuration
*/
registerLayer(
type: GNNLayerType | string,
factory: GNNLayerFactory,
defaultConfig?: Partial<GNNLayerConfig>
): void {
this.factories.set(type, factory);
if (defaultConfig) {
this.defaultConfigs.set(type, defaultConfig);
}
}
/**
* Unregister a GNN layer factory.
* @param type - Layer type to remove
* @returns Whether the layer was removed
*/
unregisterLayer(type: GNNLayerType | string): boolean {
this.defaultConfigs.delete(type);
return this.factories.delete(type);
}
/**
* Create a GNN layer instance.
* @param type - Layer type
* @param config - Layer configuration
* @returns IGNNLayer instance
* @throws Error if layer type is not registered
*/
createLayer(type: GNNLayerType, config: Partial<GNNLayerConfig>): IGNNLayer {
const factory = this.factories.get(type);
if (!factory) {
throw new Error(`Unknown GNN layer type: ${type}. Available types: ${this.getLayerTypes().join(', ')}`);
}
const defaultConfig = this.defaultConfigs.get(type) ?? {};
const fullConfig: GNNLayerConfig = {
type,
inputDim: config.inputDim ?? 64,
outputDim: config.outputDim ?? 64,
dropout: config.dropout ?? defaultConfig.dropout ?? GNN_DEFAULTS.dropout,
aggregation: config.aggregation ?? defaultConfig.aggregation ?? GNN_DEFAULTS.aggregation,
addSelfLoops: config.addSelfLoops ?? defaultConfig.addSelfLoops ?? GNN_DEFAULTS.addSelfLoops,
normalize: config.normalize ?? defaultConfig.normalize ?? GNN_DEFAULTS.normalize,
useBias: config.useBias ?? defaultConfig.useBias ?? GNN_DEFAULTS.useBias,
activation: config.activation ?? defaultConfig.activation ?? GNN_DEFAULTS.activation,
...config,
};
return factory(fullConfig);
}
/**
* Check if a layer type is registered.
* @param type - Layer type to check
* @returns Whether the layer is registered
*/
hasLayer(type: GNNLayerType | string): boolean {
return this.factories.has(type);
}
/**
* Get all registered layer types.
* @returns Array of layer type identifiers
*/
getLayerTypes(): string[] {
return Array.from(this.factories.keys());
}
/**
* Get default configuration for a layer type.
* @param type - Layer type
* @returns Default configuration or undefined
*/
getDefaultConfig(type: GNNLayerType | string): Partial<GNNLayerConfig> | undefined {
return this.defaultConfigs.get(type);
}
/**
* Register all built-in GNN layer factories.
*/
private registerBuiltinLayers(): void {
// GCN - Graph Convolutional Network
this.registerLayer('gcn', (config) => new GCNLayer(config), {
normalize: true,
addSelfLoops: true,
});
// GAT - Graph Attention Network
this.registerLayer('gat', (config) => new GATLayer(config), {
numHeads: 8,
params: { negativeSlope: 0.2, concat: true },
});
// GAT v2 - Improved Graph Attention
this.registerLayer('gat_v2', (config) => new GATv2Layer(config), {
numHeads: 8,
params: { negativeSlope: 0.2, concat: true },
});
// GraphSAGE - Sampling and Aggregation
this.registerLayer('sage', (config) => new GraphSAGELayer(config), {
aggregation: 'mean',
params: { sampleSize: 10, samplingStrategy: 'uniform' },
});
// GIN - Graph Isomorphism Network
this.registerLayer('gin', (config) => new GINLayer(config), {
params: { eps: 0, trainEps: false },
});
// MPNN - Message Passing Neural Network
this.registerLayer('mpnn', (config) => new MPNNLayer(config), {
aggregation: 'sum',
});
// EdgeConv - Dynamic Edge Convolution
this.registerLayer('edge_conv', (config) => new EdgeConvLayer(config), {
params: { k: 20, dynamic: true },
});
// Point Convolution
this.registerLayer('point_conv', (config) => new PointConvLayer(config), {
params: { k: 16 },
});
// Graph Transformer
this.registerLayer('transformer', (config) => new GraphTransformerLayer(config), {
numHeads: 8,
params: { numLayers: 1 },
});
// PNA - Principal Neighbourhood Aggregation
this.registerLayer('pna', (config) => new PNALayer(config), {
params: {
aggregators: ['mean', 'sum', 'max', 'min'],
scalers: ['identity', 'amplification', 'attenuation'],
},
});
// FiLM - Feature-wise Linear Modulation
this.registerLayer('film', (config) => new FiLMLayer(config), {});
// RGCN - Relational Graph Convolutional Network
this.registerLayer('rgcn', (config) => new RGCNLayer(config), {
params: { numRelations: 1 },
});
// HGT - Heterogeneous Graph Transformer
this.registerLayer('hgt', (config) => new HGTLayer(config), {
numHeads: 8,
});
// HAN - Heterogeneous Attention Network
this.registerLayer('han', (config) => new HANLayer(config), {
numHeads: 8,
});
// MetaPath aggregation
this.registerLayer('metapath', (config) => new MetaPathLayer(config), {
params: { metapaths: [] },
});
}
}
// ============================================================================
// Base GNN Layer Implementation
// ============================================================================
/**
* Abstract base class for GNN layer implementations.
*/
export abstract class BaseGNNLayer implements IGNNLayer {
readonly type: GNNLayerType;
readonly config: GNNLayerConfig;
constructor(config: GNNLayerConfig) {
this.type = config.type;
this.config = config;
this.validateConfig();
}
/**
* Validate layer configuration.
* @throws Error if configuration is invalid
*/
protected validateConfig(): void {
if (this.config.inputDim <= 0) {
throw new Error(`Invalid inputDim: ${this.config.inputDim}. Must be positive.`);
}
if (this.config.outputDim <= 0) {
throw new Error(`Invalid outputDim: ${this.config.outputDim}. Must be positive.`);
}
if (this.config.dropout !== undefined && (this.config.dropout < 0 || this.config.dropout > 1)) {
throw new Error(`Invalid dropout: ${this.config.dropout}. Must be between 0 and 1.`);
}
if (this.config.numHeads !== undefined && this.config.numHeads <= 0) {
throw new Error(`Invalid numHeads: ${this.config.numHeads}. Must be positive.`);
}
}
abstract forward(graph: GraphData): Promise<GNNOutput>;
abstract messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures>;
/**
* Aggregate messages using the specified method.
*/
async aggregate(messages: Message[], method: AggregationMethod): Promise<number[]> {
if (messages.length === 0) {
return new Array(this.config.outputDim).fill(0);
}
const vectors = messages.map((m) => m.vector);
const weights = messages.map((m) => m.weight ?? 1);
switch (method) {
case 'sum':
return this.aggregateSum(vectors);
case 'mean':
return this.aggregateMean(vectors);
case 'max':
return this.aggregateMax(vectors);
case 'min':
return this.aggregateMin(vectors);
case 'attention':
return this.aggregateAttention(vectors, weights);
case 'weighted_mean':
return this.aggregateWeightedMean(vectors, weights);
case 'softmax':
return this.aggregateSoftmax(vectors);
case 'power_mean':
return this.aggregatePowerMean(vectors, 2);
case 'std':
return this.aggregateStd(vectors);
case 'var':
return this.aggregateVar(vectors);
case 'concat':
return this.aggregateConcat(vectors);
case 'lstm':
return this.aggregateLSTM(vectors);
case 'multi_head':
return this.aggregateMultiHead(vectors);
default:
return this.aggregateMean(vectors);
}
}
/**
* Reset layer state.
*/
reset(): void {
// Override in stateful layers
}
/**
* Generate SQL for this layer.
*/
toSQL(tableName: string, options: SQLGenerationOptions = {}): string {
const schema = options.schema ?? 'public';
const nodeColumn = options.nodeColumn ?? 'embedding';
const edgeTable = options.edgeTable ?? `${tableName}_edges`;
const sqlFunction = GNN_SQL_FUNCTIONS[this.type] ?? 'ruvector.gnn_layer';
const configJson = JSON.stringify({
type: this.type,
input_dim: this.config.inputDim,
output_dim: this.config.outputDim,
num_heads: this.config.numHeads,
dropout: this.config.dropout,
aggregation: this.config.aggregation,
add_self_loops: this.config.addSelfLoops,
normalize: this.config.normalize,
use_bias: this.config.useBias,
activation: this.config.activation,
params: this.config.params,
});
if (options.prepared) {
const prefix = options.paramPrefix ?? '$';
return `
SELECT ${sqlFunction}(
(SELECT array_agg(${nodeColumn}) FROM "${schema}"."${tableName}"),
(SELECT array_agg(ARRAY[source_id, target_id]) FROM "${schema}"."${edgeTable}"),
${prefix}1::jsonb
);`.trim();
}
return `
SELECT ${sqlFunction}(
(SELECT array_agg(${nodeColumn}) FROM "${schema}"."${tableName}"),
(SELECT array_agg(ARRAY[source_id, target_id]) FROM "${schema}"."${edgeTable}"),
'${configJson}'::jsonb
);`.trim();
}
// Aggregation implementations
protected aggregateSum(vectors: number[][]): number[] {
const dim = vectors[0]?.length ?? 0;
const result = new Array(dim).fill(0);
for (const vec of vectors) {
for (let i = 0; i < dim; i++) {
result[i] += vec[i] ?? 0;
}
}
return result;
}
protected aggregateMean(vectors: number[][]): number[] {
const sum = this.aggregateSum(vectors);
return sum.map((v) => v / vectors.length);
}
protected aggregateMax(vectors: number[][]): number[] {
const dim = vectors[0]?.length ?? 0;
const result = new Array(dim).fill(-Infinity);
for (const vec of vectors) {
for (let i = 0; i < dim; i++) {
result[i] = Math.max(result[i], vec[i] ?? -Infinity);
}
}
return result;
}
protected aggregateMin(vectors: number[][]): number[] {
const dim = vectors[0]?.length ?? 0;
const result = new Array(dim).fill(Infinity);
for (const vec of vectors) {
for (let i = 0; i < dim; i++) {
result[i] = Math.min(result[i], vec[i] ?? Infinity);
}
}
return result;
}
protected aggregateWeightedMean(vectors: number[][], weights: number[]): number[] {
const dim = vectors[0]?.length ?? 0;
const result = new Array(dim).fill(0);
let totalWeight = 0;
for (let j = 0; j < vectors.length; j++) {
const w = weights[j] ?? 1;
totalWeight += w;
for (let i = 0; i < dim; i++) {
result[i] += (vectors[j]?.[i] ?? 0) * w;
}
}
return result.map((v) => (totalWeight > 0 ? v / totalWeight : 0));
}
protected aggregateAttention(vectors: number[][], weights: number[]): number[] {
// Softmax over weights then weighted mean
const maxWeight = Math.max(...weights);
const expWeights = weights.map((w) => Math.exp(w - maxWeight));
const sumExp = expWeights.reduce((a, b) => a + b, 0);
const attentionWeights = expWeights.map((w) => w / sumExp);
return this.aggregateWeightedMean(vectors, attentionWeights);
}
protected aggregateSoftmax(vectors: number[][]): number[] {
// Softmax aggregation across vectors
const dim = vectors[0]?.length ?? 0;
const result = new Array(dim).fill(0);
for (let i = 0; i < dim; i++) {
const values = vectors.map((v) => v[i] ?? 0);
const maxVal = Math.max(...values);
const expValues = values.map((v) => Math.exp(v - maxVal));
const sumExp = expValues.reduce((a, b) => a + b, 0);
result[i] = expValues.reduce((sum, exp, j) => sum + (exp / sumExp) * values[j], 0);
}
return result;
}
protected aggregatePowerMean(vectors: number[][], p: number): number[] {
const dim = vectors[0]?.length ?? 0;
const result = new Array(dim).fill(0);
for (let i = 0; i < dim; i++) {
let sum = 0;
for (const vec of vectors) {
sum += Math.pow(Math.abs(vec[i] ?? 0), p);
}
result[i] = Math.pow(sum / vectors.length, 1 / p);
}
return result;
}
protected aggregateStd(vectors: number[][]): number[] {
const mean = this.aggregateMean(vectors);
const dim = mean.length;
const variance = new Array(dim).fill(0);
for (const vec of vectors) {
for (let i = 0; i < dim; i++) {
variance[i] += Math.pow((vec[i] ?? 0) - mean[i], 2);
}
}
return variance.map((v) => Math.sqrt(v / vectors.length));
}
protected aggregateVar(vectors: number[][]): number[] {
const mean = this.aggregateMean(vectors);
const dim = mean.length;
const variance = new Array(dim).fill(0);
for (const vec of vectors) {
for (let i = 0; i < dim; i++) {
variance[i] += Math.pow((vec[i] ?? 0) - mean[i], 2);
}
}
return variance.map((v) => v / vectors.length);
}
protected aggregateConcat(vectors: number[][]): number[] {
return vectors.flat();
}
protected aggregateLSTM(vectors: number[][]): number[] {
// Simplified LSTM-style aggregation (sequential processing)
let hidden = new Array(this.config.outputDim).fill(0);
for (const vec of vectors) {
hidden = this.lstmCell(vec, hidden);
}
return hidden;
}
protected aggregateMultiHead(vectors: number[][]): number[] {
// Split into heads, aggregate each, then combine
const numHeads = this.config.numHeads ?? 1;
const headDim = Math.floor((vectors[0]?.length ?? 0) / numHeads);
const results: number[][] = [];
for (let h = 0; h < numHeads; h++) {
const headVectors = vectors.map((v) =>
v.slice(h * headDim, (h + 1) * headDim)
);
results.push(this.aggregateMean(headVectors));
}
return results.flat();
}
private lstmCell(input: number[], hidden: number[]): number[] {
// Simplified LSTM update (no learned parameters)
const dim = hidden.length;
const inputDim = input.length;
const result = new Array(dim).fill(0);
for (let i = 0; i < dim; i++) {
const inputVal = input[i % inputDim] ?? 0;
const hiddenVal = hidden[i] ?? 0;
// Simple gated update
const gate = 1 / (1 + Math.exp(-(inputVal + hiddenVal)));
result[i] = gate * inputVal + (1 - gate) * hiddenVal;
}
return result;
}
/**
* Apply activation function.
*/
protected applyActivation(x: number): number {
switch (this.config.activation) {
case 'relu':
return Math.max(0, x);
case 'gelu':
return 0.5 * x * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (x + 0.044715 * Math.pow(x, 3))));
case 'silu':
case 'swish':
return x / (1 + Math.exp(-x));
case 'leaky_relu':
return x >= 0 ? x : 0.01 * x;
case 'elu':
return x >= 0 ? x : Math.exp(x) - 1;
case 'selu':
const alpha = 1.6732632423543772;
const scale = 1.0507009873554805;
return scale * (x >= 0 ? x : alpha * (Math.exp(x) - 1));
case 'tanh':
return Math.tanh(x);
case 'sigmoid':
return 1 / (1 + Math.exp(-x));
case 'softmax':
case 'none':
default:
return x;
}
}
/**
* Apply dropout (during training).
*/
protected applyDropout(vector: number[], training: boolean = false): number[] {
if (!training || !this.config.dropout || this.config.dropout === 0) {
return vector;
}
const scale = 1 / (1 - this.config.dropout);
return vector.map((v) => (Math.random() > this.config.dropout! ? v * scale : 0));
}
/**
* Normalize vector (L2 normalization).
*/
protected normalizeVector(vector: number[]): number[] {
const norm = Math.sqrt(vector.reduce((sum, v) => sum + v * v, 0));
return norm > 0 ? vector.map((v) => v / norm) : vector;
}
/**
* Create statistics for GNN computation.
*/
protected createStats(
startTime: number,
numNodes: number,
numEdges: number,
numIterations: number = 1
): GNNStats {
return {
forwardTimeMs: Date.now() - startTime,
numNodes,
numEdges,
memoryBytes: numNodes * this.config.outputDim * 4 + numEdges * 8,
numIterations,
};
}
}
// ============================================================================
// GCN Layer Implementation
// ============================================================================
/**
* Graph Convolutional Network (GCN) layer.
*
* Implements spectral graph convolution with first-order approximation.
* Reference: Kipf & Welling, "Semi-Supervised Classification with Graph Convolutional Networks" (2017)
*/
export class GCNLayer extends BaseGNNLayer {
async forward(graph: GraphData): Promise<GNNOutput> {
const startTime = Date.now();
const { nodeFeatures, edgeIndex, edgeWeights } = graph;
const numNodes = nodeFeatures.length;
const numEdges = edgeIndex[0].length;
// Build adjacency with self-loops
const adj = this.buildAdjacency(numNodes, edgeIndex, edgeWeights);
// Normalize adjacency (D^-0.5 * A * D^-0.5)
const normAdj = this.config.normalize ? this.symmetricNormalize(adj, numNodes) : adj;
// Message passing: H' = sigma(A_norm * H * W)
const outputFeatures = this.convolve(nodeFeatures, normAdj);
return {
nodeEmbeddings: outputFeatures,
graphEmbedding: this.poolGraph(outputFeatures),
stats: this.createStats(startTime, numNodes, numEdges),
};
}
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
const numNodes = nodes.ids.length;
const edgeIndex: [number[], number[]] = [
edges.sources.map((s) => nodes.ids.indexOf(s)),
edges.targets.map((t) => nodes.ids.indexOf(t)),
];
const adj = this.buildAdjacency(numNodes, edgeIndex, edges.weights);
const normAdj = this.config.normalize ? this.symmetricNormalize(adj, numNodes) : adj;
const outputFeatures = this.convolve(nodes.features, normAdj);
return {
ids: nodes.ids,
features: outputFeatures,
types: nodes.types,
labels: nodes.labels,
};
}
private buildAdjacency(
numNodes: number,
edgeIndex: [number[], number[]],
weights?: number[]
): Map<number, Map<number, number>> {
const adj = new Map<number, Map<number, number>>();
// Initialize with self-loops if configured
for (let i = 0; i < numNodes; i++) {
adj.set(i, new Map());
if (this.config.addSelfLoops) {
adj.get(i)!.set(i, 1);
}
}
// Add edges
const [sources, targets] = edgeIndex;
for (let i = 0; i < sources.length; i++) {
const src = sources[i];
const tgt = targets[i];
const weight = weights?.[i] ?? 1;
if (src >= 0 && src < numNodes && tgt >= 0 && tgt < numNodes) {
adj.get(src)!.set(tgt, weight);
// Undirected: add reverse edge
adj.get(tgt)!.set(src, weight);
}
}
return adj;
}
private symmetricNormalize(
adj: Map<number, Map<number, number>>,
numNodes: number
): Map<number, Map<number, number>> {
// Compute degree
const degree = new Array(numNodes).fill(0);
for (let i = 0; i < numNodes; i++) {
for (const weight of adj.get(i)!.values()) {
degree[i] += weight;
}
}
// D^-0.5 * A * D^-0.5
const normAdj = new Map<number, Map<number, number>>();
for (let i = 0; i < numNodes; i++) {
normAdj.set(i, new Map());
for (const [j, weight] of adj.get(i)!.entries()) {
const normWeight = weight / Math.sqrt(degree[i] * degree[j] + 1e-10);
normAdj.get(i)!.set(j, normWeight);
}
}
return normAdj;
}
private convolve(features: number[][], adj: Map<number, Map<number, number>>): number[][] {
const numNodes = features.length;
const inputDim = this.config.inputDim;
const outputDim = this.config.outputDim;
const output: number[][] = [];
for (let i = 0; i < numNodes; i++) {
const aggregated = new Array(inputDim).fill(0);
// Aggregate neighbor features
for (const [j, weight] of adj.get(i)!.entries()) {
const neighborFeatures = features[j] ?? new Array(inputDim).fill(0);
for (let k = 0; k < inputDim; k++) {
aggregated[k] += weight * (neighborFeatures[k] ?? 0);
}
}
// Project to output dimension
const projected = this.projectFeatures(aggregated, inputDim, outputDim);
// Apply activation
const activated = projected.map((x) => this.applyActivation(x));
// Apply dropout
output.push(this.applyDropout(activated));
}
return output;
}
private projectFeatures(input: number[], inputDim: number, outputDim: number): number[] {
// Simple linear projection (in practice, this would use learned weights)
const output = new Array(outputDim).fill(0);
for (let i = 0; i < outputDim; i++) {
for (let j = 0; j < inputDim; j++) {
// Use a deterministic pseudo-weight based on position
const weight = Math.sin((i * inputDim + j) * 0.1) * 0.5;
output[i] += input[j] * weight;
}
if (this.config.useBias) {
output[i] += 0.01; // Small bias term
}
}
return output;
}
private poolGraph(features: number[][]): number[] {
if (features.length === 0) return [];
return this.aggregateMean(features);
}
}
// ============================================================================
// GAT Layer Implementation
// ============================================================================
/**
* Graph Attention Network (GAT) layer.
*
* Implements attention-based message passing.
* Reference: Veličković et al., "Graph Attention Networks" (2018)
*/
export class GATLayer extends BaseGNNLayer {
async forward(graph: GraphData): Promise<GNNOutput> {
const startTime = Date.now();
const { nodeFeatures, edgeIndex } = graph;
const numNodes = nodeFeatures.length;
const numEdges = edgeIndex[0].length;
const numHeads = this.config.numHeads ?? 1;
const negativeSlope = this.config.params?.negativeSlope ?? 0.2;
// Compute attention for each head
const headOutputs: number[][][] = [];
for (let h = 0; h < numHeads; h++) {
const headDim = Math.floor(this.config.outputDim / numHeads);
const headFeatures: number[][] = [];
for (let i = 0; i < numNodes; i++) {
const neighbors = this.getNeighbors(i, edgeIndex, numNodes);
const messages: { feature: number[]; attention: number }[] = [];
// Compute attention for each neighbor
for (const j of neighbors) {
const attention = this.computeAttention(
nodeFeatures[i],
nodeFeatures[j],
h,
negativeSlope
);
messages.push({
feature: this.projectHead(nodeFeatures[j], h, headDim),
attention,
});
}
// Softmax attention weights
const attentionSum = messages.reduce(
(sum, m) => sum + Math.exp(m.attention),
0
);
const normalizedMessages = messages.map((m) => ({
feature: m.feature,
weight: Math.exp(m.attention) / (attentionSum + 1e-10),
}));
// Aggregate with attention weights
const aggregated = new Array(headDim).fill(0);
for (const m of normalizedMessages) {
for (let k = 0; k < headDim; k++) {
aggregated[k] += m.weight * (m.feature[k] ?? 0);
}
}
headFeatures.push(aggregated);
}
headOutputs.push(headFeatures);
}
// Combine heads (concat or average)
const concat = this.config.params?.concat ?? true;
const outputFeatures = this.combineHeads(headOutputs, concat);
// Apply activation and dropout
const finalFeatures = outputFeatures.map((f) =>
this.applyDropout(f.map((x) => this.applyActivation(x)))
);
return {
nodeEmbeddings: finalFeatures,
graphEmbedding: this.aggregateMean(finalFeatures),
attentionWeights: this.extractAttentionWeights(headOutputs),
stats: this.createStats(startTime, numNodes, numEdges),
};
}
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
const graph: GraphData = {
nodeFeatures: nodes.features,
edgeIndex: [
edges.sources.map((s) => nodes.ids.indexOf(s)),
edges.targets.map((t) => nodes.ids.indexOf(t)),
],
};
const output = await this.forward(graph);
return {
ids: nodes.ids,
features: output.nodeEmbeddings,
types: nodes.types,
labels: nodes.labels,
};
}
private getNeighbors(
nodeIdx: number,
edgeIndex: [number[], number[]],
numNodes: number
): number[] {
const neighbors = new Set<number>();
// Add self-loop
if (this.config.addSelfLoops) {
neighbors.add(nodeIdx);
}
// Find neighbors from edges
const [sources, targets] = edgeIndex;
for (let i = 0; i < sources.length; i++) {
if (sources[i] === nodeIdx && targets[i] < numNodes) {
neighbors.add(targets[i]);
}
if (targets[i] === nodeIdx && sources[i] < numNodes) {
neighbors.add(sources[i]);
}
}
return Array.from(neighbors);
}
protected computeAttention(
nodeI: number[],
nodeJ: number[],
head: number,
negativeSlope: number
): number {
// Compute attention score using concatenation of features
let score = 0;
const dim = nodeI.length;
for (let k = 0; k < dim; k++) {
// Simple attention mechanism (in practice, uses learned attention weights)
const combined = (nodeI[k] ?? 0) + (nodeJ[k] ?? 0);
score += combined * Math.sin((head * dim + k) * 0.1);
}
// LeakyReLU
return score >= 0 ? score : negativeSlope * score;
}
private projectHead(features: number[], head: number, headDim: number): number[] {
const output = new Array(headDim).fill(0);
const inputDim = features.length;
for (let i = 0; i < headDim; i++) {
for (let j = 0; j < inputDim; j++) {
const weight = Math.cos((head * headDim * inputDim + i * inputDim + j) * 0.05);
output[i] += (features[j] ?? 0) * weight;
}
}
return output;
}
private combineHeads(heads: number[][][], concat: boolean): number[][] {
const numNodes = heads[0]?.length ?? 0;
const result: number[][] = [];
for (let i = 0; i < numNodes; i++) {
if (concat) {
// Concatenate all head outputs
result.push(heads.flatMap((h) => h[i] ?? []));
} else {
// Average head outputs
const headDim = heads[0]?.[0]?.length ?? 0;
const averaged = new Array(headDim).fill(0);
for (const head of heads) {
for (let j = 0; j < headDim; j++) {
averaged[j] += (head[i]?.[j] ?? 0) / heads.length;
}
}
result.push(averaged);
}
}
return result;
}
private extractAttentionWeights(heads: number[][][]): number[][] {
// Return simplified attention representation
return heads.map((h) => h.map((node) => node.reduce((a, b) => a + b, 0) / node.length));
}
}
// ============================================================================
// GAT v2 Layer Implementation
// ============================================================================
/**
* Graph Attention Network v2 layer.
*
* Improved attention mechanism with dynamic attention.
* Reference: Brody et al., "How Attentive are Graph Attention Networks?" (2022)
*/
export class GATv2Layer extends GATLayer {
protected override computeAttention(
nodeI: number[],
nodeJ: number[],
head: number,
negativeSlope: number
): number {
// GAT v2: Apply attention AFTER concatenation and transformation
const dim = nodeI.length;
const combined = new Array(dim).fill(0);
// First, transform and combine
for (let k = 0; k < dim; k++) {
combined[k] = (nodeI[k] ?? 0) + (nodeJ[k] ?? 0);
}
// Apply LeakyReLU
for (let k = 0; k < dim; k++) {
combined[k] = combined[k] >= 0 ? combined[k] : negativeSlope * combined[k];
}
// Then compute attention
let score = 0;
for (let k = 0; k < dim; k++) {
score += combined[k] * Math.sin((head * dim + k) * 0.1);
}
return score;
}
}
// ============================================================================
// GraphSAGE Layer Implementation
// ============================================================================
/**
* GraphSAGE (Sample and Aggregate) layer.
*
* Implements inductive representation learning with neighbor sampling.
* Reference: Hamilton et al., "Inductive Representation Learning on Large Graphs" (2017)
*/
export class GraphSAGELayer extends BaseGNNLayer {
async forward(graph: GraphData): Promise<GNNOutput> {
const startTime = Date.now();
const { nodeFeatures, edgeIndex } = graph;
const numNodes = nodeFeatures.length;
const numEdges = edgeIndex[0].length;
const sampleSize = this.config.params?.sampleSize ?? 10;
const outputFeatures: number[][] = [];
for (let i = 0; i < numNodes; i++) {
// Sample neighbors
const allNeighbors = this.getNeighbors(i, edgeIndex, numNodes);
const sampledNeighbors = this.sampleNeighbors(allNeighbors, sampleSize);
// Aggregate neighbor features
const neighborFeatures = sampledNeighbors.map((j) => nodeFeatures[j] ?? []);
const aggregated = await this.aggregate(
neighborFeatures.map((f) => ({ source: i, target: i, vector: f })),
this.config.aggregation ?? 'mean'
);
// Concatenate with self features and project
const selfFeatures = nodeFeatures[i] ?? [];
const combined = [...selfFeatures, ...aggregated];
const projected = this.projectFeatures(combined, combined.length, this.config.outputDim);
// Normalize, activate, and apply dropout
const normalized = this.config.normalize ? this.normalizeVector(projected) : projected;
const activated = normalized.map((x) => this.applyActivation(x));
outputFeatures.push(this.applyDropout(activated));
}
return {
nodeEmbeddings: outputFeatures,
graphEmbedding: this.aggregateMean(outputFeatures),
stats: this.createStats(startTime, numNodes, numEdges),
};
}
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
const graph: GraphData = {
nodeFeatures: nodes.features,
edgeIndex: [
edges.sources.map((s) => nodes.ids.indexOf(s)),
edges.targets.map((t) => nodes.ids.indexOf(t)),
],
};
const output = await this.forward(graph);
return {
ids: nodes.ids,
features: output.nodeEmbeddings,
types: nodes.types,
labels: nodes.labels,
};
}
private getNeighbors(
nodeIdx: number,
edgeIndex: [number[], number[]],
numNodes: number
): number[] {
const neighbors = new Set<number>();
const [sources, targets] = edgeIndex;
for (let i = 0; i < sources.length; i++) {
if (sources[i] === nodeIdx && targets[i] < numNodes) {
neighbors.add(targets[i]);
}
if (targets[i] === nodeIdx && sources[i] < numNodes) {
neighbors.add(sources[i]);
}
}
return Array.from(neighbors);
}
private sampleNeighbors(neighbors: number[], k: number): number[] {
if (neighbors.length <= k) {
return neighbors;
}
// Random sampling
const sampled: number[] = [];
const available = [...neighbors];
for (let i = 0; i < k && available.length > 0; i++) {
const idx = Math.floor(Math.random() * available.length);
sampled.push(available[idx]);
available.splice(idx, 1);
}
return sampled;
}
private projectFeatures(input: number[], inputDim: number, outputDim: number): number[] {
const output = new Array(outputDim).fill(0);
for (let i = 0; i < outputDim; i++) {
for (let j = 0; j < inputDim; j++) {
const weight = Math.sin((i * inputDim + j) * 0.1) * Math.sqrt(2 / (inputDim + outputDim));
output[i] += (input[j] ?? 0) * weight;
}
if (this.config.useBias) {
output[i] += 0.01;
}
}
return output;
}
}
// ============================================================================
// GIN Layer Implementation
// ============================================================================
/**
* Graph Isomorphism Network (GIN) layer.
*
* Maximally powerful GNN for graph classification.
* Reference: Xu et al., "How Powerful are Graph Neural Networks?" (2019)
*/
export class GINLayer extends BaseGNNLayer {
async forward(graph: GraphData): Promise<GNNOutput> {
const startTime = Date.now();
const { nodeFeatures, edgeIndex } = graph;
const numNodes = nodeFeatures.length;
const numEdges = edgeIndex[0].length;
const eps = this.config.params?.eps ?? 0;
const outputFeatures: number[][] = [];
for (let i = 0; i < numNodes; i++) {
const neighbors = this.getNeighbors(i, edgeIndex, numNodes);
// Sum neighbor features
const neighborSum = new Array(this.config.inputDim).fill(0);
for (const j of neighbors) {
const neighborFeatures = nodeFeatures[j] ?? [];
for (let k = 0; k < this.config.inputDim; k++) {
neighborSum[k] += neighborFeatures[k] ?? 0;
}
}
// GIN update: h_v = MLP((1 + eps) * h_v + sum(h_u))
const selfFeatures = nodeFeatures[i] ?? [];
const combined = new Array(this.config.inputDim).fill(0);
for (let k = 0; k < this.config.inputDim; k++) {
combined[k] = (1 + eps) * (selfFeatures[k] ?? 0) + neighborSum[k];
}
// MLP (2-layer)
const hidden = this.mlpLayer1(combined);
const output = this.mlpLayer2(hidden);
outputFeatures.push(this.applyDropout(output));
}
return {
nodeEmbeddings: outputFeatures,
graphEmbedding: this.aggregateSum(outputFeatures), // Sum pooling for graph classification
stats: this.createStats(startTime, numNodes, numEdges),
};
}
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
const graph: GraphData = {
nodeFeatures: nodes.features,
edgeIndex: [
edges.sources.map((s) => nodes.ids.indexOf(s)),
edges.targets.map((t) => nodes.ids.indexOf(t)),
],
};
const output = await this.forward(graph);
return {
ids: nodes.ids,
features: output.nodeEmbeddings,
types: nodes.types,
labels: nodes.labels,
};
}
private getNeighbors(
nodeIdx: number,
edgeIndex: [number[], number[]],
numNodes: number
): number[] {
const neighbors = new Set<number>();
const [sources, targets] = edgeIndex;
for (let i = 0; i < sources.length; i++) {
if (sources[i] === nodeIdx && targets[i] < numNodes) {
neighbors.add(targets[i]);
}
if (targets[i] === nodeIdx && sources[i] < numNodes) {
neighbors.add(sources[i]);
}
}
return Array.from(neighbors);
}
private mlpLayer1(input: number[]): number[] {
const hiddenDim = this.config.hiddenDim ?? this.config.inputDim;
const output = new Array(hiddenDim).fill(0);
for (let i = 0; i < hiddenDim; i++) {
for (let j = 0; j < input.length; j++) {
const weight = Math.sin((i * input.length + j) * 0.1) * 0.5;
output[i] += (input[j] ?? 0) * weight;
}
output[i] = this.applyActivation(output[i]);
}
return output;
}
private mlpLayer2(input: number[]): number[] {
const output = new Array(this.config.outputDim).fill(0);
for (let i = 0; i < this.config.outputDim; i++) {
for (let j = 0; j < input.length; j++) {
const weight = Math.cos((i * input.length + j) * 0.1) * 0.5;
output[i] += (input[j] ?? 0) * weight;
}
}
return output;
}
}
// ============================================================================
// MPNN Layer Implementation
// ============================================================================
/**
* Message Passing Neural Network (MPNN) layer.
*
* General framework for GNN with customizable message and update functions.
* Reference: Gilmer et al., "Neural Message Passing for Quantum Chemistry" (2017)
*/
export class MPNNLayer extends BaseGNNLayer {
async forward(graph: GraphData): Promise<GNNOutput> {
const startTime = Date.now();
const { nodeFeatures, edgeIndex, edgeFeatures } = graph;
const numNodes = nodeFeatures.length;
const numEdges = edgeIndex[0].length;
let currentFeatures = nodeFeatures.map((f) => [...f]);
// Multiple rounds of message passing
const numIterations = this.config.params?.numLayers ?? 1;
for (let t = 0; t < numIterations; t++) {
const newFeatures: number[][] = [];
for (let i = 0; i < numNodes; i++) {
// Collect messages from neighbors
const messages: Message[] = [];
const [sources, targets] = edgeIndex;
for (let e = 0; e < sources.length; e++) {
if (targets[e] === i) {
const j = sources[e];
const edgeFeat = edgeFeatures?.[e];
const message = this.messageFunction(
currentFeatures[j] ?? [],
currentFeatures[i] ?? [],
edgeFeat
);
messages.push({
source: j,
target: i,
vector: message,
edgeFeatures: edgeFeat,
});
}
if (sources[e] === i) {
const j = targets[e];
const edgeFeat = edgeFeatures?.[e];
const message = this.messageFunction(
currentFeatures[j] ?? [],
currentFeatures[i] ?? [],
edgeFeat
);
messages.push({
source: j,
target: i,
vector: message,
edgeFeatures: edgeFeat,
});
}
}
// Aggregate messages
const aggregated = await this.aggregate(messages, this.config.aggregation ?? 'sum');
// Update node features
const updated = this.updateFunction(currentFeatures[i] ?? [], aggregated);
newFeatures.push(this.applyDropout(updated));
}
currentFeatures = newFeatures;
}
return {
nodeEmbeddings: currentFeatures,
graphEmbedding: this.aggregateMean(currentFeatures),
stats: this.createStats(startTime, numNodes, numEdges, numIterations),
};
}
async messagePass(nodes: NodeFeatures, edges: EdgeFeatures): Promise<NodeFeatures> {
const graph: GraphData = {
nodeFeatures: nodes.features,
edgeIndex: [
edges.sources.map((s) => nodes.ids.indexOf(s)),
edges.targets.map((t) => nodes.ids.indexOf(t)),
],
edgeFeatures: edges.features,
};
const output = await this.forward(graph);
return {
ids: nodes.ids,
features: output.nodeEmbeddings,
types: nodes.types,
labels: nodes.labels,
};
}
private messageFunction(
sourceFeatures: number[],
targetFeatures: number[],
edgeFeatures?: number[]
): number[] {
const dim = this.config.inputDim;
const message = new Array(dim).fill(0);
for (let i = 0; i < dim; i++) {
message[i] = (sourceFeatures[i] ?? 0) * 0.5 + (targetFeatures[i] ?? 0) * 0.3;
if (edgeFeatures && edgeFeatures[i] !== undefined) {
message[i] += edgeFeatures[i] * 0.2;
}
}
return message;
}
private updateFunction(nodeFeatures: number[], aggregated: number[]): number[] {
const output = new Array(this.config.outputDim).fill(0);
// GRU-like update
for (let i = 0; i < this.config.outputDim; i++) {
const nodeVal = nodeFeatures[i % nodeFeatures.length] ?? 0;
const aggVal = aggregated[i % aggregated.length] ?? 0;
const gate = 1 / (1 + Math.exp(-(nodeVal + aggVal)));
output[i] = this.applyActivation(gate * aggVal + (1 - gate) * nodeVal);
}
return output;
}
}
// ============================================================================
// EdgeConv Layer Implementation
// ============================================================================
/**
* EdgeConv layer for dynamic graph convolution.
*
* Uses k-NN graph construction and edge features.
* Reference: Wang et al., "Dynamic Graph CNN for Learning on Point Clouds" (2019)
*/
export class EdgeConvLayer extends BaseGNNLayer {
async forward(graph: GraphData): Promise<GNNOutput> {
const startTime = Date.now();
const { nodeFeatures } = graph;
const numNodes = nodeFeatures.length;
const k = this.config.params?.k ?? 20;
const dynamic = this.config.params?.dynamic ?? true;
// Build k-NN graph
const knnGraph = dynamic
? this.buildKNNGraph(nodeFeatures, k)
: graph.edgeIndex;
const outputFeatures: number[][] = [];
for (let i = 0; i < numNodes; i++) {
const neighbors = this.getKNNNeighbors(i, knnGraph);
const selfFeatures = nodeFeatures[i] ?? [];
// Edge features: (x_j - x_i) || x_i
const edgeFeatures: number[][] = [];
for (const j of neighbors) {
const neighborFeatures = nodeFeatures[j] ?? [];
const diff = selfFeatures.map((v, idx) => (neighborFeatures[idx] ?? 0) - v);
edgeFeatures.push([...diff, ...selfFeatures]);
}
// Max pooling over edge features
const pooled = this.maxPoolEdges(edgeFeatures);
// MLP on pooled features
const output = this.edgeMLP(pooled);
outputFeatures.push(this.applyDropout(output));
}
return {
nodeEmbeddings: outputFeatures,
graphEmbedding: this.aggregateMean(outputFeatures),
stats: this.createStats(startTime, numNodes, numN