catniff
Version:
A small Torch-like deep learning framework for Javascript
1,129 lines • 76.2 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.Tensor = void 0;
const utils_1 = require("./utils");
class Tensor {
value;
shape;
strides;
grad;
requiresGrad;
gradFn;
children;
device;
static training = false;
constructor(value, options = {}) {
this.value = Tensor.flatten(value);
this.shape = options.shape || Tensor.getShape(value);
this.strides = options.strides || Tensor.getStrides(this.shape);
this.grad = options.grad;
this.requiresGrad = options.requiresGrad ?? false;
this.gradFn = options.gradFn || (() => { });
this.children = options.children || [];
this.device = options.device || "cpu";
// Move tensor to device
if (this.device !== "cpu") {
const backend = Tensor.backends.get(this.device);
if (backend && backend.transfer) {
backend.transfer(this);
}
}
}
// Utility to flatten an nD array to be 1D
static flatten(tensor) {
// Handle scalar tensors
if (typeof tensor === "number")
return tensor;
// If value is already 1D, we just need to return the value ('s reference)
if (typeof tensor[0] === "number")
return tensor;
// Or else recursively traverse through the nD array to flatten
const result = [];
function traverse(arr) {
if (typeof arr === "number") {
result.push(arr);
}
else if (Array.isArray(arr)) {
arr.forEach(traverse);
}
}
traverse(tensor);
return result;
}
// Utility to get shape from tensor *value*
static getShape(tensor) {
const shape = [];
let subA = tensor;
while (Array.isArray(subA)) {
shape.push(subA.length);
subA = subA[0];
}
return shape;
}
// Utility to get strides from shape
static getStrides(shape) {
if (shape.length === 0)
return [];
const strides = new Array(shape.length);
strides[strides.length - 1] = 1;
for (let i = strides.length - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
// Left-pad shape and strides for two shape to be of same length
static padShape(stridesA, stridesB, shapeA, shapeB) {
const newStrideA = [...stridesA], newStrideB = [...stridesB];
const newShapeA = [...shapeA], newShapeB = [...shapeB];
while (newStrideA.length < newStrideB.length) {
const newStride = newShapeA[0] * newStrideA[0];
newStrideA.unshift(newStride);
newShapeA.unshift(1);
}
while (newStrideA.length > newStrideB.length) {
const newStride = newShapeB[0] * newStrideB[0];
newStrideB.unshift(newStride);
newShapeB.unshift(1);
}
return [newStrideA, newStrideB, newShapeA, newShapeB];
}
// Broadcast shapes
static broadcastShapes(shapeA, shapeB) {
const newShape = new Array(shapeA.length);
for (let index = 0; index < shapeA.length; index++) {
if (shapeA[index] === 1) {
newShape[index] = shapeB[index];
}
else if (shapeB[index] === 1) {
newShape[index] = shapeA[index];
}
else if (shapeA[index] === shapeB[index]) {
newShape[index] = shapeA[index];
}
else {
throw new Error(`Cannot broadcast shapes: ${shapeA} and ${shapeB}`);
}
}
return newShape;
}
// Utility to convert flat index to array of coordinates
static indexToCoords(index, strides) {
const coords = new Array(strides.length);
let remaining = index;
for (let dim = 0; dim < strides.length; dim++) {
coords[dim] = Math.floor(remaining / strides[dim]);
remaining %= strides[dim];
}
return coords;
}
// Utility to convert array of coordinates to *unbroadcasted* flat index
static coordsToUnbroadcastedIndex(coords, shape, strides) {
let index = 0;
for (let i = 0; i < coords.length; i++) {
// Handle broadcasting
const actualCoord = shape[i] === 1 ? 0 : coords[i];
index += actualCoord * strides[i];
}
return index;
}
// Utility to convert array of coordinates to flat index
static coordsToIndex(coords, strides) {
let index = 0;
for (let i = 0; i < coords.length; i++) {
index += coords[i] * strides[i];
}
return index;
}
// Utility to convert shape into 1D value array size
static shapeToSize(shape) {
let prod = 1;
for (let i = 0; i < shape.length; i++) {
prod *= shape[i];
}
return prod;
}
;
// Utility for binary (two operators involved) element-wise ops
static elementWiseAB(tA, tB, op) {
if (typeof tA.value === "number" && typeof tB.value === "number") {
return new Tensor(op(tA.value, tB.value));
}
if (typeof tA.value === "number") {
return Tensor.elementWiseSelf(tB, (a) => op(a, tA.value));
}
if (typeof tB.value === "number") {
return Tensor.elementWiseSelf(tA, (a) => op(a, tB.value));
}
// Pad + broadcast shape
const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape);
const outputShape = Tensor.broadcastShapes(paddedAShape, paddedBShape);
// Get other output info
const outputStrides = Tensor.getStrides(outputShape);
const outputSize = Tensor.shapeToSize(outputShape);
const outputValue = new Array(outputSize);
for (let i = 0; i < outputSize; i++) {
// Get coordinates from 1D index
const coordsOutput = Tensor.indexToCoords(i, outputStrides);
// Convert the coordinates to 1D index of flattened A with respect to A's shape
const indexA = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedAShape, paddedAStrides);
// Convert the coordinates to 1D index of flattened B with respect to B's shape
const indexB = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides);
// Calculate with op
outputValue[i] = op(tA.value[indexA], tB.value[indexB]);
}
return new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides
});
}
// Utility for self-inflicting element-wise ops
static elementWiseSelf(tA, op) {
if (typeof tA.value === "number")
return new Tensor(op(tA.value));
const newValue = new Array(tA.value.length);
for (let index = 0; index < tA.value.length; index++) {
newValue[index] = op(tA.value[index]);
}
return new Tensor(newValue, { shape: tA.shape, strides: tA.strides });
}
// Utility to do element-wise operation and build a dag node with another tensor
elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
other = Tensor.forceTensor(other);
const out = Tensor.elementWiseAB(this, other, op);
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
}
if (other.requiresGrad) {
out.requiresGrad = true;
out.children.push(other);
}
if (out.requiresGrad) {
out.gradFn = () => {
// Disable gradient collecting of gradients themselves
const outGrad = out.grad;
const selfNoGrad = this.detach();
const otherNoGrad = other.detach();
if (this.requiresGrad)
Tensor.addGrad(this, thisGrad(selfNoGrad, otherNoGrad, outGrad));
if (other.requiresGrad)
Tensor.addGrad(other, otherGrad(selfNoGrad, otherNoGrad, outGrad));
};
}
return out;
}
// Utility to do self-inflicting element-wise operation and build a dag node
elementWiseSelfDAG(op, thisGrad = () => new Tensor(0)) {
const out = Tensor.elementWiseSelf(this, op);
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
}
if (out.requiresGrad) {
out.gradFn = () => {
// Disable gradient collecting of gradients themselves
const outGrad = out.grad;
const selfNoGrad = this.detach();
if (this.requiresGrad)
Tensor.addGrad(this, thisGrad(selfNoGrad, outGrad));
};
}
return out;
}
// Utility to force an input value to be a tensor
static forceTensor(value) {
if (value instanceof Tensor)
return value;
return new Tensor(value);
}
// Utility to add to gradient of tensor
static addGrad(tensor, accumGrad) {
const axesToSqueeze = [];
const axesToReduce = [];
const shape = tensor.shape;
const gradShape = accumGrad.shape;
const paddedDims = gradShape.length - shape.length;
for (let i = 0; i < paddedDims; i++) {
axesToReduce.push(i);
axesToSqueeze.push(i);
}
for (let i = 0; i < shape.length; i++) {
if (shape[i] === 1 && gradShape[i + paddedDims] > 1) {
axesToReduce.push(i + paddedDims);
}
}
const reducedGrad = accumGrad.sum(axesToReduce, true);
const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
if (typeof tensor.grad === "undefined") {
tensor.grad = squeezedGrad;
}
else {
tensor.grad = tensor.grad.add(squeezedGrad);
}
}
// Contiguity-related ops
isContiguous() {
const expectedStrides = Tensor.getStrides(this.shape);
if (expectedStrides.length !== this.strides.length) {
return false;
}
for (let i = 0; i < this.strides.length; i++) {
if (this.strides[i] !== expectedStrides[i]) {
return false;
}
}
return true;
}
contiguous() {
// Check if scalar
if (typeof this.value === "number")
return this;
// Check if already contiguous
if (this.isContiguous())
return this;
const outputStrides = Tensor.getStrides(this.shape);
const outputSize = Tensor.shapeToSize(this.shape);
const outputValue = new Array(outputSize);
for (let index = 0; index < outputSize; index++) {
const outputCoords = Tensor.indexToCoords(index, outputStrides);
const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides);
outputValue[index] = this.value[originalIndex];
}
const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides });
// Gradient flow back to the original tensor
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
Tensor.addGrad(this, out.grad);
};
}
return out;
}
reshape(newShape) {
// Verify shape size
const originalSize = Tensor.shapeToSize(this.shape);
const outputSize = Tensor.shapeToSize(newShape);
if (originalSize !== outputSize) {
throw new Error("Cannot reshape: incompatible sizes");
}
const outputStrides = Tensor.getStrides(newShape);
const out = new Tensor(this.contiguous().value, { shape: newShape, strides: outputStrides });
// Gradient reshaped and flow back to the original tensor
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
Tensor.addGrad(this, out.grad.reshape(this.shape));
};
}
return out;
}
// Tensor squeeze
squeeze(dims) {
if (typeof this.value === "number")
return this;
if (typeof dims === "number") {
dims = [dims];
}
if (typeof dims === "undefined") {
const shape = this.shape;
dims = [];
for (let index = 0; index < shape.length; index++) {
if (shape[index] === 1) {
dims.push(index);
}
}
}
// Remove size-1 dims only
const outShape = [], outStrides = [];
for (let index = 0; index < this.shape.length; index++) {
const dim = this.shape[index];
const stride = this.strides[index];
if (dims.includes(index)) {
if (dim !== 1)
throw new Error(`Can not squeeze dim with size ${dim}`);
}
else {
outShape.push(dim);
outStrides.push(stride);
}
}
const outValue = outShape.length === 0 ? this.value[0] : this.value;
const out = new Tensor(outValue, {
shape: outShape,
strides: outStrides,
device: this.device
});
// Set up gradient if needed
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
let restoredGrad = out.grad;
for (let i = dims.length - 1; i >= 0; i--) {
restoredGrad = restoredGrad.unsqueeze(dims[i]);
}
Tensor.addGrad(this, restoredGrad);
};
}
return out;
}
// Tensor unsqueeze - adds dimension of size 1 at specified position
unsqueeze(dim) {
let thisValue = this.value;
if (typeof thisValue === "number") {
thisValue = [thisValue];
}
// Insert size-1 dimension at specified position
const newShape = [...this.shape];
newShape.splice(dim, 0, 1);
// New stride
const newStrides = [...this.strides];
let newDimStride;
if (dim >= this.shape.length) {
// Inserting at the back: use 1
newDimStride = 1;
}
else {
// Inserting before dim: use current stride * current shape
newDimStride = this.strides[dim] * this.shape[dim];
}
newStrides.splice(dim, 0, newDimStride);
const out = new Tensor(thisValue, { shape: newShape, strides: newStrides, device: this.device });
// Set up gradient if needed
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
Tensor.addGrad(this, out.grad.squeeze(dim));
};
}
return out;
}
// Tensor sum reduction
sum(dims, keepDims = false) {
if (typeof this.value === "number")
return this;
if (typeof dims === "undefined") {
dims = Array.from({ length: this.shape.length }, (_, index) => index);
}
if (Array.isArray(dims)) {
// Sort in descending order
const sortedDims = dims.sort((a, b) => b - a);
let reducedThis = this;
for (let i = 0; i < sortedDims.length; i++) {
reducedThis = reducedThis.sum(sortedDims[i], true);
}
return keepDims ? reducedThis : reducedThis.squeeze(dims);
}
// Dims that are reduced now have size-1
const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
const outputStrides = Tensor.getStrides(outputShape);
const outputSize = Tensor.shapeToSize(outputShape);
const outputValue = new Array(outputSize).fill(0);
const originalSize = Tensor.shapeToSize(this.shape);
// Gradient data
let gradShape, gradStrides, gradValue = [];
// Allocate gradient data only when needed
if (this.requiresGrad) {
gradShape = this.shape;
gradStrides = this.strides;
gradValue = new Array(originalSize).fill(0);
}
// Calculate new value after sum
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// Add into sum
outputValue[outFlatIndex] += this.value[realFlatIndex];
// Mark for gradient if needed
if (this.requiresGrad) {
gradValue[realFlatIndex] = 1;
}
}
const out = new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides
});
// Set up gradient if needed
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
Tensor.addGrad(this, out.grad.mul(localGrad));
};
}
return keepDims ? out : out.squeeze(dims);
}
// Tensor product reduction
prod(dims, keepDims = false) {
if (typeof this.value === "number")
return this;
if (typeof dims === "undefined") {
dims = Array.from({ length: this.shape.length }, (_, index) => index);
}
if (Array.isArray(dims)) {
// Sort in descending order
const sortedDims = dims.sort((a, b) => b - a);
let reducedThis = this;
for (let i = 0; i < sortedDims.length; i++) {
reducedThis = reducedThis.prod(sortedDims[i], true);
}
return keepDims ? reducedThis : reducedThis.squeeze(dims);
}
// Dims that are reduced now have size-1
const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
const outputStrides = Tensor.getStrides(outputShape);
const outputSize = Tensor.shapeToSize(outputShape);
const outputValue = new Array(outputSize).fill(1);
const originalSize = Tensor.shapeToSize(this.shape);
// Calculate new value after multiplying
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// Multiply into product
outputValue[outFlatIndex] *= this.value[realFlatIndex];
}
const out = new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides
});
// Set up gradient if needed
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// Grad is the product of other elements of the same axis, which is product of all els divided by the current value
gradValue[realFlatIndex] = outputValue[outFlatIndex] / this.value[realFlatIndex];
}
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
Tensor.addGrad(this, out.grad.mul(localGrad));
};
}
return keepDims ? out : out.squeeze(dims);
}
// Tensor mean reduction
mean(dims, keepDims = false) {
if (typeof this.value === "number")
return this;
if (typeof dims === "undefined") {
dims = Array.from({ length: this.shape.length }, (_, index) => index);
}
if (Array.isArray(dims)) {
// Sort in descending order
const sortedDims = dims.sort((a, b) => b - a);
let reducedThis = this;
for (let i = 0; i < sortedDims.length; i++) {
reducedThis = reducedThis.mean(sortedDims[i], true);
}
return keepDims ? reducedThis : reducedThis.squeeze(dims);
}
// Dims that are reduced now have size-1
const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
const outputStrides = Tensor.getStrides(outputShape);
const outputSize = Tensor.shapeToSize(outputShape);
const outputValue = new Array(outputSize).fill(0);
const outputFeeders = new Array(outputSize).fill(0);
const originalSize = Tensor.shapeToSize(this.shape);
// Calculate sums and how many elements contribute to specific positions
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// Calculate sum and contributors to the sum
outputValue[outFlatIndex] += this.value[realFlatIndex];
outputFeeders[outFlatIndex]++;
}
// Calculate mean by dividing sum by the number of contributors to the position
for (let index = 0; index < outputSize; index++) {
outputValue[index] /= outputFeeders[index];
}
const out = new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides
});
// Set up gradient if needed
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
// Calculate grad by assigning 1 divided by the number of contributors to the position
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// Mean = 1/n * (el1 + el2 + ... + eln) so grad = 1/n
gradValue[realFlatIndex] = 1 / outputFeeders[outFlatIndex];
}
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
Tensor.addGrad(this, out.grad.mul(localGrad));
};
}
return keepDims ? out : out.squeeze(dims);
}
// Tensor maximum reduction
max(dims, keepDims = false) {
if (typeof this.value === "number")
return this;
if (typeof dims === "undefined") {
dims = Array.from({ length: this.shape.length }, (_, index) => index);
}
if (Array.isArray(dims)) {
// Sort in descending order
const sortedDims = dims.sort((a, b) => b - a);
let reducedThis = this;
for (let i = 0; i < sortedDims.length; i++) {
reducedThis = reducedThis.max(sortedDims[i], true);
}
return keepDims ? reducedThis : reducedThis.squeeze(dims);
}
// Dims that are reduced now have size-1
const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
const outputStrides = Tensor.getStrides(outputShape);
const outputSize = Tensor.shapeToSize(outputShape);
const outputValue = new Array(outputSize).fill(-Infinity);
const originalSize = Tensor.shapeToSize(this.shape);
// Calculate maximum values of axes
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// Get max over time
if (this.value[realFlatIndex] > outputValue[outFlatIndex]) {
outputValue[outFlatIndex] = this.value[realFlatIndex];
}
}
const out = new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides
});
// Set up gradient if needed
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
const shareCounts = new Array(outputSize).fill(0);
const originalValue = this.value;
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// We collect how many elements share the same max value first
shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
}
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// Here we share the grad between the elements that share the same max value
gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
}
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
Tensor.addGrad(this, out.grad.mul(localGrad));
};
}
return keepDims ? out : out.squeeze(dims);
}
// Tensor minimum reduction
min(dims, keepDims = false) {
if (typeof this.value === "number")
return this;
if (typeof dims === "undefined") {
dims = Array.from({ length: this.shape.length }, (_, index) => index);
}
if (Array.isArray(dims)) {
// Sort in descending order
const sortedDims = dims.sort((a, b) => b - a);
let reducedThis = this;
for (let i = 0; i < sortedDims.length; i++) {
reducedThis = reducedThis.min(sortedDims[i], true);
}
return keepDims ? reducedThis : reducedThis.squeeze(dims);
}
// Dims that are reduced now have size-1
const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
const outputStrides = Tensor.getStrides(outputShape);
const outputSize = Tensor.shapeToSize(outputShape);
const outputValue = new Array(outputSize).fill(Infinity);
const originalSize = Tensor.shapeToSize(this.shape);
// Calculate minimum values of axes
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// Get min over time
if (this.value[realFlatIndex] < outputValue[outFlatIndex]) {
outputValue[outFlatIndex] = this.value[realFlatIndex];
}
}
const out = new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides
});
// Set up gradient if needed
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
const shareCounts = new Array(outputSize).fill(0);
const originalValue = this.value;
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// We collect how many elements share the same min value first
shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
}
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
// Force 0 on reduced axes to collapse into size-1 dims
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert output coordinates to flat index
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
// Here we share the grad between the elements that share the same min value
gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
}
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
Tensor.addGrad(this, out.grad.mul(localGrad));
};
}
return keepDims ? out : out.squeeze(dims);
}
// Tensor variance reduction
var(dims, keepDims = false) {
const meanXSquared = this.square().mean(dims, keepDims);
const meanXSquaredExpanded = this.mean(dims, keepDims).square();
return meanXSquared.sub(meanXSquaredExpanded);
}
// Tensor standard deviation reduction
std(dims, keepDims = false) {
return this.var(dims, keepDims).sqrt();
}
// Tensor product reduction
softmax(dims) {
if (typeof this.value === "number")
return this;
if (typeof dims === "undefined") {
dims = Array.from({ length: this.shape.length }, (_, index) => index);
}
if (Array.isArray(dims)) {
// Sort in descending order
const sortedDims = dims.sort((a, b) => b - a);
let reducedThis = this;
for (let i = 0; i < sortedDims.length; i++) {
reducedThis = reducedThis.softmax(sortedDims[i]);
}
return reducedThis;
}
// Dims that are reduced now have size-1
const expSumShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
const expSumStrides = Tensor.getStrides(expSumShape);
const expSumSize = Tensor.shapeToSize(expSumShape);
const expSumValue = new Array(expSumSize).fill(0);
const outputShape = this.shape;
const outputStrides = this.strides;
const outputSize = Tensor.shapeToSize(outputShape);
const outputValue = new Array(outputSize);
// Calculate sums of e^xi over axes
for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
// Force 0 on reduced axes to collapse into size-1 dims
const expSumCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert exp sum coordinates to flat index
const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
// Add e^x to the sum cache
expSumValue[expSumFlatIndex] += Math.exp(this.value[realFlatIndex]);
}
// Calculate e^xi / sum over axes
for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
// Force 0 on reduced axes to collapse into size-1 dims
const expSumCoords = coords.map((val, i) => dims === i ? 0 : val);
// Convert exp sum coordinates to flat index
const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
// Calculate e^xi / sum
outputValue[realFlatIndex] = Math.exp(this.value[realFlatIndex]) / expSumValue[expSumFlatIndex];
}
const out = new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides
});
// Set up gradient if needed
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
const upstreamGrad = out.grad;
const softmaxOutput = out.detach();
// Compute element-wise product: ∂L/∂σᵢ × σᵢ
const gradTimesOutput = upstreamGrad.mul(softmaxOutput);
// Sum over softmax dimensions: Σᵢ(∂L/∂σᵢ × σᵢ)
const sumGradOutput = gradTimesOutput.sum(dims, true); // keepDims=true for broadcasting
// Apply softmax gradient formula:
// ∂L/∂zⱼ = (∂L/∂σⱼ × σⱼ) - (σⱼ × Σᵢ(∂L/∂σᵢ × σᵢ))
const term1 = upstreamGrad.mul(softmaxOutput); // ∂L/∂σⱼ × σⱼ
const term2 = softmaxOutput.mul(sumGradOutput); // σⱼ × Σᵢ(∂L/∂σᵢ × σᵢ)
const localGrad = term1.sub(term2);
Tensor.addGrad(this, localGrad);
};
}
return out;
}
// Tensor element-wise addition
add(other) {
return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad);
}
// Tensor element-wise subtraction
sub(other) {
return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad.neg());
}
subtract = this.sub;
// Tensor element-wise multiplication
mul(other) {
return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => outGrad.mul(other), (self, other, outGrad) => outGrad.mul(self));
}
multiply = this.mul;
// Tensor element-wise power
pow(other) {
return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => outGrad.mul(other.mul(self.pow(other.sub(1)))), (self, other, outGrad) => outGrad.mul(self.pow(other).mul(self.log())));
}
// Tensor element-wise division
div(other) {
return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => outGrad.div(other), (self, other, outGrad) => outGrad.mul(self.neg().div(other.square())));
}
divide = this.div;
// Tensor element-wise modulo
remainder(other) {
return this.elementWiseABDAG(other, (a, b) => a % b);
}
// Tensor element-wise greater or equal comparison
ge(other) {
return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0);
}
greaterEqual = this.ge;
// Tensor element-wise less or equal comparison
le(other) {
return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0);
}
lessEqual = this.le;
// Tensor element-wise greater-than comparison
gt(other) {
return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0);
}
greater = this.gt;
// Tensor element-wise less-than comparison
lt(other) {
return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0);
}
less = this.lt;
// Tensor element-wise equality comparison
eq(other) {
return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0);
}
equal = this.eq;
// Tensor element-wise not equality comparison
ne(other) {
return this.elementWiseABDAG(other, (a, b) => a !== b ? 1 : 0);
}
notEqual = this.ne;
// Tensor element-wise logical and
logicalAnd(other) {
return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0);
}
// Tensor element-wise logical or
logicalOr(other) {
return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0);
}
// Tensor element-wise logical xor
logicalXor(other) {
return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0);
}
// Tensor element-wise logical not
logicalNot() {
return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1);
}
// Tensor element-wise bitwise and
bitwiseAnd(other) {
return this.elementWiseABDAG(other, (a, b) => a & b);
}
// Tensor element-wise bitwise or
bitwiseOr(other) {
return this.elementWiseABDAG(other, (a, b) => a | b);
}
// Tensor element-wise bitwise xor
bitwiseXor(other) {
return this.elementWiseABDAG(other, (a, b) => a ^ b);
}
// Tensor element-wise bitwise not
bitwiseNot() {
return this.elementWiseSelfDAG((a) => ~a);
}
// Tensor element-wise left shift
bitwiseLeftShift(other) {
return this.elementWiseABDAG(other, (a, b) => a << b);
}
// Tensor element-wise right shift
bitwiseRightShift(other) {
return this.elementWiseABDAG(other, (a, b) => a >> b);
}
// Tensor element-wise negation
neg() {
return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => outGrad.mul(-1));
}
negative = this.neg;
// Tensor element-wise reciprocal
reciprocal() {
return this.elementWiseSelfDAG((a) => 1 / a, (self, outGrad) => outGrad.mul(self.pow(-2).neg()));
}
// Tensor element-wise square
square() {
return this.elementWiseSelfDAG((a) => a * a, (self, outGrad) => outGrad.mul(self.mul(2)));
}
// Tensor element-wise absolute
abs() {
return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => outGrad.mul(self.sign()));
}
absolute = this.abs;
// Tensor element-wise sign function
sign() {
return this.elementWiseSelfDAG((a) => Math.sign(a));
}
// Tensor element-wise sin
sin() {
return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => outGrad.mul(self.cos()));
}
// Tensor element-wise cos
cos() {
return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => outGrad.mul(self.sin().neg()));
}
// Tensor element-wise tan
tan() {
return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => outGrad.mul(self.tan().square().add(1)));
}
// Tensor element-wise asin
asin() {
return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()));
}
arcsin = this.asin;
// Tensor element-wise acos
acos() {
return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()).neg());
}
arccos = this.acos;
// Tensor element-wise atan
atan() {
return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => outGrad.div(self.square().add(1)));
}
arctan = this.atan;
// Tensor element-wise atan2
atan2(other) {
return this.elementWiseABDAG(other, (a, b) => Math.atan2(a, b), (self, other, outGrad) => outGrad.mul(other.div(self.square().add(other.square()))), (self, other, outGrad) => outGrad.mul(self.neg().div(self.square().add(other.square()))));
}
arctan2 = this.atan2;
// Tensor element-wise sinh
sinh() {
return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => outGrad.mul(self.cosh()));
}
// Tensor element-wise cosh
cosh() {
return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => outGrad.mul(self.sinh()));
}
// Tensor element-wise asinh
asinh() {
return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => outGrad.div(self.square().add(1).sqrt()));
}
arcsinh = this.asinh;
// Tensor element-wise acosh
acosh() {
return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
}
arccosh = this.acosh;
// Tensor element-wise atanh
atanh() {
return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => outGrad.div(self.square().neg().add(1)));
}
arctanh = this.atanh;
// Tensor element-wise degree to radian
deg2rad() {
return this.elementWiseSelfDAG((a) => a * (Math.PI / 180), (self, outGrad) => outGrad.mul(Math.PI / 180));
}
// Tensor element-wise radian to degree
rad2deg() {
return this.elementWiseSelfDAG((a) => a / (Math.PI / 180), (self, outGrad) => outGrad.div(Math.PI / 180));
}
// Tensor element-wise square root
sqrt() {
return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => outGrad.div(self.sqrt().mul(2)));
}
// Tensor element-wise reciprocal of square root
rsqrt() {
return this.elementWiseSelfDAG((a) => 1 / Math.sqrt(a), (self, outGrad) => outGrad.mul(self.pow(-1.5).mul(-0.5)));
}
// Tensor element-wise e^x
exp() {
return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) => outGrad.mul(self.exp()));
}
// Tensor element-wise 2^x
exp2() {
return this.elementWiseSelfDAG((a) => 2 ** a, (self, outGrad) => outGrad.mul(self.exp2().mul(Math.log(2))));
}
// Tensor element-wise e^x - 1
expm1() {
return this.elementWiseSelfDAG((a) => Math.expm1(a), (self, outGrad) => outGrad.mul(self.exp()));
}
// Tensor element-wise natural log
log() {
return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) => outGrad.div(self));
}
// Tensor element-wise log2
log2() {
return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) => outGrad.div(self.mul(Math.log(2))));
}
// Tensor element-wise log10
log10() {
return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) => outGrad.div(self.mul(Math.log(10))));
}
// Tensor element-wise log(1+x)
log1p() {
return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) => outGrad.div(self.add(1)));
}
// Tensor element-wise relu
relu() {
return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) => outGrad.mul(self.gt(0)));
}
// Tensor element-wise sigmoid
sigmoid() {
return this.elementWiseSelfDAG((a) => 1 / (1 + Math.exp(-a)), (self, outGrad) => {
const sig = self.sigmoid();
return outGrad.mul(sig).mul(sig.neg().add(1));
});
}
// Tensor element-wise tanh
tanh() {
return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) => outGrad.mul(self.tanh().square().neg().add(1)));
}
// Tensor element-wise softplus
softplus() {
return this.elementWiseSelfDAG((a) => Math.log1p(Math.exp(a)), (self, outGrad) => outGrad.mul(self.sigmoid()));
}
// Tensor element-wise softsign
softsign() {
return this.elementWiseSelfDAG((a) => a / (1 + Math.abs(a)), (self, outGrad) => outGrad.div(self.abs().add(1).square()));
}
// Tensor element-wise silu (swish)
silu() {
return this.elementWiseSelfDAG((a) => a / (1 + Math.exp(-a)), (self, outGrad) => {
const sig = self.sigmoid();
return outGrad.mul(sig.add(self.mul(sig).mul(sig.neg().add(1))));
});
}
// Tensor element-wise mish
mish() {
return this.elementWiseSelfDAG((a) => a * Math.tanh(Math.log1p(Math.exp(a))), (self, outGrad) => {
const tanhSoftPlus = self.exp().add(1).log().tanh();
// tanh(softplus(x)) + x * (1 - tanh²(softplus(x))) * sigmoid(x)
const derivative = tanhSoftPlus.add(self.mul(tanhSoftPlus.square().neg().add(1)).mul(self.sigmoid()));
return outGrad.mul(derivative);
});
}
// Tensor element-wise gelu
gelu(approximate = "none") {
if (approximate === "none") {
return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + (0, utils_1.erf)(a / Math.sqrt(2))), (self, outGrad) => {
const sqrt2 = Math.sqrt(2);
const sqrt2OverPi = Math.sqrt(2 / Math.PI);
const xOverSqrt2 = self.div(sqrt2);
const erfVal = xOverSqrt2.erf();
const phi = xOverSqrt2.square().neg().exp().div(sqrt2OverPi);
const derivative = erfVal.add(1).mul(0.5).add(self.mul(phi));
return outGrad.mul(derivative);
});
}
else if (approximate === "tanh") {
return this.elementWiseSelfDAG((a) => 0.5 * a * (1 + Math.tanh(Math.sqrt(2 / Math.PI) * (a + 0.044715 * a * a * a))), (self, outGrad) => {
const sqrt2OverPi = Math.sqrt(2 / Math.PI);
const c = 0.044715;
const tanhArg = self.add(self.pow(3).mul(c)).mul(sqrt2OverPi);
const tanhVal = tanhArg.tanh();
const sechSquared = tanhVal.square().neg().add(1);
const term1 = tanhVal.add(1).mul(0.5);
const term2 = self.mul(sechSquared).mul(sqrt2OverPi).mul(self.square().mul(c * 3).add(1)).mul(0.5);
const derivative = term1.add(term2);
return outGrad.mul(derivative);
});
}
throw new Error("Specified approximation does not exist");
}
// Tensor element-wise maximum
maximum(other) {
return this.elementWiseABDAG(other, (a, b) => Math.max(a, b), (self, other, outGrad) => outGrad.mul(self.gt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.gt(self).add(other.eq(self).mul(0.5))));
}
// Tensor element-wise minimum
minimum(other) {
return this.elementWiseABDAG(other, (a, b) => Math.min(a, b), (self, other, outGrad) => outGrad.mul(self.lt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.lt(self).add(other.eq(self).mul(0.5))));
}
// Tensor element-wise round
round() {
return this.elementWiseSelfDAG((a) => Math.round(a));
}
// Tensor element-wise floor
floor() {
return this.elementWiseSelfDAG((a) => Math.floor(a));
}
// Tensor element-wise ceil
ceil() {
return this.elementWiseSelfDAG((a) => Math.ceil(a));
}
// Tensor element-wise truncation
trunc() {
return this.elementWiseSelfDAG((a) => Math.trunc(a));
}
fix = this.trunc;
// Tensor element-wise fraction portion
frac() {
return this.elementWiseSelfDAG((a) => a - Math.floor(a));
}
// Tensor element-wise clip and clamp
clip(min, max) {
return this.elementWiseSelfDAG((a) => Math.max(min, Math.min(max, a)), (self, outGrad) => outGrad.mul(self.ge(min).mul(self.le(max))));
}
clamp = this.clip;
// Tensor element-wise error function
erf() {
return this.elementWiseSelfDAG((a) => (0, utils_1.erf)(a), (self, outGrad) => outGrad.mul(self.square().neg().exp().mul(2 / Math.sqrt(Math.PI))));
}
// Tensor element-wise complementary error function
erfc() {
return this.elementWiseSelfDAG((a) => (0, utils_1.erfc)(a), (self, outGrad) => outGrad.mul(