UNPKG

@ruvector/attention-wasm

Version:

WebAssembly bindings for ruvector-attention - high-performance attention mechanisms

413 lines (341 loc) 8.74 kB
/** * TypeScript wrapper for ruvector-attention-wasm * Provides a clean, type-safe API for attention mechanisms */ import init, * as wasm from '../pkg/ruvector_attention_wasm'; import type { AttentionConfig, MultiHeadConfig, HyperbolicConfig, LinearAttentionConfig, FlashAttentionConfig, LocalGlobalConfig, MoEConfig, TrainingConfig, SchedulerConfig, ExpertStats, AttentionType, } from './types'; export * from './types'; let initialized = false; /** * Initialize the WASM module * Must be called before using any attention mechanisms */ export async function initialize(): Promise<void> { if (!initialized) { await init(); initialized = true; } } /** * Get the version of the ruvector-attention-wasm package */ export function version(): string { return wasm.version(); } /** * Get list of available attention mechanisms */ export function availableMechanisms(): AttentionType[] { return wasm.available_mechanisms() as AttentionType[]; } /** * Multi-head attention mechanism */ export class MultiHeadAttention { private inner: wasm.WasmMultiHeadAttention; constructor(config: MultiHeadConfig) { this.inner = new wasm.WasmMultiHeadAttention(config.dim, config.numHeads); } /** * Compute multi-head attention */ compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array { const result = this.inner.compute(query, keys, values); return new Float32Array(result); } get numHeads(): number { return this.inner.num_heads; } get dim(): number { return this.inner.dim; } free(): void { this.inner.free(); } } /** * Hyperbolic attention mechanism */ export class HyperbolicAttention { private inner: wasm.WasmHyperbolicAttention; constructor(config: HyperbolicConfig) { this.inner = new wasm.WasmHyperbolicAttention(config.dim, config.curvature); } compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array { const result = this.inner.compute(query, keys, values); return new Float32Array(result); } get curvature(): number { return this.inner.curvature; } free(): void { this.inner.free(); } } /** * Linear attention (Performer-style) */ export class LinearAttention { private inner: wasm.WasmLinearAttention; constructor(config: LinearAttentionConfig) { this.inner = new wasm.WasmLinearAttention(config.dim, config.numFeatures); } compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array { const result = this.inner.compute(query, keys, values); return new Float32Array(result); } free(): void { this.inner.free(); } } /** * Flash attention mechanism */ export class FlashAttention { private inner: wasm.WasmFlashAttention; constructor(config: FlashAttentionConfig) { this.inner = new wasm.WasmFlashAttention(config.dim, config.blockSize); } compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array { const result = this.inner.compute(query, keys, values); return new Float32Array(result); } free(): void { this.inner.free(); } } /** * Local-global attention mechanism */ export class LocalGlobalAttention { private inner: wasm.WasmLocalGlobalAttention; constructor(config: LocalGlobalConfig) { this.inner = new wasm.WasmLocalGlobalAttention( config.dim, config.localWindow, config.globalTokens ); } compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array { const result = this.inner.compute(query, keys, values); return new Float32Array(result); } free(): void { this.inner.free(); } } /** * Mixture of Experts attention */ export class MoEAttention { private inner: wasm.WasmMoEAttention; constructor(config: MoEConfig) { this.inner = new wasm.WasmMoEAttention(config.dim, config.numExperts, config.topK); } compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array { const result = this.inner.compute(query, keys, values); return new Float32Array(result); } getExpertStats(): ExpertStats { return this.inner.expert_stats() as ExpertStats; } free(): void { this.inner.free(); } } /** * InfoNCE contrastive loss */ export class InfoNCELoss { private inner: wasm.WasmInfoNCELoss; constructor(temperature: number = 0.07) { this.inner = new wasm.WasmInfoNCELoss(temperature); } compute(anchor: Float32Array, positive: Float32Array, negatives: Float32Array[]): number { return this.inner.compute(anchor, positive, negatives); } computeMultiPositive( anchor: Float32Array, positives: Float32Array[], negatives: Float32Array[] ): number { return this.inner.compute_multi_positive(anchor, positives, negatives); } free(): void { this.inner.free(); } } /** * Adam optimizer */ export class Adam { private inner: wasm.WasmAdam; constructor(paramCount: number, config: TrainingConfig) { this.inner = new wasm.WasmAdam( paramCount, config.learningRate, config.beta1, config.beta2, config.epsilon ); } step(params: Float32Array, gradients: Float32Array): void { this.inner.step(params, gradients); } reset(): void { this.inner.reset(); } get learningRate(): number { return this.inner.learning_rate; } set learningRate(lr: number) { this.inner.learning_rate = lr; } free(): void { this.inner.free(); } } /** * AdamW optimizer (Adam with decoupled weight decay) */ export class AdamW { private inner: wasm.WasmAdamW; constructor(paramCount: number, config: TrainingConfig) { if (!config.weightDecay) { throw new Error('AdamW requires weightDecay parameter'); } this.inner = new wasm.WasmAdamW( paramCount, config.learningRate, config.weightDecay, config.beta1, config.beta2, config.epsilon ); } step(params: Float32Array, gradients: Float32Array): void { this.inner.step(params, gradients); } reset(): void { this.inner.reset(); } get learningRate(): number { return this.inner.learning_rate; } set learningRate(lr: number) { this.inner.learning_rate = lr; } get weightDecay(): number { return this.inner.weight_decay; } free(): void { this.inner.free(); } } /** * Learning rate scheduler with warmup and cosine decay */ export class LRScheduler { private inner: wasm.WasmLRScheduler; constructor(config: SchedulerConfig) { this.inner = new wasm.WasmLRScheduler( config.initialLR, config.warmupSteps, config.totalSteps ); } getLR(): number { return this.inner.get_lr(); } step(): void { this.inner.step(); } reset(): void { this.inner.reset(); } free(): void { this.inner.free(); } } /** * Utility functions */ export const utils = { /** * Compute cosine similarity between two vectors */ cosineSimilarity(a: Float32Array, b: Float32Array): number { return wasm.cosine_similarity(a, b); }, /** * Compute L2 norm of a vector */ l2Norm(vec: Float32Array): number { return wasm.l2_norm(vec); }, /** * Normalize a vector to unit length (in-place) */ normalize(vec: Float32Array): void { wasm.normalize(vec); }, /** * Apply softmax to a vector (in-place) */ softmax(vec: Float32Array): void { wasm.softmax(vec); }, /** * Compute attention weights from scores (in-place) */ attentionWeights(scores: Float32Array, temperature?: number): void { wasm.attention_weights(scores, temperature); }, /** * Batch normalize vectors */ batchNormalize(vectors: Float32Array[], epsilon?: number): Float32Array { const result = wasm.batch_normalize(vectors, epsilon); return new Float32Array(result); }, /** * Generate random orthogonal matrix */ randomOrthogonalMatrix(dim: number): Float32Array { const result = wasm.random_orthogonal_matrix(dim); return new Float32Array(result); }, /** * Compute pairwise distances between vectors */ pairwiseDistances(vectors: Float32Array[]): Float32Array { const result = wasm.pairwise_distances(vectors); return new Float32Array(result); }, }; /** * Simple scaled dot-product attention (functional API) */ export function scaledDotAttention( query: Float32Array, keys: Float32Array[], values: Float32Array[], scale?: number ): Float32Array { const result = wasm.scaled_dot_attention(query, keys, values, scale); return new Float32Array(result); } // Re-export WASM module for advanced usage export { wasm };