UNPKG

catniff

Version:

A small Torch-like deep learning framework for Javascript

1,129 lines 76.2 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.Tensor = void 0; const utils_1 = require("./utils"); class Tensor { value; shape; strides; grad; requiresGrad; gradFn; children; device; static training = false; constructor(value, options = {}) { this.value = Tensor.flatten(value); this.shape = options.shape || Tensor.getShape(value); this.strides = options.strides || Tensor.getStrides(this.shape); this.grad = options.grad; this.requiresGrad = options.requiresGrad ?? false; this.gradFn = options.gradFn || (() => { }); this.children = options.children || []; this.device = options.device || "cpu"; // Move tensor to device if (this.device !== "cpu") { const backend = Tensor.backends.get(this.device); if (backend && backend.transfer) { backend.transfer(this); } } } // Utility to flatten an nD array to be 1D static flatten(tensor) { // Handle scalar tensors if (typeof tensor === "number") return tensor; // If value is already 1D, we just need to return the value ('s reference) if (typeof tensor[0] === "number") return tensor; // Or else recursively traverse through the nD array to flatten const result = []; function traverse(arr) { if (typeof arr === "number") { result.push(arr); } else if (Array.isArray(arr)) { arr.forEach(traverse); } } traverse(tensor); return result; } // Utility to get shape from tensor *value* static getShape(tensor) { const shape = []; let subA = tensor; while (Array.isArray(subA)) { shape.push(subA.length); subA = subA[0]; } return shape; } // Utility to get strides from shape static getStrides(shape) { if (shape.length === 0) return []; const strides = new Array(shape.length); strides[strides.length - 1] = 1; for (let i = strides.length - 2; i >= 0; i--) { strides[i] = strides[i + 1] * shape[i + 1]; } return strides; } // Left-pad shape and strides for two shape to be of same length static padShape(stridesA, stridesB, shapeA, shapeB) { const newStrideA = [...stridesA], newStrideB = [...stridesB]; const newShapeA = [...shapeA], newShapeB = [...shapeB]; while (newStrideA.length < newStrideB.length) { const newStride = newShapeA[0] * newStrideA[0]; newStrideA.unshift(newStride); newShapeA.unshift(1); } while (newStrideA.length > newStrideB.length) { const newStride = newShapeB[0] * newStrideB[0]; newStrideB.unshift(newStride); newShapeB.unshift(1); } return [newStrideA, newStrideB, newShapeA, newShapeB]; } // Broadcast shapes static broadcastShapes(shapeA, shapeB) { const newShape = new Array(shapeA.length); for (let index = 0; index < shapeA.length; index++) { if (shapeA[index] === 1) { newShape[index] = shapeB[index]; } else if (shapeB[index] === 1) { newShape[index] = shapeA[index]; } else if (shapeA[index] === shapeB[index]) { newShape[index] = shapeA[index]; } else { throw new Error(`Cannot broadcast shapes: ${shapeA} and ${shapeB}`); } } return newShape; } // Utility to convert flat index to array of coordinates static indexToCoords(index, strides) { const coords = new Array(strides.length); let remaining = index; for (let dim = 0; dim < strides.length; dim++) { coords[dim] = Math.floor(remaining / strides[dim]); remaining %= strides[dim]; } return coords; } // Utility to convert array of coordinates to *unbroadcasted* flat index static coordsToUnbroadcastedIndex(coords, shape, strides) { let index = 0; for (let i = 0; i < coords.length; i++) { // Handle broadcasting const actualCoord = shape[i] === 1 ? 0 : coords[i]; index += actualCoord * strides[i]; } return index; } // Utility to convert array of coordinates to flat index static coordsToIndex(coords, strides) { let index = 0; for (let i = 0; i < coords.length; i++) { index += coords[i] * strides[i]; } return index; } // Utility to convert shape into 1D value array size static shapeToSize(shape) { let prod = 1; for (let i = 0; i < shape.length; i++) { prod *= shape[i]; } return prod; } ; // Utility for binary (two operators involved) element-wise ops static elementWiseAB(tA, tB, op) { if (typeof tA.value === "number" && typeof tB.value === "number") { return new Tensor(op(tA.value, tB.value)); } if (typeof tA.value === "number") { return Tensor.elementWiseSelf(tB, (a) => op(a, tA.value)); } if (typeof tB.value === "number") { return Tensor.elementWiseSelf(tA, (a) => op(a, tB.value)); } // Pad + broadcast shape const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape); const outputShape = Tensor.broadcastShapes(paddedAShape, paddedBShape); // Get other output info const outputStrides = Tensor.getStrides(outputShape); const outputSize = Tensor.shapeToSize(outputShape); const outputValue = new Array(outputSize); for (let i = 0; i < outputSize; i++) { // Get coordinates from 1D index const coordsOutput = Tensor.indexToCoords(i, outputStrides); // Convert the coordinates to 1D index of flattened A with respect to A's shape const indexA = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedAShape, paddedAStrides); // Convert the coordinates to 1D index of flattened B with respect to B's shape const indexB = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides); // Calculate with op outputValue[i] = op(tA.value[indexA], tB.value[indexB]); } return new Tensor(outputValue, { shape: outputShape, strides: outputStrides }); } // Utility for self-inflicting element-wise ops static elementWiseSelf(tA, op) { if (typeof tA.value === "number") return new Tensor(op(tA.value)); const newValue = new Array(tA.value.length); for (let index = 0; index < tA.value.length; index++) { newValue[index] = op(tA.value[index]); } return new Tensor(newValue, { shape: tA.shape, strides: tA.strides }); } // Utility to do element-wise operation and build a dag node with another tensor elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) { other = Tensor.forceTensor(other); const out = Tensor.elementWiseAB(this, other, op); if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); } if (other.requiresGrad) { out.requiresGrad = true; out.children.push(other); } if (out.requiresGrad) { out.gradFn = () => { // Disable gradient collecting of gradients themselves const outGrad = out.grad; const selfNoGrad = this.detach(); const otherNoGrad = other.detach(); if (this.requiresGrad) Tensor.addGrad(this, thisGrad(selfNoGrad, otherNoGrad, outGrad)); if (other.requiresGrad) Tensor.addGrad(other, otherGrad(selfNoGrad, otherNoGrad, outGrad)); }; } return out; } // Utility to do self-inflicting element-wise operation and build a dag node elementWiseSelfDAG(op, thisGrad = () => new Tensor(0)) { const out = Tensor.elementWiseSelf(this, op); if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); } if (out.requiresGrad) { out.gradFn = () => { // Disable gradient collecting of gradients themselves const outGrad = out.grad; const selfNoGrad = this.detach(); if (this.requiresGrad) Tensor.addGrad(this, thisGrad(selfNoGrad, outGrad)); }; } return out; } // Utility to force an input value to be a tensor static forceTensor(value) { if (value instanceof Tensor) return value; return new Tensor(value); } // Utility to add to gradient of tensor static addGrad(tensor, accumGrad) { const axesToSqueeze = []; const axesToReduce = []; const shape = tensor.shape; const gradShape = accumGrad.shape; const paddedDims = gradShape.length - shape.length; for (let i = 0; i < paddedDims; i++) { axesToReduce.push(i); axesToSqueeze.push(i); } for (let i = 0; i < shape.length; i++) { if (shape[i] === 1 && gradShape[i + paddedDims] > 1) { axesToReduce.push(i + paddedDims); } } const reducedGrad = accumGrad.sum(axesToReduce, true); const squeezedGrad = reducedGrad.squeeze(axesToSqueeze); if (typeof tensor.grad === "undefined") { tensor.grad = squeezedGrad; } else { tensor.grad = tensor.grad.add(squeezedGrad); } } // Contiguity-related ops isContiguous() { const expectedStrides = Tensor.getStrides(this.shape); if (expectedStrides.length !== this.strides.length) { return false; } for (let i = 0; i < this.strides.length; i++) { if (this.strides[i] !== expectedStrides[i]) { return false; } } return true; } contiguous() { // Check if scalar if (typeof this.value === "number") return this; // Check if already contiguous if (this.isContiguous()) return this; const outputStrides = Tensor.getStrides(this.shape); const outputSize = Tensor.shapeToSize(this.shape); const outputValue = new Array(outputSize); for (let index = 0; index < outputSize; index++) { const outputCoords = Tensor.indexToCoords(index, outputStrides); const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides); outputValue[index] = this.value[originalIndex]; } const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides }); // Gradient flow back to the original tensor if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { Tensor.addGrad(this, out.grad); }; } return out; } reshape(newShape) { // Verify shape size const originalSize = Tensor.shapeToSize(this.shape); const outputSize = Tensor.shapeToSize(newShape); if (originalSize !== outputSize) { throw new Error("Cannot reshape: incompatible sizes"); } const outputStrides = Tensor.getStrides(newShape); const out = new Tensor(this.contiguous().value, { shape: newShape, strides: outputStrides }); // Gradient reshaped and flow back to the original tensor if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { Tensor.addGrad(this, out.grad.reshape(this.shape)); }; } return out; } // Tensor squeeze squeeze(dims) { if (typeof this.value === "number") return this; if (typeof dims === "number") { dims = [dims]; } if (typeof dims === "undefined") { const shape = this.shape; dims = []; for (let index = 0; index < shape.length; index++) { if (shape[index] === 1) { dims.push(index); } } } // Remove size-1 dims only const outShape = [], outStrides = []; for (let index = 0; index < this.shape.length; index++) { const dim = this.shape[index]; const stride = this.strides[index]; if (dims.includes(index)) { if (dim !== 1) throw new Error(`Can not squeeze dim with size ${dim}`); } else { outShape.push(dim); outStrides.push(stride); } } const outValue = outShape.length === 0 ? this.value[0] : this.value; const out = new Tensor(outValue, { shape: outShape, strides: outStrides, device: this.device }); // Set up gradient if needed if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { let restoredGrad = out.grad; for (let i = dims.length - 1; i >= 0; i--) { restoredGrad = restoredGrad.unsqueeze(dims[i]); } Tensor.addGrad(this, restoredGrad); }; } return out; } // Tensor unsqueeze - adds dimension of size 1 at specified position unsqueeze(dim) { let thisValue = this.value; if (typeof thisValue === "number") { thisValue = [thisValue]; } // Insert size-1 dimension at specified position const newShape = [...this.shape]; newShape.splice(dim, 0, 1); // New stride const newStrides = [...this.strides]; let newDimStride; if (dim >= this.shape.length) { // Inserting at the back: use 1 newDimStride = 1; } else { // Inserting before dim: use current stride * current shape newDimStride = this.strides[dim] * this.shape[dim]; } newStrides.splice(dim, 0, newDimStride); const out = new Tensor(thisValue, { shape: newShape, strides: newStrides, device: this.device }); // Set up gradient if needed if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { Tensor.addGrad(this, out.grad.squeeze(dim)); }; } return out; } // Tensor sum reduction sum(dims, keepDims = false) { if (typeof this.value === "number") return this; if (typeof dims === "undefined") { dims = Array.from({ length: this.shape.length }, (_, index) => index); } if (Array.isArray(dims)) { // Sort in descending order const sortedDims = dims.sort((a, b) => b - a); let reducedThis = this; for (let i = 0; i < sortedDims.length; i++) { reducedThis = reducedThis.sum(sortedDims[i], true); } return keepDims ? reducedThis : reducedThis.squeeze(dims); } // Dims that are reduced now have size-1 const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim); const outputStrides = Tensor.getStrides(outputShape); const outputSize = Tensor.shapeToSize(outputShape); const outputValue = new Array(outputSize).fill(0); const originalSize = Tensor.shapeToSize(this.shape); // Gradient data let gradShape, gradStrides, gradValue = []; // Allocate gradient data only when needed if (this.requiresGrad) { gradShape = this.shape; gradStrides = this.strides; gradValue = new Array(originalSize).fill(0); } // Calculate new value after sum for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // Add into sum outputValue[outFlatIndex] += this.value[realFlatIndex]; // Mark for gradient if needed if (this.requiresGrad) { gradValue[realFlatIndex] = 1; } } const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides }); // Set up gradient if needed if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides }); Tensor.addGrad(this, out.grad.mul(localGrad)); }; } return keepDims ? out : out.squeeze(dims); } // Tensor product reduction prod(dims, keepDims = false) { if (typeof this.value === "number") return this; if (typeof dims === "undefined") { dims = Array.from({ length: this.shape.length }, (_, index) => index); } if (Array.isArray(dims)) { // Sort in descending order const sortedDims = dims.sort((a, b) => b - a); let reducedThis = this; for (let i = 0; i < sortedDims.length; i++) { reducedThis = reducedThis.prod(sortedDims[i], true); } return keepDims ? reducedThis : reducedThis.squeeze(dims); } // Dims that are reduced now have size-1 const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim); const outputStrides = Tensor.getStrides(outputShape); const outputSize = Tensor.shapeToSize(outputShape); const outputValue = new Array(outputSize).fill(1); const originalSize = Tensor.shapeToSize(this.shape); // Calculate new value after multiplying for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // Multiply into product outputValue[outFlatIndex] *= this.value[realFlatIndex]; } const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides }); // Set up gradient if needed if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0); for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // Grad is the product of other elements of the same axis, which is product of all els divided by the current value gradValue[realFlatIndex] = outputValue[outFlatIndex] / this.value[realFlatIndex]; } const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides }); Tensor.addGrad(this, out.grad.mul(localGrad)); }; } return keepDims ? out : out.squeeze(dims); } // Tensor mean reduction mean(dims, keepDims = false) { if (typeof this.value === "number") return this; if (typeof dims === "undefined") { dims = Array.from({ length: this.shape.length }, (_, index) => index); } if (Array.isArray(dims)) { // Sort in descending order const sortedDims = dims.sort((a, b) => b - a); let reducedThis = this; for (let i = 0; i < sortedDims.length; i++) { reducedThis = reducedThis.mean(sortedDims[i], true); } return keepDims ? reducedThis : reducedThis.squeeze(dims); } // Dims that are reduced now have size-1 const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim); const outputStrides = Tensor.getStrides(outputShape); const outputSize = Tensor.shapeToSize(outputShape); const outputValue = new Array(outputSize).fill(0); const outputFeeders = new Array(outputSize).fill(0); const originalSize = Tensor.shapeToSize(this.shape); // Calculate sums and how many elements contribute to specific positions for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // Calculate sum and contributors to the sum outputValue[outFlatIndex] += this.value[realFlatIndex]; outputFeeders[outFlatIndex]++; } // Calculate mean by dividing sum by the number of contributors to the position for (let index = 0; index < outputSize; index++) { outputValue[index] /= outputFeeders[index]; } const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides }); // Set up gradient if needed if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0); // Calculate grad by assigning 1 divided by the number of contributors to the position for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // Mean = 1/n * (el1 + el2 + ... + eln) so grad = 1/n gradValue[realFlatIndex] = 1 / outputFeeders[outFlatIndex]; } const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides }); Tensor.addGrad(this, out.grad.mul(localGrad)); }; } return keepDims ? out : out.squeeze(dims); } // Tensor maximum reduction max(dims, keepDims = false) { if (typeof this.value === "number") return this; if (typeof dims === "undefined") { dims = Array.from({ length: this.shape.length }, (_, index) => index); } if (Array.isArray(dims)) { // Sort in descending order const sortedDims = dims.sort((a, b) => b - a); let reducedThis = this; for (let i = 0; i < sortedDims.length; i++) { reducedThis = reducedThis.max(sortedDims[i], true); } return keepDims ? reducedThis : reducedThis.squeeze(dims); } // Dims that are reduced now have size-1 const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim); const outputStrides = Tensor.getStrides(outputShape); const outputSize = Tensor.shapeToSize(outputShape); const outputValue = new Array(outputSize).fill(-Infinity); const originalSize = Tensor.shapeToSize(this.shape); // Calculate maximum values of axes for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // Get max over time if (this.value[realFlatIndex] > outputValue[outFlatIndex]) { outputValue[outFlatIndex] = this.value[realFlatIndex]; } } const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides }); // Set up gradient if needed if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0); const shareCounts = new Array(outputSize).fill(0); const originalValue = this.value; for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // We collect how many elements share the same max value first shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0; } for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // Here we share the grad between the elements that share the same max value gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0; } const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides }); Tensor.addGrad(this, out.grad.mul(localGrad)); }; } return keepDims ? out : out.squeeze(dims); } // Tensor minimum reduction min(dims, keepDims = false) { if (typeof this.value === "number") return this; if (typeof dims === "undefined") { dims = Array.from({ length: this.shape.length }, (_, index) => index); } if (Array.isArray(dims)) { // Sort in descending order const sortedDims = dims.sort((a, b) => b - a); let reducedThis = this; for (let i = 0; i < sortedDims.length; i++) { reducedThis = reducedThis.min(sortedDims[i], true); } return keepDims ? reducedThis : reducedThis.squeeze(dims); } // Dims that are reduced now have size-1 const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim); const outputStrides = Tensor.getStrides(outputShape); const outputSize = Tensor.shapeToSize(outputShape); const outputValue = new Array(outputSize).fill(Infinity); const originalSize = Tensor.shapeToSize(this.shape); // Calculate minimum values of axes for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // Get min over time if (this.value[realFlatIndex] < outputValue[outFlatIndex]) { outputValue[outFlatIndex] = this.value[realFlatIndex]; } } const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides }); // Set up gradient if needed if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0); const shareCounts = new Array(outputSize).fill(0); const originalValue = this.value; for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // We collect how many elements share the same min value first shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0; } for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, this.strides); // Force 0 on reduced axes to collapse into size-1 dims const outCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert output coordinates to flat index const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides); // Here we share the grad between the elements that share the same min value gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0; } const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides }); Tensor.addGrad(this, out.grad.mul(localGrad)); }; } return keepDims ? out : out.squeeze(dims); } // Tensor variance reduction var(dims, keepDims = false) { const meanXSquared = this.square().mean(dims, keepDims); const meanXSquaredExpanded = this.mean(dims, keepDims).square(); return meanXSquared.sub(meanXSquaredExpanded); } // Tensor standard deviation reduction std(dims, keepDims = false) { return this.var(dims, keepDims).sqrt(); } // Tensor product reduction softmax(dims) { if (typeof this.value === "number") return this; if (typeof dims === "undefined") { dims = Array.from({ length: this.shape.length }, (_, index) => index); } if (Array.isArray(dims)) { // Sort in descending order const sortedDims = dims.sort((a, b) => b - a); let reducedThis = this; for (let i = 0; i < sortedDims.length; i++) { reducedThis = reducedThis.softmax(sortedDims[i]); } return reducedThis; } // Dims that are reduced now have size-1 const expSumShape = this.shape.map((dim, i) => dims === i ? 1 : dim); const expSumStrides = Tensor.getStrides(expSumShape); const expSumSize = Tensor.shapeToSize(expSumShape); const expSumValue = new Array(expSumSize).fill(0); const outputShape = this.shape; const outputStrides = this.strides; const outputSize = Tensor.shapeToSize(outputShape); const outputValue = new Array(outputSize); // Calculate sums of e^xi over axes for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, outputStrides); // Force 0 on reduced axes to collapse into size-1 dims const expSumCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert exp sum coordinates to flat index const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides); // Add e^x to the sum cache expSumValue[expSumFlatIndex] += Math.exp(this.value[realFlatIndex]); } // Calculate e^xi / sum over axes for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) { const coords = Tensor.indexToCoords(realFlatIndex, outputStrides); // Force 0 on reduced axes to collapse into size-1 dims const expSumCoords = coords.map((val, i) => dims === i ? 0 : val); // Convert exp sum coordinates to flat index const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides); // Calculate e^xi / sum outputValue[realFlatIndex] = Math.exp(this.value[realFlatIndex]) / expSumValue[expSumFlatIndex]; } const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides }); // Set up gradient if needed if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { const upstreamGrad = out.grad; const softmaxOutput = out.detach(); // Compute element-wise product: ∂L/∂σᵢ × σᵢ const gradTimesOutput = upstreamGrad.mul(softmaxOutput); // Sum over softmax dimensions: Σᵢ(∂L/∂σᵢ × σᵢ) const sumGradOutput = gradTimesOutput.sum(dims, true); // keepDims=true for broadcasting // Apply softmax gradient formula: // ∂L/∂zⱼ = (∂L/∂σⱼ × σⱼ) - (σⱼ × Σᵢ(∂L/∂σᵢ × σᵢ)) const term1 = upstreamGrad.mul(softmaxOutput); // ∂L/∂σⱼ × σⱼ const term2 = softmaxOutput.mul(sumGradOutput); // σⱼ × Σᵢ(∂L/∂σᵢ × σᵢ) const localGrad = term1.sub(term2); Tensor.addGrad(this, localGrad); }; } return out; } // Tensor element-wise addition add(other) { return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad); } // Tensor element-wise subtraction sub(other) { return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad.neg()); } subtract = this.sub; // Tensor element-wise multiplication mul(other) { return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => outGrad.mul(other), (self, other, outGrad) => outGrad.mul(self)); } multiply = this.mul; // Tensor element-wise power pow(other) { return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => outGrad.mul(other.mul(self.pow(other.sub(1)))), (self, other, outGrad) => outGrad.mul(self.pow(other).mul(self.log()))); } // Tensor element-wise division div(other) { return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => outGrad.div(other), (self, other, outGrad) => outGrad.mul(self.neg().div(other.square()))); } divide = this.div; // Tensor element-wise modulo remainder(other) { return this.elementWiseABDAG(other, (a, b) => a % b); } // Tensor element-wise greater or equal comparison ge(other) { return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0); } greaterEqual = this.ge; // Tensor element-wise less or equal comparison le(other) { return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0); } lessEqual = this.le; // Tensor element-wise greater-than comparison gt(other) { return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0); } greater = this.gt; // Tensor element-wise less-than comparison lt(other) { return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0); } less = this.lt; // Tensor element-wise equality comparison eq(other) { return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0); } equal = this.eq; // Tensor element-wise not equality comparison ne(other) { return this.elementWiseABDAG(other, (a, b) => a !== b ? 1 : 0); } notEqual = this.ne; // Tensor element-wise logical and logicalAnd(other) { return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0); } // Tensor element-wise logical or logicalOr(other) { return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0); } // Tensor element-wise logical xor logicalXor(other) { return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0); } // Tensor element-wise logical not logicalNot() { return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1); } // Tensor element-wise bitwise and bitwiseAnd(other) { return this.elementWiseABDAG(other, (a, b) => a & b); } // Tensor element-wise bitwise or bitwiseOr(other) { return this.elementWiseABDAG(other, (a, b) => a | b); } // Tensor element-wise bitwise xor bitwiseXor(other) { return this.elementWiseABDAG(other, (a, b) => a ^ b); } // Tensor element-wise bitwise not bitwiseNot() { return this.elementWiseSelfDAG((a) => ~a); } // Tensor element-wise left shift bitwiseLeftShift(other) { return this.elementWiseABDAG(other, (a, b) => a << b); } // Tensor element-wise right shift bitwiseRightShift(other) { return this.elementWiseABDAG(other, (a, b) => a >> b); } // Tensor element-wise negation neg() { return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => outGrad.mul(-1)); } negative = this.neg; // Tensor element-wise reciprocal reciprocal() { return this.elementWiseSelfDAG((a) => 1 / a, (self, outGrad) => outGrad.mul(self.pow(-2).neg())); } // Tensor element-wise square square() { return this.elementWiseSelfDAG((a) => a * a, (self, outGrad) => outGrad.mul(self.mul(2))); } // Tensor element-wise absolute abs() { return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => outGrad.mul(self.sign())); } absolute = this.abs; // Tensor element-wise sign function sign() { return this.elementWiseSelfDAG((a) => Math.sign(a)); } // Tensor element-wise sin sin() { return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => outGrad.mul(self.cos())); } // Tensor element-wise cos cos() { return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => outGrad.mul(self.sin().neg())); } // Tensor element-wise tan tan() { return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => outGrad.mul(self.tan().square().add(1))); } // Tensor element-wise asin asin() { return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt())); } arcsin = this.asin; // Tensor element-wise acos acos() { return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()).neg()); } arccos = this.acos; // Tensor element-wise atan atan() { return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => outGrad.div(self.square().add(1))); } arctan = this.atan; // Tensor element-wise atan2 atan2(other) { return this.elementWiseABDAG(other, (a, b) => Math.atan2(a, b), (self, other, outGrad) => outGrad.mul(other.div(self.square().add(other.square()))), (self, other, outGrad) => outGrad.mul(self.neg().div(self.square().add(other.square())))); } arctan2 = this.atan2; // Tensor element-wise sinh sinh() { return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => outGrad.mul(self.cosh())); } // Tensor element-wise cosh cosh() { return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => outGrad.mul(self.sinh())); } // Tensor element-wise asinh asinh() { return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => outGrad.div(self.square().add(1).sqrt())); } arcsinh = this.asinh; // Tensor element-wise acosh acosh() { return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt()))); } arccosh = this.acosh; // Tensor element-wise atanh atanh() { return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => outGrad.div(self.square().neg().add(1))); } arctanh = this.atanh; // Tensor element-wise degree to radian deg2rad() { return this.elementWiseSelfDAG((a) => a * (Math.PI / 180), (self, outGrad) => outGrad.mul(Math.PI / 180)); } // Tensor element-wise radian to degree rad2deg() { return this.elementWiseSelfDAG((a) => a / (Math.PI / 180), (self, outGrad) => outGrad.div(Math.PI / 180)); } // Tensor element-wise square root sqrt() { return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => outGrad.div(self.sqrt().mul(2))); } // Tensor element-wise reciprocal of square root rsqrt() { return this.elementWiseSelfDAG((a) => 1 / Math.sqrt(a), (self, outGrad) => outGrad.mul(self.pow(-1.5).mul(-0.5))); } // Tensor element-wise e^x exp() { return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) => outGrad.mul(self.exp())); } // Tensor element-wise 2^x exp2() { return this.elementWiseSelfDAG((a) => 2 ** a, (self, outGrad) => outGrad.mul(self.exp2().mul(Math.log(2)))); } // Tensor element-wise e^x - 1 expm1() { return this.elementWiseSelfDAG((a) => Math.expm1(a), (self, outGrad) => outGrad.mul(self.exp())); } // Tensor element-wise natural log log() { return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) => outGrad.div(self)); } // Tensor element-wise log2 log2() { return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) => outGrad.div(self.mul(Math.log(2)))); } // Tensor element-wise log10 log10() { return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) => outGrad.div(self.mul(Math.log(10)))); } // Tensor element-wise log(1+x) log1p() { return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) => outGrad.div(self.add(1))); } // Tensor element-wise relu relu() { return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) => outGrad.mul(self.gt(0))); } // Tensor element-wise sigmoid sigmoid() { return this.elementWiseSelfDAG((a) => 1 / (1 + Math.exp(-a)), (self, outGrad) => { const sig = self.sigmoid(); return outGrad.mul(sig).mul(sig.neg().add(1)); }); } // Tensor element-wise tanh tanh() { return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) => outGrad.mul(self.tanh().square().neg().add(1))); } // Tensor element-wise softplus softplus() { return this.elementWiseSelfDAG((a) => Math.log1p(Math.exp(a)), (self, outGrad) => outGrad.mul(self.sigmoid())); } // Tensor element-wise softsign softsign() { return this.elementWiseSelfDAG((a) => a / (1 + Math.abs(a)), (self, outGrad) => outGrad.div(self.abs().add(1).square())); } // Tensor element-wise silu (swish) silu() { return this.elementWiseSelfDAG((a) => a / (1 + Math.exp(-a)), (self, outGrad) => { const sig = self.sigmoid(); return outGrad.mul(sig.add(self.mul(sig).mul(sig.neg().add(1)))); }); } // Tensor element-wise mish mish() { return this.elementWiseSelfDAG((a) => a * Math.tanh(Math.log1p(Math.exp(a))), (self, outGrad) => { const tanhSoftPlus = self.exp().add(1).log().tanh(); // tanh(softplus(x)) + x * (1 - tanh²(softplus(x))) * sigmoid(x) const derivative = tanhSoftPlus.add(self.mul(tanhSoftPlus.square().neg().add(1)).mul(self.sigmoid())); return outGrad.mul(derivative); }); } // Tensor element-wise gelu gelu(approximate = "none") { if (approximate === "none") { return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + (0, utils_1.erf)(a / Math.sqrt(2))), (self, outGrad) => { const sqrt2 = Math.sqrt(2); const sqrt2OverPi = Math.sqrt(2 / Math.PI); const xOverSqrt2 = self.div(sqrt2); const erfVal = xOverSqrt2.erf(); const phi = xOverSqrt2.square().neg().exp().div(sqrt2OverPi); const derivative = erfVal.add(1).mul(0.5).add(self.mul(phi)); return outGrad.mul(derivative); }); } else if (approximate === "tanh") { return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (a + 0.044715 * a * a * a))), (self, outGrad) => { const sqrt2OverPi = Math.sqrt(2 / Math.PI); const c = 0.044715; const tanhArg = self.add(self.pow(3).mul(c)).mul(sqrt2OverPi); const tanhVal = tanhArg.tanh(); const sechSquared = tanhVal.square().neg().add(1); const term1 = tanhVal.add(1).mul(0.5); const term2 = self.mul(sechSquared).mul(sqrt2OverPi).mul(self.square().mul(c * 3).add(1)).mul(0.5); const derivative = term1.add(term2); return outGrad.mul(derivative); }); } throw new Error("Specified approximation does not exist"); } // Tensor element-wise maximum maximum(other) { return this.elementWiseABDAG(other, (a, b) => Math.max(a, b), (self, other, outGrad) => outGrad.mul(self.gt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.gt(self).add(other.eq(self).mul(0.5)))); } // Tensor element-wise minimum minimum(other) { return this.elementWiseABDAG(other, (a, b) => Math.min(a, b), (self, other, outGrad) => outGrad.mul(self.lt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.lt(self).add(other.eq(self).mul(0.5)))); } // Tensor element-wise round round() { return this.elementWiseSelfDAG((a) => Math.round(a)); } // Tensor element-wise floor floor() { return this.elementWiseSelfDAG((a) => Math.floor(a)); } // Tensor element-wise ceil ceil() { return this.elementWiseSelfDAG((a) => Math.ceil(a)); } // Tensor element-wise truncation trunc() { return this.elementWiseSelfDAG((a) => Math.trunc(a)); } fix = this.trunc; // Tensor element-wise fraction portion frac() { return this.elementWiseSelfDAG((a) => a - Math.floor(a)); } // Tensor element-wise clip and clamp clip(min, max) { return this.elementWiseSelfDAG((a) => Math.max(min, Math.min(max, a)), (self, outGrad) => outGrad.mul(self.ge(min).mul(self.le(max)))); } clamp = this.clip; // Tensor element-wise error function erf() { return this.elementWiseSelfDAG((a) => (0, utils_1.erf)(a), (self, outGrad) => outGrad.mul(self.square().neg().exp().mul(2 / Math.sqrt(Math.PI)))); } // Tensor element-wise complementary error function erfc() { return this.elementWiseSelfDAG((a) => (0, utils_1.erfc)(a), (self, outGrad) => outGrad.mul(