UNPKG

catniff

Version:

Torch-like deep learning framework for Javascript

471 lines (470 loc) 21.2 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.nn = exports.MultiheadAttention = exports.Embedding = exports.RMSNorm = exports.LayerNorm = exports.BatchNorm = exports.LSTMCell = exports.GRUCell = exports.RNNCell = exports.Linear = void 0; exports.scaledDotProductAttention = scaledDotProductAttention; const core_1 = require("./core"); function linearTransform(input, weight, bias) { let output = input.matmul(weight.t()); if (bias) { output = output.add(bias); } return output; } class Linear { weight; bias; constructor(inFeatures, outFeatures, bias = true, device, dtype) { const bound = 1 / Math.sqrt(inFeatures); this.weight = core_1.Tensor.uniform([outFeatures, inFeatures], -bound, bound, { requiresGrad: true, device, dtype }); if (bias) { this.bias = core_1.Tensor.uniform([outFeatures], -bound, bound, { requiresGrad: true, device, dtype }); } } forward(input) { input = this.weight.handleOther(input); return linearTransform(input, this.weight, this.bias); } } exports.Linear = Linear; function rnnTransform(input, hidden, inputWeight, hiddenWeight, inputBias, hiddenBias) { let output = input.matmul(inputWeight.t()).add(hidden.matmul(hiddenWeight.t())); if (inputBias) { output = output.add(inputBias); } if (hiddenBias) { output = output.add(hiddenBias); } return output; } class RNNCell { weightIH; weightHH; biasIH; biasHH; constructor(inputSize, hiddenSize, bias = true, device, dtype) { const bound = 1 / Math.sqrt(hiddenSize); this.weightIH = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightHH = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); if (bias) { this.biasIH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasHH = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); } } forward(input, hidden) { input = this.weightIH.handleOther(input); hidden = this.weightHH.handleOther(hidden); return rnnTransform(input, hidden, this.weightIH, this.weightHH, this.biasIH, this.biasHH).tanh(); } } exports.RNNCell = RNNCell; class GRUCell { weightIR; weightIZ; weightIN; weightHR; weightHZ; weightHN; biasIR; biasIZ; biasIN; biasHR; biasHZ; biasHN; constructor(inputSize, hiddenSize, bias = true, device, dtype) { const bound = 1 / Math.sqrt(hiddenSize); this.weightIR = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightIZ = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightIN = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightHR = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightHZ = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightHN = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); if (bias) { this.biasIR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasIZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasIN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasHR = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasHZ = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasHN = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); } } forward(input, hidden) { input = this.weightIN.handleOther(input); hidden = this.weightHN.handleOther(hidden); const r = rnnTransform(input, hidden, this.weightIR, this.weightHR, this.biasIR, this.biasHR).sigmoid(); const z = rnnTransform(input, hidden, this.weightIZ, this.weightHZ, this.biasIZ, this.biasHZ).sigmoid(); const n = linearTransform(input, this.weightIN, this.biasIN).add(r.mul(linearTransform(hidden, this.weightHN, this.biasHN))).tanh(); return (z.neg().add(1).mul(n).add(z.mul(hidden))); } } exports.GRUCell = GRUCell; class LSTMCell { weightII; weightIF; weightIG; weightIO; weightHI; weightHF; weightHG; weightHO; biasII; biasIF; biasIG; biasIO; biasHI; biasHF; biasHG; biasHO; constructor(inputSize, hiddenSize, bias = true, device, dtype) { const bound = 1 / Math.sqrt(hiddenSize); this.weightII = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightIF = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightIG = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightIO = core_1.Tensor.uniform([hiddenSize, inputSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightHI = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightHF = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightHG = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.weightHO = core_1.Tensor.uniform([hiddenSize, hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); if (bias) { this.biasII = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasIF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasIG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasIO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasHI = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasHF = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasHG = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); this.biasHO = core_1.Tensor.uniform([hiddenSize], -bound, bound, { requiresGrad: true, device, dtype }); } } forward(input, hidden, cell) { input = this.weightII.handleOther(input); hidden = this.weightHI.handleOther(hidden); cell = this.weightHI.handleOther(cell); const i = rnnTransform(input, hidden, this.weightII, this.weightHI, this.biasII, this.biasHI).sigmoid(); const f = rnnTransform(input, hidden, this.weightIF, this.weightHF, this.biasIF, this.biasHF).sigmoid(); const g = rnnTransform(input, hidden, this.weightIG, this.weightHG, this.biasIG, this.biasHG).tanh(); const o = rnnTransform(input, hidden, this.weightIO, this.weightHO, this.biasIO, this.biasHO).sigmoid(); const c = f.mul(cell).add(i.mul(g)); const h = o.mul(c.tanh()); return [h, c]; } } exports.LSTMCell = LSTMCell; class BatchNorm { weight; bias; runningMean; runningVar; eps; momentum; numFeatures; affine; trackRunningStats; numBatchesTracked; constructor(numFeatures, eps = 1e-5, momentum = 0.1, affine = true, trackRunningStats = true, device, dtype) { this.numFeatures = numFeatures; this.eps = eps; this.momentum = momentum; this.affine = affine; this.trackRunningStats = trackRunningStats; this.numBatchesTracked = 0; if (this.affine) { this.weight = core_1.Tensor.ones([numFeatures], { requiresGrad: true, device, dtype }); this.bias = core_1.Tensor.zeros([numFeatures], { requiresGrad: true, device, dtype }); } if (this.trackRunningStats) { this.runningMean = core_1.Tensor.zeros([numFeatures], { requiresGrad: false, device, dtype }); this.runningVar = core_1.Tensor.ones([numFeatures], { requiresGrad: false, device, dtype }); } } forward(input) { // Input shape: (N, C, ...) where C = numFeatures // Normalize over batch dimension and spatial dimensions (if any) if (input.shape.length < 2) { throw new Error("Input must have at least 2 dimensions (batch, features)"); } if (input.shape[1] !== this.numFeatures) { throw new Error(`Expected ${this.numFeatures} features, got ${input.shape[1]}`); } let mean; let variance; if (core_1.Tensor.training || !this.trackRunningStats) { // Training or trackRunningStats disabled - calculate mean and variance from scratch // Calculate mean and variance over batch and spatial dimensions // Keep only the channel dimension const dims = [0, ...Array.from({ length: input.shape.length - 2 }, (_, i) => i + 2)]; mean = input.mean(dims, true); variance = input.sub(mean).pow(2).mean(dims, true); // Update running statistics if enabled and in training mode if (this.trackRunningStats && core_1.Tensor.training) { const exponentialAverageFactor = this.momentum; this.runningMean = this.runningMean .mul(1 - exponentialAverageFactor) .add(mean.squeeze().mul(exponentialAverageFactor)); // Use unbiased variance for running estimate const n = input.shape.reduce((acc, val, idx) => idx === 1 ? acc : acc * val, 1); const unbiasingFactor = n / (n - 1); this.runningVar = this.runningVar .mul(1 - exponentialAverageFactor) .add(variance.squeeze().mul(exponentialAverageFactor * unbiasingFactor)); this.numBatchesTracked++; } } else { // Inference with trackRunningStats enabled - use running statistics mean = this.runningMean.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]); variance = this.runningVar.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]); } // Normalize let normalized = input.sub(mean).div(variance.add(this.eps).sqrt()); // Apply affine transformation if (this.affine) { const weightReshaped = this.weight.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]); const biasReshaped = this.bias.reshape([1, this.numFeatures, ...Array(input.shape.length - 2).fill(1)]); normalized = normalized.mul(weightReshaped).add(biasReshaped); } return normalized; } } exports.BatchNorm = BatchNorm; class LayerNorm { weight; bias; eps; normalizedShape; constructor(normalizedShape, eps = 1e-5, elementwiseAffine = true, bias = true, device, dtype) { this.eps = eps; this.normalizedShape = Array.isArray(normalizedShape) ? normalizedShape : [normalizedShape]; if (this.normalizedShape.length === 0) { throw new Error("Normalized shape cannot be empty"); } if (elementwiseAffine) { this.weight = core_1.Tensor.ones(this.normalizedShape, { requiresGrad: true, device, dtype }); if (bias) { this.bias = core_1.Tensor.zeros(this.normalizedShape, { requiresGrad: true, device, dtype }); } } } forward(input) { // Normalize over the specified dimensions const normalizedDims = this.normalizedShape.length; const startDim = input.shape.length - normalizedDims; if (startDim < 0) { throw new Error("Input does not have enough dims to normalize"); } const dims = []; for (let i = 0; i < normalizedDims; i++) { if (input.shape[startDim + i] !== this.normalizedShape[i]) { throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`); } dims.push(startDim + i); } const mean = input.mean(dims, true); const variance = input.sub(mean).pow(2).mean(dims, true); let normalized = input.sub(mean).div(variance.add(this.eps).sqrt()); if (this.weight) { normalized = normalized.mul(this.weight); } if (this.bias) { normalized = normalized.add(this.bias); } return normalized; } } exports.LayerNorm = LayerNorm; class RMSNorm { weight; eps; normalizedShape; constructor(normalizedShape, eps = 1e-5, elementwiseAffine = true, device, dtype) { this.eps = eps; this.normalizedShape = Array.isArray(normalizedShape) ? normalizedShape : [normalizedShape]; if (this.normalizedShape.length === 0) { throw new Error("Normalized shape cannot be empty"); } if (elementwiseAffine) { this.weight = core_1.Tensor.ones(this.normalizedShape, { requiresGrad: true, device, dtype }); } } forward(input) { // Normalize over the specified dimensions const normalizedDims = this.normalizedShape.length; const startDim = input.shape.length - normalizedDims; if (startDim < 0) { throw new Error("Input does not have enough dims to normalize"); } const dims = []; for (let i = 0; i < normalizedDims; i++) { if (input.shape[startDim + i] !== this.normalizedShape[i]) { throw new Error(`Shape mismatch at dim ${startDim + i}: expected ${this.normalizedShape[i]}, got ${input.shape[startDim + i]}`); } dims.push(startDim + i); } let rms = input.square().mean(dims, true).add(this.eps).sqrt(); let normalized = input.div(rms); if (this.weight) { normalized = normalized.mul(this.weight); } return normalized; } } exports.RMSNorm = RMSNorm; class Embedding { weight; constructor(numEmbeddings, embeddingDim, device, dtype) { this.weight = core_1.Tensor.randn([numEmbeddings, embeddingDim], { requiresGrad: true, device, dtype }); } forward(input) { return this.weight.index(input); } } exports.Embedding = Embedding; function scaledDotProductAttention(query, key, value, attnMask, dropout = 0, isCausal = false, scale) { const targetLen = query.shape[query.shape.length - 2]; const sourceLen = key.shape[key.shape.length - 2]; const dimSize = query.shape[query.shape.length - 1]; // Attention scores let scores = query.matmul(key.transpose(-2, -1)).div(scale ?? Math.sqrt(dimSize)); // Set attention mask to causal mask if specified if (isCausal) { attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: query.device }).triu(1); } // Apply attention mask if specified if (attnMask) { scores = scores.maskedFill(attnMask, -Infinity); } // Calculate attention weights let attnWeights = scores.softmax().dropout(dropout); // Apply attention to values return attnWeights.matmul(value); } class MultiheadAttention { qProjection; kProjection; vProjection; oProjection; embedDim; numHeads; headDim; dropout; constructor(embedDim, numHeads, dropout = 0, bias = true, device, dtype) { this.qProjection = new Linear(embedDim, embedDim, bias, device, dtype); this.kProjection = new Linear(embedDim, embedDim, bias, device, dtype); this.vProjection = new Linear(embedDim, embedDim, bias, device, dtype); this.oProjection = new Linear(embedDim, embedDim, bias, device, dtype); this.embedDim = embedDim; this.numHeads = numHeads; this.headDim = Math.floor(embedDim / numHeads); this.dropout = dropout; } forward(query, key, value, needWeights = true, attnMask, averageAttnWeights = true, isCausal = false) { // Batch-first const [batchSize, targetLen, embedDim] = query.shape; const sourceLen = key.shape[1]; let Q = this.qProjection.forward(query); // (batchSize, targetLen, embedDim) let K = this.kProjection.forward(key); // (batchSize, sourceLen, embedDim) let V = this.vProjection.forward(value); // (batchSize, sourceLen, embedDim) // (batchSize, numHeads, targetLen/sourceLen, headDim) Q = Q.reshape([batchSize, targetLen, this.numHeads, this.headDim]).transpose(1, 2); K = K.reshape([batchSize, sourceLen, this.numHeads, this.headDim]).transpose(1, 2); V = V.reshape([batchSize, sourceLen, this.numHeads, this.headDim]).transpose(1, 2); // Attention scores let scores = Q.matmul(K.transpose(-2, -1)).div(Math.sqrt(this.headDim)); // Set attention mask to causal mask if specified if (isCausal) { attnMask = core_1.Tensor.ones([targetLen, sourceLen], { device: this.qProjection.weight.device }).triu(1); } // Apply attention mask if specified if (attnMask) { scores = scores.maskedFill(attnMask, -Infinity); } // Calculate attention weights let attnWeights = scores.softmax().dropout(this.dropout); // Apply attention to values let attnOutput = attnWeights.matmul(V); // (batchSize, numHeads, targetLen, headDim) // (batchSize, targetLen, embedDim) attnOutput = attnOutput.transpose(1, 2).reshape([batchSize, targetLen, embedDim]); // Output const output = this.oProjection.forward(attnOutput); // Average weights if needed if (averageAttnWeights) { attnWeights = attnWeights.mean(1); } return [output, needWeights ? attnWeights : undefined]; } } exports.MultiheadAttention = MultiheadAttention; const state = { getParameters(model, visited = new WeakSet()) { if (visited.has(model)) return []; visited.add(model); const parameters = []; for (const key in model) { if (!model.hasOwnProperty(key)) continue; const value = model[key]; if (value instanceof core_1.Tensor) { parameters.push(value); } else if (typeof value === "object" && value !== null) { parameters.push(...state.getParameters(value, visited)); } } return parameters; }, moveParameters(model, device) { const params = state.getParameters(model); for (const param of params) { param.to_(device); } }, getStateDict(model, prefix = "", visited = new WeakSet()) { if (visited.has(model)) return {}; visited.add(model); const stateDict = {}; for (const key in model) { if (!model.hasOwnProperty(key)) continue; const value = model[key]; const fullKey = prefix ? `${prefix}.${key}` : key; if (value instanceof core_1.Tensor) { stateDict[fullKey] = value.val(); } else if (typeof value === "object" && value !== null) { Object.assign(stateDict, state.getStateDict(value, fullKey, visited)); } } return stateDict; }, loadStateDict(model, stateDict, prefix = "", visited = new WeakSet()) { if (visited.has(model)) return; visited.add(model); for (const key in model) { if (!model.hasOwnProperty(key)) continue; const value = model[key]; const fullKey = prefix ? `${prefix}.${key}` : key; if (value instanceof core_1.Tensor && stateDict[fullKey]) { value.replace(new core_1.Tensor(stateDict[fullKey], { device: value.device })); } else if (typeof value === "object" && value !== null) { state.loadStateDict(value, stateDict, fullKey, visited); } } } }; exports.nn = { Linear, RNNCell, GRUCell, LSTMCell, BatchNorm, LayerNorm, RMSNorm, Embedding, scaledDotProductAttention, MultiheadAttention, state };