ruvector-attention-wasm
Version:
High-performance attention mechanisms for WebAssembly - Transformer, Hyperbolic, Flash, MoE, and Graph attention
430 lines (425 loc) • 14 kB
TypeScript
/* tslint:disable */
/* eslint-disable */
/**
* Initialize the WASM module with panic hook
*/
export function init(): void;
/**
* Get the version of the ruvector-attention-wasm crate
*/
export function version(): string;
/**
* Get information about available attention mechanisms
*/
export function available_mechanisms(): any;
/**
* Compute attention weights from scores
*/
export function attention_weights(scores: Float32Array, temperature?: number | null): void;
/**
* Compute cosine similarity between two vectors
*/
export function cosine_similarity(a: Float32Array, b: Float32Array): number;
/**
* Compute pairwise distances between vectors
*/
export function pairwise_distances(vectors: any): Float32Array;
/**
* Generate random orthogonal matrix (for initialization)
*/
export function random_orthogonal_matrix(dim: number): Float32Array;
/**
* Log an error to the browser console
*/
export function log_error(message: string): void;
/**
* Compute L2 norm of a vector
*/
export function l2_norm(vec: Float32Array): number;
/**
* Log a message to the browser console
*/
export function log(message: string): void;
/**
* Normalize a vector to unit length
*/
export function normalize(vec: Float32Array): void;
/**
* Batch normalize vectors
*/
export function batch_normalize(vectors: any, epsilon?: number | null): Float32Array;
/**
* Compute softmax of a vector
*/
export function softmax(vec: Float32Array): void;
/**
* Compute scaled dot-product attention
*
* # Arguments
* * `query` - Query vector as Float32Array
* * `keys` - Array of key vectors
* * `values` - Array of value vectors
* * `scale` - Optional scaling factor (defaults to 1/sqrt(dim))
*/
export function scaled_dot_attention(query: Float32Array, keys: any, values: any, _scale?: number | null): Float32Array;
/**
* Adam optimizer
*/
export class WasmAdam {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new Adam optimizer
*
* # Arguments
* * `param_count` - Number of parameters
* * `learning_rate` - Learning rate
*/
constructor(param_count: number, learning_rate: number);
/**
* Perform optimization step
*
* # Arguments
* * `params` - Current parameter values (will be updated in-place)
* * `gradients` - Gradient values
*/
step(params: Float32Array, gradients: Float32Array): void;
/**
* Reset optimizer state
*/
reset(): void;
/**
* Get current learning rate
*/
learning_rate: number;
}
/**
* AdamW optimizer (Adam with decoupled weight decay)
*/
export class WasmAdamW {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new AdamW optimizer
*
* # Arguments
* * `param_count` - Number of parameters
* * `learning_rate` - Learning rate
* * `weight_decay` - Weight decay coefficient
*/
constructor(param_count: number, learning_rate: number, weight_decay: number);
/**
* Perform optimization step with weight decay
*/
step(params: Float32Array, gradients: Float32Array): void;
/**
* Reset optimizer state
*/
reset(): void;
/**
* Get weight decay
*/
readonly weight_decay: number;
/**
* Get current learning rate
*/
learning_rate: number;
}
/**
* Flash attention mechanism
*/
export class WasmFlashAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new flash attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `block_size` - Block size for tiling
*/
constructor(dim: number, block_size: number);
/**
* Compute flash attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
}
/**
* Hyperbolic attention mechanism
*/
export class WasmHyperbolicAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new hyperbolic attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `curvature` - Hyperbolic curvature parameter
*/
constructor(dim: number, curvature: number);
/**
* Compute hyperbolic attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
/**
* Get the curvature
*/
readonly curvature: number;
}
/**
* InfoNCE contrastive loss for training
*/
export class WasmInfoNCELoss {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new InfoNCE loss instance
*
* # Arguments
* * `temperature` - Temperature parameter for softmax
*/
constructor(temperature: number);
/**
* Compute InfoNCE loss
*
* # Arguments
* * `anchor` - Anchor embedding
* * `positive` - Positive example embedding
* * `negatives` - Array of negative example embeddings
*/
compute(anchor: Float32Array, positive: Float32Array, negatives: any): number;
}
/**
* Learning rate scheduler
*/
export class WasmLRScheduler {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new learning rate scheduler with warmup and cosine decay
*
* # Arguments
* * `initial_lr` - Initial learning rate
* * `warmup_steps` - Number of warmup steps
* * `total_steps` - Total training steps
*/
constructor(initial_lr: number, warmup_steps: number, total_steps: number);
/**
* Advance to next step
*/
step(): void;
/**
* Reset scheduler
*/
reset(): void;
/**
* Get learning rate for current step
*/
get_lr(): number;
}
/**
* Linear attention (Performer-style)
*/
export class WasmLinearAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new linear attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `num_features` - Number of random features
*/
constructor(dim: number, num_features: number);
/**
* Compute linear attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
}
/**
* Local-global attention mechanism
*/
export class WasmLocalGlobalAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new local-global attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `local_window` - Size of local attention window
* * `global_tokens` - Number of global attention tokens
*/
constructor(dim: number, local_window: number, global_tokens: number);
/**
* Compute local-global attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
}
/**
* Mixture of Experts (MoE) attention
*/
export class WasmMoEAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new MoE attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `num_experts` - Number of expert attention mechanisms
* * `top_k` - Number of experts to use per query
*/
constructor(dim: number, num_experts: number, top_k: number);
/**
* Compute MoE attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
}
/**
* Multi-head attention mechanism
*/
export class WasmMultiHeadAttention {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new multi-head attention instance
*
* # Arguments
* * `dim` - Embedding dimension
* * `num_heads` - Number of attention heads
*/
constructor(dim: number, num_heads: number);
/**
* Compute multi-head attention
*/
compute(query: Float32Array, keys: any, values: any): Float32Array;
/**
* Get the dimension
*/
readonly dim: number;
/**
* Get the number of heads
*/
readonly num_heads: number;
}
/**
* SGD optimizer with momentum
*/
export class WasmSGD {
free(): void;
[Symbol.dispose](): void;
/**
* Create a new SGD optimizer
*
* # Arguments
* * `param_count` - Number of parameters
* * `learning_rate` - Learning rate
* * `momentum` - Momentum coefficient (default: 0)
*/
constructor(param_count: number, learning_rate: number, momentum?: number | null);
/**
* Perform optimization step
*/
step(params: Float32Array, gradients: Float32Array): void;
/**
* Reset optimizer state
*/
reset(): void;
/**
* Get current learning rate
*/
learning_rate: number;
}
export type InitInput = RequestInfo | URL | Response | BufferSource | WebAssembly.Module;
export interface InitOutput {
readonly memory: WebAssembly.Memory;
readonly __wbg_wasmadam_free: (a: number, b: number) => void;
readonly __wbg_wasmadamw_free: (a: number, b: number) => void;
readonly __wbg_wasmflashattention_free: (a: number, b: number) => void;
readonly __wbg_wasmhyperbolicattention_free: (a: number, b: number) => void;
readonly __wbg_wasminfonceloss_free: (a: number, b: number) => void;
readonly __wbg_wasmlinearattention_free: (a: number, b: number) => void;
readonly __wbg_wasmmoeattention_free: (a: number, b: number) => void;
readonly __wbg_wasmmultiheadattention_free: (a: number, b: number) => void;
readonly __wbg_wasmsgd_free: (a: number, b: number) => void;
readonly attention_weights: (a: number, b: number, c: number, d: number) => void;
readonly available_mechanisms: () => number;
readonly batch_normalize: (a: number, b: number, c: number) => void;
readonly cosine_similarity: (a: number, b: number, c: number, d: number, e: number) => void;
readonly l2_norm: (a: number, b: number) => number;
readonly log: (a: number, b: number) => void;
readonly log_error: (a: number, b: number) => void;
readonly normalize: (a: number, b: number, c: number, d: number) => void;
readonly pairwise_distances: (a: number, b: number) => void;
readonly random_orthogonal_matrix: (a: number, b: number) => void;
readonly scaled_dot_attention: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
readonly softmax: (a: number, b: number, c: number) => void;
readonly version: (a: number) => void;
readonly wasmadam_learning_rate: (a: number) => number;
readonly wasmadam_new: (a: number, b: number) => number;
readonly wasmadam_reset: (a: number) => void;
readonly wasmadam_set_learning_rate: (a: number, b: number) => void;
readonly wasmadam_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
readonly wasmadamw_new: (a: number, b: number, c: number) => number;
readonly wasmadamw_reset: (a: number) => void;
readonly wasmadamw_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
readonly wasmadamw_weight_decay: (a: number) => number;
readonly wasmflashattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
readonly wasmflashattention_new: (a: number, b: number) => number;
readonly wasmhyperbolicattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
readonly wasmhyperbolicattention_curvature: (a: number) => number;
readonly wasmhyperbolicattention_new: (a: number, b: number) => number;
readonly wasminfonceloss_compute: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => void;
readonly wasminfonceloss_new: (a: number) => number;
readonly wasmlinearattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
readonly wasmlinearattention_new: (a: number, b: number) => number;
readonly wasmlocalglobalattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
readonly wasmlocalglobalattention_new: (a: number, b: number, c: number) => number;
readonly wasmlrscheduler_get_lr: (a: number) => number;
readonly wasmlrscheduler_new: (a: number, b: number, c: number) => number;
readonly wasmlrscheduler_reset: (a: number) => void;
readonly wasmlrscheduler_step: (a: number) => void;
readonly wasmmoeattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
readonly wasmmoeattention_new: (a: number, b: number, c: number) => number;
readonly wasmmultiheadattention_compute: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
readonly wasmmultiheadattention_dim: (a: number) => number;
readonly wasmmultiheadattention_new: (a: number, b: number, c: number) => void;
readonly wasmmultiheadattention_num_heads: (a: number) => number;
readonly wasmsgd_learning_rate: (a: number) => number;
readonly wasmsgd_new: (a: number, b: number, c: number) => number;
readonly wasmsgd_reset: (a: number) => void;
readonly wasmsgd_set_learning_rate: (a: number, b: number) => void;
readonly wasmsgd_step: (a: number, b: number, c: number, d: number, e: number, f: number) => void;
readonly init: () => void;
readonly wasmadamw_set_learning_rate: (a: number, b: number) => void;
readonly wasmadamw_learning_rate: (a: number) => number;
readonly __wbg_wasmlocalglobalattention_free: (a: number, b: number) => void;
readonly __wbg_wasmlrscheduler_free: (a: number, b: number) => void;
readonly __wbindgen_export: (a: number, b: number) => number;
readonly __wbindgen_export2: (a: number, b: number, c: number, d: number) => number;
readonly __wbindgen_export3: (a: number) => void;
readonly __wbindgen_export4: (a: number, b: number, c: number) => void;
readonly __wbindgen_add_to_stack_pointer: (a: number) => number;
readonly __wbindgen_start: () => void;
}
export type SyncInitInput = BufferSource | WebAssembly.Module;
/**
* Instantiates the given `module`, which can either be bytes or
* a precompiled `WebAssembly.Module`.
*
* @param {{ module: SyncInitInput }} module - Passing `SyncInitInput` directly is deprecated.
*
* @returns {InitOutput}
*/
export function initSync(module: { module: SyncInitInput } | SyncInitInput): InitOutput;
/**
* If `module_or_path` is {RequestInfo} or {URL}, makes a request and
* for everything else, calls `WebAssembly.instantiate` directly.
*
* @param {{ module_or_path: InitInput | Promise<InitInput> }} module_or_path - Passing `InitInput` directly is deprecated.
*
* @returns {Promise<InitOutput>}
*/
export default function __wbg_init (module_or_path?: { module_or_path: InitInput | Promise<InitInput> } | InitInput | Promise<InitInput>): Promise<InitOutput>;