UNPKG

ruvector-attention-wasm

Version:

High-performance attention mechanisms for WebAssembly - Transformer, Hyperbolic, Flash, MoE, and Graph attention

1,497 lines (1,393 loc) 50.6 kB
let wasm; let cachedUint8ArrayMemory0 = null; function getUint8ArrayMemory0() { if (cachedUint8ArrayMemory0 === null || cachedUint8ArrayMemory0.byteLength === 0) { cachedUint8ArrayMemory0 = new Uint8Array(wasm.memory.buffer); } return cachedUint8ArrayMemory0; } let cachedTextDecoder = new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }); cachedTextDecoder.decode(); const MAX_SAFARI_DECODE_BYTES = 2146435072; let numBytesDecoded = 0; function decodeText(ptr, len) { numBytesDecoded += len; if (numBytesDecoded >= MAX_SAFARI_DECODE_BYTES) { cachedTextDecoder = new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }); cachedTextDecoder.decode(); numBytesDecoded = len; } return cachedTextDecoder.decode(getUint8ArrayMemory0().subarray(ptr, ptr + len)); } function getStringFromWasm0(ptr, len) { ptr = ptr >>> 0; return decodeText(ptr, len); } let heap = new Array(128).fill(undefined); heap.push(undefined, null, true, false); let heap_next = heap.length; function addHeapObject(obj) { if (heap_next === heap.length) heap.push(heap.length + 1); const idx = heap_next; heap_next = heap[idx]; heap[idx] = obj; return idx; } function getObject(idx) { return heap[idx]; } let WASM_VECTOR_LEN = 0; const cachedTextEncoder = new TextEncoder(); if (!('encodeInto' in cachedTextEncoder)) { cachedTextEncoder.encodeInto = function (arg, view) { const buf = cachedTextEncoder.encode(arg); view.set(buf); return { read: arg.length, written: buf.length }; } } function passStringToWasm0(arg, malloc, realloc) { if (realloc === undefined) { const buf = cachedTextEncoder.encode(arg); const ptr = malloc(buf.length, 1) >>> 0; getUint8ArrayMemory0().subarray(ptr, ptr + buf.length).set(buf); WASM_VECTOR_LEN = buf.length; return ptr; } let len = arg.length; let ptr = malloc(len, 1) >>> 0; const mem = getUint8ArrayMemory0(); let offset = 0; for (; offset < len; offset++) { const code = arg.charCodeAt(offset); if (code > 0x7F) break; mem[ptr + offset] = code; } if (offset !== len) { if (offset !== 0) { arg = arg.slice(offset); } ptr = realloc(ptr, len, len = offset + arg.length * 3, 1) >>> 0; const view = getUint8ArrayMemory0().subarray(ptr + offset, ptr + len); const ret = cachedTextEncoder.encodeInto(arg, view); offset += ret.written; ptr = realloc(ptr, len, offset, 1) >>> 0; } WASM_VECTOR_LEN = offset; return ptr; } let cachedDataViewMemory0 = null; function getDataViewMemory0() { if (cachedDataViewMemory0 === null || cachedDataViewMemory0.buffer.detached === true || (cachedDataViewMemory0.buffer.detached === undefined && cachedDataViewMemory0.buffer !== wasm.memory.buffer)) { cachedDataViewMemory0 = new DataView(wasm.memory.buffer); } return cachedDataViewMemory0; } function isLikeNone(x) { return x === undefined || x === null; } function getArrayU8FromWasm0(ptr, len) { ptr = ptr >>> 0; return getUint8ArrayMemory0().subarray(ptr / 1, ptr / 1 + len); } function debugString(val) { // primitive types const type = typeof val; if (type == 'number' || type == 'boolean' || val == null) { return `${val}`; } if (type == 'string') { return `"${val}"`; } if (type == 'symbol') { const description = val.description; if (description == null) { return 'Symbol'; } else { return `Symbol(${description})`; } } if (type == 'function') { const name = val.name; if (typeof name == 'string' && name.length > 0) { return `Function(${name})`; } else { return 'Function'; } } // objects if (Array.isArray(val)) { const length = val.length; let debug = '['; if (length > 0) { debug += debugString(val[0]); } for(let i = 1; i < length; i++) { debug += ', ' + debugString(val[i]); } debug += ']'; return debug; } // Test for built-in const builtInMatches = /\[object ([^\]]+)\]/.exec(toString.call(val)); let className; if (builtInMatches && builtInMatches.length > 1) { className = builtInMatches[1]; } else { // Failed to match the standard '[object ClassName]' return toString.call(val); } if (className == 'Object') { // we're a user defined class or Object // JSON.stringify avoids problems with cycles, and is generally much // easier than looping through ownProperties of `val`. try { return 'Object(' + JSON.stringify(val) + ')'; } catch (_) { return 'Object'; } } // errors if (val instanceof Error) { return `${val.name}: ${val.message}\n${val.stack}`; } // TODO we could test for more things here, like `Set`s and `Map`s. return className; } function handleError(f, args) { try { return f.apply(this, args); } catch (e) { wasm.__wbindgen_export3(addHeapObject(e)); } } function dropObject(idx) { if (idx < 132) return; heap[idx] = heap_next; heap_next = idx; } function takeObject(idx) { const ret = getObject(idx); dropObject(idx); return ret; } /** * Initialize the WASM module with panic hook */ export function init() { wasm.init(); } /** * Get the version of the ruvector-attention-wasm crate * @returns {string} */ export function version() { let deferred1_0; let deferred1_1; try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); wasm.version(retptr); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); deferred1_0 = r0; deferred1_1 = r1; return getStringFromWasm0(r0, r1); } finally { wasm.__wbindgen_add_to_stack_pointer(16); wasm.__wbindgen_export4(deferred1_0, deferred1_1, 1); } } /** * Get information about available attention mechanisms * @returns {any} */ export function available_mechanisms() { const ret = wasm.available_mechanisms(); return takeObject(ret); } let cachedFloat32ArrayMemory0 = null; function getFloat32ArrayMemory0() { if (cachedFloat32ArrayMemory0 === null || cachedFloat32ArrayMemory0.byteLength === 0) { cachedFloat32ArrayMemory0 = new Float32Array(wasm.memory.buffer); } return cachedFloat32ArrayMemory0; } function passArrayF32ToWasm0(arg, malloc) { const ptr = malloc(arg.length * 4, 4) >>> 0; getFloat32ArrayMemory0().set(arg, ptr / 4); WASM_VECTOR_LEN = arg.length; return ptr; } /** * Compute attention weights from scores * @param {Float32Array} scores * @param {number | null} [temperature] */ export function attention_weights(scores, temperature) { var ptr0 = passArrayF32ToWasm0(scores, wasm.__wbindgen_export); var len0 = WASM_VECTOR_LEN; wasm.attention_weights(ptr0, len0, addHeapObject(scores), isLikeNone(temperature) ? 0x100000001 : Math.fround(temperature)); } /** * Compute cosine similarity between two vectors * @param {Float32Array} a * @param {Float32Array} b * @returns {number} */ export function cosine_similarity(a, b) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); const ptr0 = passArrayF32ToWasm0(a, wasm.__wbindgen_export); const len0 = WASM_VECTOR_LEN; const ptr1 = passArrayF32ToWasm0(b, wasm.__wbindgen_export); const len1 = WASM_VECTOR_LEN; wasm.cosine_similarity(retptr, ptr0, len0, ptr1, len1); var r0 = getDataViewMemory0().getFloat32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); if (r2) { throw takeObject(r1); } return r0; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } function getArrayF32FromWasm0(ptr, len) { ptr = ptr >>> 0; return getFloat32ArrayMemory0().subarray(ptr / 4, ptr / 4 + len); } /** * Compute pairwise distances between vectors * @param {any} vectors * @returns {Float32Array} */ export function pairwise_distances(vectors) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); wasm.pairwise_distances(retptr, addHeapObject(vectors)); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); var r3 = getDataViewMemory0().getInt32(retptr + 4 * 3, true); if (r3) { throw takeObject(r2); } var v1 = getArrayF32FromWasm0(r0, r1).slice(); wasm.__wbindgen_export4(r0, r1 * 4, 4); return v1; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } /** * Generate random orthogonal matrix (for initialization) * @param {number} dim * @returns {Float32Array} */ export function random_orthogonal_matrix(dim) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); wasm.random_orthogonal_matrix(retptr, dim); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var v1 = getArrayF32FromWasm0(r0, r1).slice(); wasm.__wbindgen_export4(r0, r1 * 4, 4); return v1; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } /** * Log an error to the browser console * @param {string} message */ export function log_error(message) { const ptr0 = passStringToWasm0(message, wasm.__wbindgen_export, wasm.__wbindgen_export2); const len0 = WASM_VECTOR_LEN; wasm.log_error(ptr0, len0); } /** * Compute L2 norm of a vector * @param {Float32Array} vec * @returns {number} */ export function l2_norm(vec) { const ptr0 = passArrayF32ToWasm0(vec, wasm.__wbindgen_export); const len0 = WASM_VECTOR_LEN; const ret = wasm.l2_norm(ptr0, len0); return ret; } /** * Log a message to the browser console * @param {string} message */ export function log(message) { const ptr0 = passStringToWasm0(message, wasm.__wbindgen_export, wasm.__wbindgen_export2); const len0 = WASM_VECTOR_LEN; wasm.log(ptr0, len0); } /** * Normalize a vector to unit length * @param {Float32Array} vec */ export function normalize(vec) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); var ptr0 = passArrayF32ToWasm0(vec, wasm.__wbindgen_export); var len0 = WASM_VECTOR_LEN; wasm.normalize(retptr, ptr0, len0, addHeapObject(vec)); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); if (r1) { throw takeObject(r0); } } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } /** * Batch normalize vectors * @param {any} vectors * @param {number | null} [epsilon] * @returns {Float32Array} */ export function batch_normalize(vectors, epsilon) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); wasm.batch_normalize(retptr, addHeapObject(vectors), isLikeNone(epsilon) ? 0x100000001 : Math.fround(epsilon)); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); var r3 = getDataViewMemory0().getInt32(retptr + 4 * 3, true); if (r3) { throw takeObject(r2); } var v1 = getArrayF32FromWasm0(r0, r1).slice(); wasm.__wbindgen_export4(r0, r1 * 4, 4); return v1; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } /** * Compute softmax of a vector * @param {Float32Array} vec */ export function softmax(vec) { var ptr0 = passArrayF32ToWasm0(vec, wasm.__wbindgen_export); var len0 = WASM_VECTOR_LEN; wasm.softmax(ptr0, len0, addHeapObject(vec)); } /** * Compute scaled dot-product attention * * # Arguments * * `query` - Query vector as Float32Array * * `keys` - Array of key vectors * * `values` - Array of value vectors * * `scale` - Optional scaling factor (defaults to 1/sqrt(dim)) * @param {Float32Array} query * @param {any} keys * @param {any} values * @param {number | null} [_scale] * @returns {Float32Array} */ export function scaled_dot_attention(query, keys, values, _scale) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); const ptr0 = passArrayF32ToWasm0(query, wasm.__wbindgen_export); const len0 = WASM_VECTOR_LEN; wasm.scaled_dot_attention(retptr, ptr0, len0, addHeapObject(keys), addHeapObject(values), isLikeNone(_scale) ? 0x100000001 : Math.fround(_scale)); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); var r3 = getDataViewMemory0().getInt32(retptr + 4 * 3, true); if (r3) { throw takeObject(r2); } var v2 = getArrayF32FromWasm0(r0, r1).slice(); wasm.__wbindgen_export4(r0, r1 * 4, 4); return v2; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } const WasmAdamFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasmadam_free(ptr >>> 0, 1)); /** * Adam optimizer */ export class WasmAdam { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmAdamFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasmadam_free(ptr, 0); } /** * Get current learning rate * @returns {number} */ get learning_rate() { const ret = wasm.wasmadam_learning_rate(this.__wbg_ptr); return ret; } /** * Set learning rate * @param {number} lr */ set learning_rate(lr) { wasm.wasmadam_set_learning_rate(this.__wbg_ptr, lr); } /** * Create a new Adam optimizer * * # Arguments * * `param_count` - Number of parameters * * `learning_rate` - Learning rate * @param {number} param_count * @param {number} learning_rate */ constructor(param_count, learning_rate) { const ret = wasm.wasmadam_new(param_count, learning_rate); this.__wbg_ptr = ret >>> 0; WasmAdamFinalization.register(this, this.__wbg_ptr, this); return this; } /** * Perform optimization step * * # Arguments * * `params` - Current parameter values (will be updated in-place) * * `gradients` - Gradient values * @param {Float32Array} params * @param {Float32Array} gradients */ step(params, gradients) { var ptr0 = passArrayF32ToWasm0(params, wasm.__wbindgen_export); var len0 = WASM_VECTOR_LEN; const ptr1 = passArrayF32ToWasm0(gradients, wasm.__wbindgen_export); const len1 = WASM_VECTOR_LEN; wasm.wasmadam_step(this.__wbg_ptr, ptr0, len0, addHeapObject(params), ptr1, len1); } /** * Reset optimizer state */ reset() { wasm.wasmadam_reset(this.__wbg_ptr); } } if (Symbol.dispose) WasmAdam.prototype[Symbol.dispose] = WasmAdam.prototype.free; const WasmAdamWFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasmadamw_free(ptr >>> 0, 1)); /** * AdamW optimizer (Adam with decoupled weight decay) */ export class WasmAdamW { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmAdamWFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasmadamw_free(ptr, 0); } /** * Get weight decay * @returns {number} */ get weight_decay() { const ret = wasm.wasmadamw_weight_decay(this.__wbg_ptr); return ret; } /** * Get current learning rate * @returns {number} */ get learning_rate() { const ret = wasm.wasmadam_learning_rate(this.__wbg_ptr); return ret; } /** * Set learning rate * @param {number} lr */ set learning_rate(lr) { wasm.wasmadam_set_learning_rate(this.__wbg_ptr, lr); } /** * Create a new AdamW optimizer * * # Arguments * * `param_count` - Number of parameters * * `learning_rate` - Learning rate * * `weight_decay` - Weight decay coefficient * @param {number} param_count * @param {number} learning_rate * @param {number} weight_decay */ constructor(param_count, learning_rate, weight_decay) { const ret = wasm.wasmadamw_new(param_count, learning_rate, weight_decay); this.__wbg_ptr = ret >>> 0; WasmAdamWFinalization.register(this, this.__wbg_ptr, this); return this; } /** * Perform optimization step with weight decay * @param {Float32Array} params * @param {Float32Array} gradients */ step(params, gradients) { var ptr0 = passArrayF32ToWasm0(params, wasm.__wbindgen_export); var len0 = WASM_VECTOR_LEN; const ptr1 = passArrayF32ToWasm0(gradients, wasm.__wbindgen_export); const len1 = WASM_VECTOR_LEN; wasm.wasmadamw_step(this.__wbg_ptr, ptr0, len0, addHeapObject(params), ptr1, len1); } /** * Reset optimizer state */ reset() { wasm.wasmadamw_reset(this.__wbg_ptr); } } if (Symbol.dispose) WasmAdamW.prototype[Symbol.dispose] = WasmAdamW.prototype.free; const WasmFlashAttentionFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasmflashattention_free(ptr >>> 0, 1)); /** * Flash attention mechanism */ export class WasmFlashAttention { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmFlashAttentionFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasmflashattention_free(ptr, 0); } /** * Create a new flash attention instance * * # Arguments * * `dim` - Embedding dimension * * `block_size` - Block size for tiling * @param {number} dim * @param {number} block_size */ constructor(dim, block_size) { const ret = wasm.wasmflashattention_new(dim, block_size); this.__wbg_ptr = ret >>> 0; WasmFlashAttentionFinalization.register(this, this.__wbg_ptr, this); return this; } /** * Compute flash attention * @param {Float32Array} query * @param {any} keys * @param {any} values * @returns {Float32Array} */ compute(query, keys, values) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); const ptr0 = passArrayF32ToWasm0(query, wasm.__wbindgen_export); const len0 = WASM_VECTOR_LEN; wasm.wasmflashattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values)); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); var r3 = getDataViewMemory0().getInt32(retptr + 4 * 3, true); if (r3) { throw takeObject(r2); } var v2 = getArrayF32FromWasm0(r0, r1).slice(); wasm.__wbindgen_export4(r0, r1 * 4, 4); return v2; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } } if (Symbol.dispose) WasmFlashAttention.prototype[Symbol.dispose] = WasmFlashAttention.prototype.free; const WasmHyperbolicAttentionFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasmhyperbolicattention_free(ptr >>> 0, 1)); /** * Hyperbolic attention mechanism */ export class WasmHyperbolicAttention { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmHyperbolicAttentionFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasmhyperbolicattention_free(ptr, 0); } /** * Create a new hyperbolic attention instance * * # Arguments * * `dim` - Embedding dimension * * `curvature` - Hyperbolic curvature parameter * @param {number} dim * @param {number} curvature */ constructor(dim, curvature) { const ret = wasm.wasmhyperbolicattention_new(dim, curvature); this.__wbg_ptr = ret >>> 0; WasmHyperbolicAttentionFinalization.register(this, this.__wbg_ptr, this); return this; } /** * Compute hyperbolic attention * @param {Float32Array} query * @param {any} keys * @param {any} values * @returns {Float32Array} */ compute(query, keys, values) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); const ptr0 = passArrayF32ToWasm0(query, wasm.__wbindgen_export); const len0 = WASM_VECTOR_LEN; wasm.wasmhyperbolicattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values)); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); var r3 = getDataViewMemory0().getInt32(retptr + 4 * 3, true); if (r3) { throw takeObject(r2); } var v2 = getArrayF32FromWasm0(r0, r1).slice(); wasm.__wbindgen_export4(r0, r1 * 4, 4); return v2; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } /** * Get the curvature * @returns {number} */ get curvature() { const ret = wasm.wasmhyperbolicattention_curvature(this.__wbg_ptr); return ret; } } if (Symbol.dispose) WasmHyperbolicAttention.prototype[Symbol.dispose] = WasmHyperbolicAttention.prototype.free; const WasmInfoNCELossFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasminfonceloss_free(ptr >>> 0, 1)); /** * InfoNCE contrastive loss for training */ export class WasmInfoNCELoss { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmInfoNCELossFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasminfonceloss_free(ptr, 0); } /** * Create a new InfoNCE loss instance * * # Arguments * * `temperature` - Temperature parameter for softmax * @param {number} temperature */ constructor(temperature) { const ret = wasm.wasminfonceloss_new(temperature); this.__wbg_ptr = ret >>> 0; WasmInfoNCELossFinalization.register(this, this.__wbg_ptr, this); return this; } /** * Compute InfoNCE loss * * # Arguments * * `anchor` - Anchor embedding * * `positive` - Positive example embedding * * `negatives` - Array of negative example embeddings * @param {Float32Array} anchor * @param {Float32Array} positive * @param {any} negatives * @returns {number} */ compute(anchor, positive, negatives) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); const ptr0 = passArrayF32ToWasm0(anchor, wasm.__wbindgen_export); const len0 = WASM_VECTOR_LEN; const ptr1 = passArrayF32ToWasm0(positive, wasm.__wbindgen_export); const len1 = WASM_VECTOR_LEN; wasm.wasminfonceloss_compute(retptr, this.__wbg_ptr, ptr0, len0, ptr1, len1, addHeapObject(negatives)); var r0 = getDataViewMemory0().getFloat32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); if (r2) { throw takeObject(r1); } return r0; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } } if (Symbol.dispose) WasmInfoNCELoss.prototype[Symbol.dispose] = WasmInfoNCELoss.prototype.free; const WasmLRSchedulerFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasmlrscheduler_free(ptr >>> 0, 1)); /** * Learning rate scheduler */ export class WasmLRScheduler { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmLRSchedulerFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasmlrscheduler_free(ptr, 0); } /** * Create a new learning rate scheduler with warmup and cosine decay * * # Arguments * * `initial_lr` - Initial learning rate * * `warmup_steps` - Number of warmup steps * * `total_steps` - Total training steps * @param {number} initial_lr * @param {number} warmup_steps * @param {number} total_steps */ constructor(initial_lr, warmup_steps, total_steps) { const ret = wasm.wasmlrscheduler_new(initial_lr, warmup_steps, total_steps); this.__wbg_ptr = ret >>> 0; WasmLRSchedulerFinalization.register(this, this.__wbg_ptr, this); return this; } /** * Advance to next step */ step() { wasm.wasmlrscheduler_step(this.__wbg_ptr); } /** * Reset scheduler */ reset() { wasm.wasmlrscheduler_reset(this.__wbg_ptr); } /** * Get learning rate for current step * @returns {number} */ get_lr() { const ret = wasm.wasmlrscheduler_get_lr(this.__wbg_ptr); return ret; } } if (Symbol.dispose) WasmLRScheduler.prototype[Symbol.dispose] = WasmLRScheduler.prototype.free; const WasmLinearAttentionFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasmlinearattention_free(ptr >>> 0, 1)); /** * Linear attention (Performer-style) */ export class WasmLinearAttention { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmLinearAttentionFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasmlinearattention_free(ptr, 0); } /** * Create a new linear attention instance * * # Arguments * * `dim` - Embedding dimension * * `num_features` - Number of random features * @param {number} dim * @param {number} num_features */ constructor(dim, num_features) { const ret = wasm.wasmlinearattention_new(dim, num_features); this.__wbg_ptr = ret >>> 0; WasmLinearAttentionFinalization.register(this, this.__wbg_ptr, this); return this; } /** * Compute linear attention * @param {Float32Array} query * @param {any} keys * @param {any} values * @returns {Float32Array} */ compute(query, keys, values) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); const ptr0 = passArrayF32ToWasm0(query, wasm.__wbindgen_export); const len0 = WASM_VECTOR_LEN; wasm.wasmlinearattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values)); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); var r3 = getDataViewMemory0().getInt32(retptr + 4 * 3, true); if (r3) { throw takeObject(r2); } var v2 = getArrayF32FromWasm0(r0, r1).slice(); wasm.__wbindgen_export4(r0, r1 * 4, 4); return v2; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } } if (Symbol.dispose) WasmLinearAttention.prototype[Symbol.dispose] = WasmLinearAttention.prototype.free; const WasmLocalGlobalAttentionFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasmlocalglobalattention_free(ptr >>> 0, 1)); /** * Local-global attention mechanism */ export class WasmLocalGlobalAttention { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmLocalGlobalAttentionFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasmlocalglobalattention_free(ptr, 0); } /** * Create a new local-global attention instance * * # Arguments * * `dim` - Embedding dimension * * `local_window` - Size of local attention window * * `global_tokens` - Number of global attention tokens * @param {number} dim * @param {number} local_window * @param {number} global_tokens */ constructor(dim, local_window, global_tokens) { const ret = wasm.wasmlocalglobalattention_new(dim, local_window, global_tokens); this.__wbg_ptr = ret >>> 0; WasmLocalGlobalAttentionFinalization.register(this, this.__wbg_ptr, this); return this; } /** * Compute local-global attention * @param {Float32Array} query * @param {any} keys * @param {any} values * @returns {Float32Array} */ compute(query, keys, values) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); const ptr0 = passArrayF32ToWasm0(query, wasm.__wbindgen_export); const len0 = WASM_VECTOR_LEN; wasm.wasmlocalglobalattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values)); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); var r3 = getDataViewMemory0().getInt32(retptr + 4 * 3, true); if (r3) { throw takeObject(r2); } var v2 = getArrayF32FromWasm0(r0, r1).slice(); wasm.__wbindgen_export4(r0, r1 * 4, 4); return v2; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } } if (Symbol.dispose) WasmLocalGlobalAttention.prototype[Symbol.dispose] = WasmLocalGlobalAttention.prototype.free; const WasmMoEAttentionFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasmmoeattention_free(ptr >>> 0, 1)); /** * Mixture of Experts (MoE) attention */ export class WasmMoEAttention { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmMoEAttentionFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasmmoeattention_free(ptr, 0); } /** * Create a new MoE attention instance * * # Arguments * * `dim` - Embedding dimension * * `num_experts` - Number of expert attention mechanisms * * `top_k` - Number of experts to use per query * @param {number} dim * @param {number} num_experts * @param {number} top_k */ constructor(dim, num_experts, top_k) { const ret = wasm.wasmmoeattention_new(dim, num_experts, top_k); this.__wbg_ptr = ret >>> 0; WasmMoEAttentionFinalization.register(this, this.__wbg_ptr, this); return this; } /** * Compute MoE attention * @param {Float32Array} query * @param {any} keys * @param {any} values * @returns {Float32Array} */ compute(query, keys, values) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); const ptr0 = passArrayF32ToWasm0(query, wasm.__wbindgen_export); const len0 = WASM_VECTOR_LEN; wasm.wasmmoeattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values)); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); var r3 = getDataViewMemory0().getInt32(retptr + 4 * 3, true); if (r3) { throw takeObject(r2); } var v2 = getArrayF32FromWasm0(r0, r1).slice(); wasm.__wbindgen_export4(r0, r1 * 4, 4); return v2; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } } if (Symbol.dispose) WasmMoEAttention.prototype[Symbol.dispose] = WasmMoEAttention.prototype.free; const WasmMultiHeadAttentionFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasmmultiheadattention_free(ptr >>> 0, 1)); /** * Multi-head attention mechanism */ export class WasmMultiHeadAttention { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmMultiHeadAttentionFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasmmultiheadattention_free(ptr, 0); } /** * Get the dimension * @returns {number} */ get dim() { const ret = wasm.wasmmultiheadattention_dim(this.__wbg_ptr); return ret >>> 0; } /** * Create a new multi-head attention instance * * # Arguments * * `dim` - Embedding dimension * * `num_heads` - Number of attention heads * @param {number} dim * @param {number} num_heads */ constructor(dim, num_heads) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); wasm.wasmmultiheadattention_new(retptr, dim, num_heads); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); if (r2) { throw takeObject(r1); } this.__wbg_ptr = r0 >>> 0; WasmMultiHeadAttentionFinalization.register(this, this.__wbg_ptr, this); return this; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } /** * Compute multi-head attention * @param {Float32Array} query * @param {any} keys * @param {any} values * @returns {Float32Array} */ compute(query, keys, values) { try { const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); const ptr0 = passArrayF32ToWasm0(query, wasm.__wbindgen_export); const len0 = WASM_VECTOR_LEN; wasm.wasmmultiheadattention_compute(retptr, this.__wbg_ptr, ptr0, len0, addHeapObject(keys), addHeapObject(values)); var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); var r3 = getDataViewMemory0().getInt32(retptr + 4 * 3, true); if (r3) { throw takeObject(r2); } var v2 = getArrayF32FromWasm0(r0, r1).slice(); wasm.__wbindgen_export4(r0, r1 * 4, 4); return v2; } finally { wasm.__wbindgen_add_to_stack_pointer(16); } } /** * Get the number of heads * @returns {number} */ get num_heads() { const ret = wasm.wasmmultiheadattention_num_heads(this.__wbg_ptr); return ret >>> 0; } } if (Symbol.dispose) WasmMultiHeadAttention.prototype[Symbol.dispose] = WasmMultiHeadAttention.prototype.free; const WasmSGDFinalization = (typeof FinalizationRegistry === 'undefined') ? { register: () => {}, unregister: () => {} } : new FinalizationRegistry(ptr => wasm.__wbg_wasmsgd_free(ptr >>> 0, 1)); /** * SGD optimizer with momentum */ export class WasmSGD { __destroy_into_raw() { const ptr = this.__wbg_ptr; this.__wbg_ptr = 0; WasmSGDFinalization.unregister(this); return ptr; } free() { const ptr = this.__destroy_into_raw(); wasm.__wbg_wasmsgd_free(ptr, 0); } /** * Get current learning rate * @returns {number} */ get learning_rate() { const ret = wasm.wasmsgd_learning_rate(this.__wbg_ptr); return ret; } /** * Set learning rate * @param {number} lr */ set learning_rate(lr) { wasm.wasmsgd_set_learning_rate(this.__wbg_ptr, lr); } /** * Create a new SGD optimizer * * # Arguments * * `param_count` - Number of parameters * * `learning_rate` - Learning rate * * `momentum` - Momentum coefficient (default: 0) * @param {number} param_count * @param {number} learning_rate * @param {number | null} [momentum] */ constructor(param_count, learning_rate, momentum) { const ret = wasm.wasmsgd_new(param_count, learning_rate, isLikeNone(momentum) ? 0x100000001 : Math.fround(momentum)); this.__wbg_ptr = ret >>> 0; WasmSGDFinalization.register(this, this.__wbg_ptr, this); return this; } /** * Perform optimization step * @param {Float32Array} params * @param {Float32Array} gradients */ step(params, gradients) { var ptr0 = passArrayF32ToWasm0(params, wasm.__wbindgen_export); var len0 = WASM_VECTOR_LEN; const ptr1 = passArrayF32ToWasm0(gradients, wasm.__wbindgen_export); const len1 = WASM_VECTOR_LEN; wasm.wasmsgd_step(this.__wbg_ptr, ptr0, len0, addHeapObject(params), ptr1, len1); } /** * Reset optimizer state */ reset() { wasm.wasmsgd_reset(this.__wbg_ptr); } } if (Symbol.dispose) WasmSGD.prototype[Symbol.dispose] = WasmSGD.prototype.free; const EXPECTED_RESPONSE_TYPES = new Set(['basic', 'cors', 'default']); async function __wbg_load(module, imports) { if (typeof Response === 'function' && module instanceof Response) { if (typeof WebAssembly.instantiateStreaming === 'function') { try { return await WebAssembly.instantiateStreaming(module, imports); } catch (e) { const validResponse = module.ok && EXPECTED_RESPONSE_TYPES.has(module.type); if (validResponse && module.headers.get('Content-Type') !== 'application/wasm') { console.warn("`WebAssembly.instantiateStreaming` failed because your server does not serve Wasm with `application/wasm` MIME type. Falling back to `WebAssembly.instantiate` which is slower. Original error:\n", e); } else { throw e; } } } const bytes = await module.arrayBuffer(); return await WebAssembly.instantiate(bytes, imports); } else { const instance = await WebAssembly.instantiate(module, imports); if (instance instanceof WebAssembly.Instance) { return { instance, module }; } else { return instance; } } } function __wbg_get_imports() { const imports = {}; imports.wbg = {}; imports.wbg.__wbg_Error_e83987f665cf5504 = function(arg0, arg1) { const ret = Error(getStringFromWasm0(arg0, arg1)); return addHeapObject(ret); }; imports.wbg.__wbg_String_8f0eb39a4a4c2f66 = function(arg0, arg1) { const ret = String(getObject(arg1)); const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_export, wasm.__wbindgen_export2); const len1 = WASM_VECTOR_LEN; getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); }; imports.wbg.__wbg___wbindgen_boolean_get_6d5a1ee65bab5f68 = function(arg0) { const v = getObject(arg0); const ret = typeof(v) === 'boolean' ? v : undefined; return isLikeNone(ret) ? 0xFFFFFF : ret ? 1 : 0; }; imports.wbg.__wbg___wbindgen_copy_to_typed_array_33fbd71146904370 = function(arg0, arg1, arg2) { new Uint8Array(getObject(arg2).buffer, getObject(arg2).byteOffset, getObject(arg2).byteLength).set(getArrayU8FromWasm0(arg0, arg1)); }; imports.wbg.__wbg___wbindgen_debug_string_df47ffb5e35e6763 = function(arg0, arg1) { const ret = debugString(getObject(arg1)); const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_export, wasm.__wbindgen_export2); const len1 = WASM_VECTOR_LEN; getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); }; imports.wbg.__wbg___wbindgen_is_function_ee8a6c5833c90377 = function(arg0) { const ret = typeof(getObject(arg0)) === 'function'; return ret; }; imports.wbg.__wbg___wbindgen_is_object_c818261d21f283a4 = function(arg0) { const val = getObject(arg0); const ret = typeof(val) === 'object' && val !== null; return ret; }; imports.wbg.__wbg___wbindgen_jsval_loose_eq_b664b38a2f582147 = function(arg0, arg1) { const ret = getObject(arg0) == getObject(arg1); return ret; }; imports.wbg.__wbg___wbindgen_number_get_a20bf9b85341449d = function(arg0, arg1) { const obj = getObject(arg1); const ret = typeof(obj) === 'number' ? obj : undefined; getDataViewMemory0().setFloat64(arg0 + 8 * 1, isLikeNone(ret) ? 0 : ret, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, !isLikeNone(ret), true); }; imports.wbg.__wbg___wbindgen_string_get_e4f06c90489ad01b = function(arg0, arg1) { const obj = getObject(arg1); const ret = typeof(obj) === 'string' ? obj : undefined; var ptr1 = isLikeNone(ret) ? 0 : passStringToWasm0(ret, wasm.__wbindgen_export, wasm.__wbindgen_export2); var len1 = WASM_VECTOR_LEN; getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); }; imports.wbg.__wbg___wbindgen_throw_b855445ff6a94295 = function(arg0, arg1) { throw new Error(getStringFromWasm0(arg0, arg1)); }; imports.wbg.__wbg_call_e762c39fa8ea36bf = function() { return handleError(function (arg0, arg1) { const ret = getObject(arg0).call(getObject(arg1)); return addHeapObject(ret); }, arguments) }; imports.wbg.__wbg_done_2042aa2670fb1db1 = function(arg0) { const ret = getObject(arg0).done; return ret; }; imports.wbg.__wbg_error_7534b8e9a36f1ab4 = function(arg0, arg1) { let deferred0_0; let deferred0_1; try { deferred0_0 = arg0; deferred0_1 = arg1; console.error(getStringFromWasm0(arg0, arg1)); } finally { wasm.__wbindgen_export4(deferred0_0, deferred0_1, 1); } }; imports.wbg.__wbg_error_a7f8fbb0523dae15 = function(arg0) { console.error(getObject(arg0)); }; imports.wbg.__wbg_get_7bed016f185add81 = function(arg0, arg1) { const ret = getObject(arg0)[arg1 >>> 0]; return addHeapObject(ret); }; imports.wbg.__wbg_get_efcb449f58ec27c2 = function() { return handleError(function (arg0, arg1) { const ret = Reflect.get(getObject(arg0), getObject(arg1)); return addHeapObject(ret); }, arguments) }; imports.wbg.__wbg_instanceof_ArrayBuffer_70beb1189ca63b38 = function(arg0) { let result; try { result = getObject(arg0) instanceof ArrayBuffer; } catch (_) { result = false; } const ret = result; return ret; }; imports.wbg.__wbg_instanceof_Uint8Array_20c8e73002f7af98 = function(arg0) { let result; try { result = getObject(arg0) instanceof Uint8Array; } catch (_) { result = false; } const ret = result; return ret; }; imports.wbg.__wbg_isArray_96e0af9891d0945d = function(arg0) { const ret = Array.isArray(getObject(arg0)); return ret; }; imports.wbg.__wbg_iterator_e5822695327a3c39 = function() { const ret = Symbol.iterator; return addHeapObject(ret); }; imports.wbg.__wbg_length_69bca3cb64fc8748 = function(arg0) { const ret = getObject(arg0).length; return ret; }; imports.wbg.__wbg_length_cdd215e10d9dd507 = function(arg0) { const ret = getObject(arg0).length; return ret; }; imports.wbg.__wbg_log_8cec76766b8c0e33 = function(arg0) { console.log(getObject(arg0)); }; imports.wbg.__wbg_new_5a79be3ab53b8aa5 = function(arg0) { const ret = new Uint8Array(getObject(arg0)); return addHeapObject(ret); }; imports.wbg.__wbg_new_8a6f238a6ece86ea = function() { const ret = new Error(); return addHeapObject(ret); }; imports.wbg.__wbg_new_e17d9f43105b08be = function() { const ret = new Array(); return addHeapObject(ret); }; imports.wbg.__wbg_next_020810e0ae8ebcb0 = function() { return handleError(function (arg0) { const ret = getObject(arg0).next(); return addHeapObject(ret); }, arguments) }; imports.wbg.__wbg_next_2c826fe5dfec6b6a = function(arg0) { const ret = getObject(arg0).next; return addHeapObject(ret); }; imports.wbg.__wbg_prototypesetcall_2a6620b6922694b2 = function(arg0, arg1, arg2) { Uint8Array.prototype.set.call(getArrayU8FromWasm0(arg0, arg1), getObject(arg2)); }; imports.wbg.__wbg_random_babe96ffc73e60a2 = function() { const ret = Math.random(); return ret; }; imports.wbg.__wbg_set_c213c871859d6500 = function(arg0, arg1, arg2) { getObject(arg0)[arg1 >>> 0] = takeObject(arg2); }; imports.wbg.__wbg_stack_0ed75d68575b0f3c = function(arg0, arg1) { const ret = getObject(arg1).stack; const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_export, wasm.__wbindgen_export2); const len1 = WASM_VECTOR_LEN; getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); }; imports.wbg.__wbg_value_692627309814bb8c = function(arg0) { const ret = getObject(arg0).value; return addHeapObject(ret); }; imports.wbg.__wbindgen_cast_2241b6af4c4b2941 = function(arg0, arg1) { // Cast intrinsic for `Ref(String) -> Externref`. const ret = getStringFromWasm0(arg0, arg1); return addHeapObject(ret); }; imports.wbg.__wbindgen_object_drop_ref = function(arg0) { takeObject(arg0); }; return imports; } function __wbg_finalize_init(instance, module) { wasm = instance.exports; __wbg_init.__wbindgen_wasm_module = module; cachedDataViewMemory0 = null; cachedFloat32ArrayMemory0 = null; cachedUint8ArrayMemory0 = null; wasm.__wbindgen_start(); return wasm; } function initSync(module) { if (wasm !== undefined) return wasm; if (typeof module !== 'undefined') { if (Object.getPrototypeOf(module) === Object.prototype) { ({module} = module) } else { console.warn('using deprecated parameters for `initSync()`; pass a single object instead') } } const imports = __wbg_get_imports(); if (!(module instanceof WebAssembly.Module)) { module = new WebAssembly.Module(module); } const instance = new WebAssembly.Instance(module, imports); return __wbg_finalize_init(instance, module); } async function __wbg_init(module_or_path) { if (wasm !== undefined) return wasm; if (typeof module_or_path !== 'undefined') { if (Object.getPrototypeOf(module_or_path) === Object.prototype) { ({module_or_path} = module_or_path) } else { console.warn('using deprecated parameters for the initialization function; pass a single