catniff
Version: 
Torch-like deep learning framework for Javascript
1,242 lines • 86.1 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.Tensor = void 0;
const utils_1 = require("./utils");
class Tensor {
    value;
    shape;
    strides;
    offset;
    numel;
    grad;
    requiresGrad;
    gradFn;
    children;
    device;
    static training = false;
    static noGrad = false;
    static createGraph = false;
    constructor(value, options = {}) {
        // Storage
        this.value = Tensor.flattenValue(value);
        // Tensor metadata
        this.shape = options.shape || Tensor.getShape(value);
        this.strides = options.strides || Tensor.getStrides(this.shape);
        this.offset = options.offset || 0;
        this.numel = options.numel || Tensor.shapeToSize(this.shape);
        this.device = options.device || "cpu";
        // Autograd data
        this.grad = options.grad;
        this.requiresGrad = options.requiresGrad ?? false;
        this.gradFn = options.gradFn || (() => { });
        this.children = options.children || [];
        // Move to device in-place
        this.to_(this.device);
    }
    // Utility to flatten an nD array to be 1D
    static flattenValue(tensor) {
        // Handle scalar tensors
        if (typeof tensor === "number")
            return tensor;
        // If value is already 1D, we just need to return the value ('s reference)
        if (typeof tensor[0] === "number")
            return tensor;
        // Or else recursively traverse through the nD array to flatten
        const result = [];
        function traverse(arr) {
            if (typeof arr === "number") {
                result.push(arr);
            }
            else if (Array.isArray(arr)) {
                arr.forEach(traverse);
            }
        }
        traverse(tensor);
        return result;
    }
    // Utility to get shape from tensor *value*
    static getShape(tensor) {
        const shape = [];
        let subA = tensor;
        while (Array.isArray(subA)) {
            shape.push(subA.length);
            subA = subA[0];
        }
        return shape;
    }
    // Utility to get strides from shape
    static getStrides(shape) {
        if (shape.length === 0)
            return [];
        const strides = new Array(shape.length);
        strides[strides.length - 1] = 1;
        for (let i = strides.length - 2; i >= 0; i--) {
            strides[i] = strides[i + 1] * shape[i + 1];
        }
        return strides;
    }
    // Left-pad shape and strides for two shape to be of same length
    static padShape(stridesA, stridesB, shapeA, shapeB) {
        const newStrideA = [...stridesA], newStrideB = [...stridesB];
        const newShapeA = [...shapeA], newShapeB = [...shapeB];
        while (newStrideA.length < newStrideB.length) {
            const newStride = newShapeA[0] * newStrideA[0];
            newStrideA.unshift(newStride);
            newShapeA.unshift(1);
        }
        while (newStrideA.length > newStrideB.length) {
            const newStride = newShapeB[0] * newStrideB[0];
            newStrideB.unshift(newStride);
            newShapeB.unshift(1);
        }
        return [newStrideA, newStrideB, newShapeA, newShapeB];
    }
    // Broadcast shapes
    static broadcastShapes(shapeA, shapeB) {
        const newShape = new Array(shapeA.length);
        for (let index = 0; index < shapeA.length; index++) {
            if (shapeA[index] === 1) {
                newShape[index] = shapeB[index];
            }
            else if (shapeB[index] === 1) {
                newShape[index] = shapeA[index];
            }
            else if (shapeA[index] === shapeB[index]) {
                newShape[index] = shapeA[index];
            }
            else {
                throw new Error(`Can not broadcast shapes: ${shapeA} and ${shapeB}`);
            }
        }
        return newShape;
    }
    // Utility to convert flat index to array of coordinates
    static indexToCoords(index, strides) {
        const coords = new Array(strides.length);
        let remaining = index;
        for (let dim = 0; dim < strides.length; dim++) {
            coords[dim] = Math.floor(remaining / strides[dim]);
            remaining %= strides[dim];
        }
        return coords;
    }
    // Utility to convert array of coordinates to *unbroadcasted* flat index 
    static coordsToUnbroadcastedIndex(coords, shape, strides) {
        let index = 0;
        for (let i = 0; i < coords.length; i++) {
            // Handle broadcasting
            const actualCoord = shape[i] === 1 ? 0 : coords[i];
            index += actualCoord * strides[i];
        }
        return index;
    }
    // Utility to convert array of coordinates to flat index 
    static coordsToIndex(coords, strides) {
        let index = 0;
        for (let i = 0; i < coords.length; i++) {
            index += coords[i] * strides[i];
        }
        return index;
    }
    // Utility to convert shape into 1D value array size
    static shapeToSize(shape) {
        let prod = 1;
        for (let i = 0; i < shape.length; i++) {
            prod *= shape[i];
        }
        return prod;
    }
    ;
    // Utility for binary (two operators involved) element-wise ops
    static elementWiseAB(tA, tB, op) {
        if (typeof tA.value === "number" && typeof tB.value === "number") {
            return new Tensor(op(tA.value, tB.value));
        }
        if (typeof tA.value === "number") {
            return Tensor.elementWiseSelf(tB, (a) => op(a, tA.value));
        }
        if (typeof tB.value === "number") {
            return Tensor.elementWiseSelf(tA, (a) => op(a, tB.value));
        }
        // Pad + broadcast shape
        const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape);
        const outputShape = Tensor.broadcastShapes(paddedAShape, paddedBShape);
        // Get other output info
        const outputStrides = Tensor.getStrides(outputShape);
        const outputSize = Tensor.shapeToSize(outputShape);
        const outputValue = new Array(outputSize);
        for (let i = 0; i < outputSize; i++) {
            // Get coordinates from 1D index
            const coordsOutput = Tensor.indexToCoords(i, outputStrides);
            // Convert the coordinates to 1D index of flattened A with respect to A's shape
            const indexA = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedAShape, paddedAStrides);
            // Convert the coordinates to 1D index of flattened B with respect to B's shape
            const indexB = Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides);
            // Calculate with op
            outputValue[i] = op(tA.value[indexA + tA.offset], tB.value[indexB + tB.offset]);
        }
        return new Tensor(outputValue, {
            shape: outputShape,
            strides: outputStrides,
            numel: outputSize
        });
    }
    // Utility for self-inflicting element-wise ops
    static elementWiseSelf(tA, op) {
        if (typeof tA.value === "number")
            return new Tensor(op(tA.value));
        const contiguous = tA.isContiguous();
        const outputShape = tA.shape;
        const outputStrides = contiguous ? tA.strides : Tensor.getStrides(outputShape);
        const outputSize = tA.numel;
        const outputValue = new Array(outputSize);
        if (contiguous) {
            for (let index = 0; index < outputSize; index++) {
                outputValue[index] = op(tA.value[index + tA.offset]);
            }
        }
        else {
            for (let index = 0; index < outputSize; index++) {
                const outputCoords = Tensor.indexToCoords(index, outputStrides);
                const originalIndex = tA.offset + Tensor.coordsToIndex(outputCoords, tA.strides);
                outputValue[index] = op(tA.value[originalIndex]);
            }
        }
        return new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: tA.numel });
    }
    // Utility to do element-wise operation and build a dag node with another tensor
    elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
        other = this.handleOther(other);
        const out = Tensor.elementWiseAB(this, other, op);
        if (this.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(this);
        }
        if (other.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(other);
        }
        if (out.requiresGrad) {
            out.gradFn = () => {
                const outGrad = out.grad;
                const selfWithGrad = Tensor.createGraph ? this : this.detach();
                const otherWithGrad = Tensor.createGraph ? other : other.detach();
                if (this.requiresGrad)
                    Tensor.addGrad(this, thisGrad(selfWithGrad, otherWithGrad, outGrad));
                if (other.requiresGrad)
                    Tensor.addGrad(other, otherGrad(selfWithGrad, otherWithGrad, outGrad));
            };
        }
        return out;
    }
    // Utility to do self-inflicting element-wise operation and build a dag node
    elementWiseSelfDAG(op, thisGrad = () => new Tensor(0)) {
        const out = Tensor.elementWiseSelf(this, op);
        if (this.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(this);
        }
        if (out.requiresGrad) {
            out.gradFn = () => {
                const outGrad = out.grad;
                const selfWithGrad = Tensor.createGraph ? this : this.detach();
                if (this.requiresGrad)
                    Tensor.addGrad(this, thisGrad(selfWithGrad, outGrad));
            };
        }
        return out;
    }
    // Utility to handle other tensor if an op needs a second operand
    handleOther(other) {
        if (other instanceof Tensor) {
            if (this.device !== other.device) {
                throw new Error("Can not operate on tensors that are not on the same device");
            }
            return other;
        }
        return new Tensor(other, { device: this.device });
    }
    // Utility to add to gradient of tensor
    static addGrad(tensor, accumGrad) {
        const axesToSqueeze = [];
        const axesToReduce = [];
        const shape = tensor.shape;
        const gradShape = accumGrad.shape;
        const paddedDims = gradShape.length - shape.length;
        for (let i = 0; i < paddedDims; i++) {
            axesToReduce.push(i);
            axesToSqueeze.push(i);
        }
        for (let i = 0; i < shape.length; i++) {
            if (shape[i] === 1 && gradShape[i + paddedDims] > 1) {
                axesToReduce.push(i + paddedDims);
            }
        }
        const reducedGrad = accumGrad.sum(axesToReduce, true);
        const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
        if (typeof tensor.grad === "undefined") {
            tensor.grad = squeezedGrad;
        }
        else {
            tensor.grad = tensor.grad.add(squeezedGrad);
        }
    }
    static normalizeDims(dims, numDims) {
        for (let index = 0; index < dims.length; index++) {
            // Handle negative indices
            if (dims[index] < 0) {
                dims[index] += numDims;
            }
            // If dimension out of bound, throw error
            if (dims[index] >= numDims || dims[index] < 0) {
                throw new Error("Dimensions do not exist");
            }
        }
        return dims;
    }
    // Contiguity-related ops
    isContiguous() {
        const expectedStrides = Tensor.getStrides(this.shape);
        for (let i = 0; i < this.strides.length; i++) {
            if (this.strides[i] !== expectedStrides[i]) {
                return false;
            }
        }
        return true;
    }
    contiguous() {
        // Check if scalar
        if (typeof this.value === "number")
            return this;
        // Check if already contiguous
        if (this.isContiguous())
            return this;
        const outputStrides = Tensor.getStrides(this.shape);
        const outputSize = this.numel;
        const outputValue = new Array(outputSize);
        for (let index = 0; index < outputSize; index++) {
            const outputCoords = Tensor.indexToCoords(index, outputStrides);
            const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides);
            outputValue[index] = this.value[this.offset + originalIndex];
        }
        const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides, numel: outputSize });
        // Gradient flow back to the original tensor
        if (this.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(this);
            out.gradFn = () => {
                Tensor.addGrad(this, out.grad);
            };
        }
        return out;
    }
    view(newShape) {
        // Verify shape size
        const originalSize = this.numel;
        const outputSize = Tensor.shapeToSize(newShape);
        if (originalSize !== outputSize || typeof this.value === "number") {
            throw new Error("Can not create view: incompatible sizes");
        }
        // Verify compatibility (only contiguity for now)
        if (!this.isContiguous()) {
            throw new Error("Can not create view: incompatible metadata");
        }
        const outputStrides = Tensor.getStrides(newShape);
        const out = new Tensor(this.value, {
            shape: newShape,
            strides: outputStrides,
            offset: this.offset,
            numel: outputSize,
            device: this.device
        });
        // Gradient reshaped and flow back to the original tensor
        if (this.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(this);
            out.gradFn = () => {
                Tensor.addGrad(this, out.grad.reshape(this.shape));
            };
        }
        return out;
    }
    reshape(newShape) {
        return this.contiguous().view(newShape);
    }
    flatten(startDim = 0, endDim = -1) {
        // Handle negative indices
        if (startDim < 0) {
            startDim += this.shape.length;
        }
        if (endDim < 0) {
            endDim += this.shape.length;
        }
        // If dimension out of bound, throw error
        if (startDim >= this.shape.length || endDim >= this.shape.length || startDim < 0 || endDim < 0) {
            throw new Error("Dimensions do not exist to flatten");
        }
        const newShape = [];
        let middleSize = 1;
        for (let index = 0; index < this.shape.length; index++) {
            // Keep dims before startDim
            if (index < startDim) {
                newShape.push(this.shape[index]);
            }
            // Multiply dims from startDim to endDim
            if (index >= startDim && index <= endDim) {
                middleSize *= this.shape[index];
            }
            // Push new flatten middle
            if (index === endDim) {
                newShape.push(middleSize);
            }
            // Keep dims after endDim
            if (index > endDim) {
                newShape.push(this.shape[index]);
            }
        }
        return this.reshape(newShape);
    }
    // Transpose
    transpose(dim1, dim2) {
        // Handle negative indices
        if (dim1 < 0) {
            dim1 += this.shape.length;
        }
        if (dim2 < 0) {
            dim2 += this.shape.length;
        }
        // If dimension out of bound, throw error
        if (dim1 >= this.shape.length || dim2 >= this.shape.length || dim1 < 0 || dim2 < 0) {
            throw new Error("Dimensions do not exist to transpose");
        }
        // If same dimension, return view
        if (dim1 === dim2)
            return this;
        // Create new shape and strides by swapping
        const newShape = [...this.shape];
        const newStrides = [...this.strides];
        [newShape[dim1], newShape[dim2]] = [newShape[dim2], newShape[dim1]];
        [newStrides[dim1], newStrides[dim2]] = [newStrides[dim2], newStrides[dim1]];
        // Create new tensor with same data but swapped shape/strides
        const out = new Tensor(this.value, {
            shape: newShape,
            strides: newStrides,
            offset: this.offset,
            numel: this.numel,
            device: this.device
        });
        out.requiresGrad = this.requiresGrad;
        // Handle gradient if needed
        if (this.requiresGrad) {
            out.children.push(this);
            out.gradFn = () => {
                Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
            };
        }
        return out;
    }
    swapaxes = this.transpose;
    swapdims = this.transpose;
    // Transpose 2D
    t() {
        // Verify matrix shape
        if (this.shape.length !== 2) {
            throw new Error("Input is not a matrix");
        }
        return this.transpose(0, 1);
    }
    // Permute
    permute(dims) {
        dims = Tensor.normalizeDims(dims, this.shape.length);
        if (dims.length !== this.shape.length) {
            throw new Error("Permutation must specify all dimensions");
        }
        // Compute new shape and strides
        const newShape = new Array(dims.length);
        const newStrides = new Array(dims.length);
        for (let index = 0; index < dims.length; index++) {
            const dim = dims[index];
            newShape[index] = this.shape[dim];
            newStrides[index] = this.strides[dim];
        }
        const out = new Tensor(this.value, {
            shape: newShape,
            strides: newStrides,
            offset: this.offset,
            numel: this.numel,
            device: this.device
        });
        if (this.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(this);
            out.gradFn = () => {
                // Compute inverse permutation
                const inverseAxes = new Array(dims.length);
                for (let i = 0; i < dims.length; i++) {
                    inverseAxes[dims[i]] = i;
                }
                // Permute gradient back to original order
                const permutedGrad = out.grad.permute(inverseAxes);
                Tensor.addGrad(this, permutedGrad);
            };
        }
        return out;
    }
    // Utility for indexing with array of indices
    indexWithArray(indices) {
        if (typeof this.value === "number")
            return this;
        indices = Tensor.normalizeDims(indices, this.shape[0]);
        // Init necessary stuff for indexing
        const reducedShape = this.shape.slice(1);
        const reducedStrides = this.strides.slice(1);
        const elementsPerIndex = Tensor.shapeToSize(reducedShape);
        // Init output data
        const outputShape = [indices.length, ...reducedShape];
        const outputSize = Tensor.shapeToSize(outputShape);
        const outputValue = new Array(outputSize);
        for (let i = 0; i < indices.length; i++) {
            const sourceRowIndex = indices[i];
            const targetStart = i * elementsPerIndex;
            for (let j = 0; j < elementsPerIndex; j++) {
                const fullCoords = Tensor.indexToCoords(j, reducedStrides);
                fullCoords.unshift(sourceRowIndex);
                const sourceIndex = Tensor.coordsToIndex(fullCoords, this.strides);
                outputValue[targetStart + j] = this.value[this.offset + sourceIndex];
            }
        }
        const out = new Tensor(outputValue, {
            shape: outputShape,
            numel: outputSize
        });
        // Handle gradient
        if (this.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(this);
            out.gradFn = () => {
                const outGrad = out.grad;
                // Create zero gradient tensor with original shape
                const grad = Tensor.zerosLike(this);
                // Scatter gradients back to original positions
                for (let i = 0; i < indices.length; i++) {
                    const originalRowIndex = indices[i];
                    const sourceStart = i * elementsPerIndex;
                    for (let j = 0; j < elementsPerIndex; j++) {
                        const fullCoords = Tensor.indexToCoords(j, reducedStrides);
                        fullCoords.unshift(originalRowIndex);
                        const targetIndex = Tensor.coordsToIndex(fullCoords, this.strides);
                        grad.value[targetIndex] += outGrad.value[sourceStart + j];
                    }
                }
                Tensor.addGrad(this, grad);
            };
        }
        return out;
    }
    // Tensor indexing
    index(indices) {
        const tensorIndices = this.handleOther(indices).clone();
        if (typeof tensorIndices.value === "number") {
            return this.indexWithArray([tensorIndices.value]).squeeze(0);
        }
        else {
            const originalShape = tensorIndices.shape;
            const flatIndices = tensorIndices.value;
            const result = this.indexWithArray(flatIndices);
            // Reshape to preserve input shape
            const outputShape = [...originalShape, ...this.shape.slice(1)];
            return result.reshape(outputShape);
        }
    }
    // Tensor slicing
    slice(ranges) {
        // Handle scalars
        if (typeof this.value === "number")
            return this;
        const newShape = [];
        const newStrides = [];
        let newOffset = this.offset || 0;
        // Pad ranges to match tensor dimensions
        const paddedRanges = [...ranges];
        while (paddedRanges.length < this.shape.length) {
            paddedRanges.push([]);
        }
        for (let i = 0; i < this.shape.length; i++) {
            const range = paddedRanges[i] || [];
            const dimSize = this.shape[i];
            const stride = this.strides[i];
            // Default values
            let start = range[0] ?? 0;
            let end = range[1] ?? dimSize;
            let step = range[2] ?? 1;
            // Handle negative indices
            if (start < 0)
                start += dimSize;
            if (end < 0)
                end += dimSize;
            // Clamp to valid range
            start = Math.max(0, Math.min(start, dimSize));
            end = Math.max(0, Math.min(end, dimSize));
            // Calculate new dimension size
            const newDimSize = step > 0
                ? Math.max(0, Math.ceil((end - start) / step))
                : Math.max(0, Math.ceil((start - end) / Math.abs(step)));
            newShape.push(newDimSize);
            newStrides.push(stride * step);
            newOffset += start * stride;
        }
        const out = new Tensor(this.value, {
            shape: newShape,
            strides: newStrides,
            offset: newOffset,
            device: this.device
        });
        if (this.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(this);
            out.gradFn = () => {
                // Create zero tensor of original shape
                const grad = Tensor.zerosLike(this);
                // Upstream grad
                const outGrad = out.grad;
                const totalElements = outGrad.numel;
                for (let i = 0; i < totalElements; i++) {
                    // Convert flat index to coordinates in sliced tensor
                    const slicedCoords = Tensor.indexToCoords(i, outGrad.strides);
                    // Map back to original coordinates
                    const originalCoords = new Array(slicedCoords.length);
                    for (let dim = 0; dim < slicedCoords.length; dim++) {
                        const coord = slicedCoords[dim];
                        const range = paddedRanges[dim] || [];
                        const start = range[0] ?? 0;
                        const step = range[2] ?? 1;
                        const normalizedStart = start < 0 ? start + this.shape[dim] : start;
                        originalCoords[dim] = normalizedStart + coord * step;
                    }
                    // Get flat indices with offsets
                    const srcIndex = Tensor.coordsToIndex(slicedCoords, outGrad.strides) + outGrad.offset;
                    const targetIndex = Tensor.coordsToIndex(originalCoords, grad.strides) + grad.offset;
                    // Accumulate gradient
                    grad.value[targetIndex] += outGrad.value[srcIndex];
                }
                Tensor.addGrad(this, grad);
            };
        }
        return out;
    }
    // Tensor chunk
    chunk(chunks, dim = 0) {
        // Handle negative indices
        if (dim < 0) {
            dim += this.shape.length;
        }
        // If dimension out of bound, throw error
        if (dim >= this.shape.length || dim < 0) {
            throw new Error("Dimension do not exist to chunk");
        }
        const sliceOpt = new Array(this.shape.length);
        for (let index = 0; index < sliceOpt.length; index++) {
            sliceOpt[index] = [];
        }
        const dimSize = this.shape[dim];
        const chunkDimSize = Math.ceil(dimSize / chunks);
        const results = [];
        for (let index = 0; index < dimSize; index += chunkDimSize) {
            sliceOpt[dim] = [index, Math.min(index + chunkDimSize, dimSize)];
            results.push(this.slice(sliceOpt));
        }
        return results;
    }
    // Tensor expansion
    expand(newShape) {
        // Handle scalars
        let self = this;
        if (typeof this.value === "number") {
            self = self.unsqueeze(0);
        }
        // Pad shapes to same length
        const ndim = Math.max(self.shape.length, newShape.length);
        const oldShape = [...Array(ndim - self.shape.length).fill(1), ...self.shape];
        const oldStrides = [...Array(ndim - self.strides.length).fill(0), ...self.strides];
        const targetShape = [...Array(ndim - newShape.length).fill(1), ...newShape];
        const newStrides = new Array(ndim);
        for (let i = 0; i < ndim; i++) {
            if (oldShape[i] === targetShape[i]) {
                newStrides[i] = oldStrides[i];
            }
            else if (oldShape[i] === 1) {
                newStrides[i] = 0;
            }
            else {
                throw new Error(`Cannot expand dimension of size ${oldShape[i]} to ${targetShape[i]}`);
            }
        }
        const out = new Tensor(self.value, {
            shape: targetShape,
            strides: newStrides,
            offset: self.offset,
            numel: Tensor.shapeToSize(targetShape),
            device: self.device
        });
        if (self.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(self);
            out.gradFn = () => {
                Tensor.addGrad(self, out.grad);
            };
        }
        return out;
    }
    // Tensor concatentation
    cat(other, dim = 0) {
        other = this.handleOther(other);
        // Handle scalars
        if (typeof this.value === "number" || typeof other.value === "number") {
            throw new Error("Can not concatenate scalars");
        }
        // Handle negative indices
        if (dim < 0) {
            dim += this.shape.length;
        }
        // If dimension out of bound, throw error
        if (dim >= this.shape.length || dim < 0) {
            throw new Error("Dimension does not exist to concatenate");
        }
        // If shape does not match, throw error
        if (this.shape.length !== other.shape.length) {
            throw new Error("Shape does not match to concatenate");
        }
        const outputShape = new Array(this.shape.length);
        for (let currentDim = 0; currentDim < this.shape.length; currentDim++) {
            if (currentDim === dim) {
                outputShape[currentDim] = this.shape[currentDim] + other.shape[currentDim];
            }
            else if (this.shape[currentDim] !== other.shape[currentDim]) {
                throw new Error("Shape does not match to concatenate");
            }
            else {
                outputShape[currentDim] = this.shape[currentDim];
            }
        }
        const outputSize = Tensor.shapeToSize(outputShape);
        const outputStrides = Tensor.getStrides(outputShape);
        const outputValue = new Array(outputSize);
        for (let outIndex = 0; outIndex < outputSize; outIndex++) {
            const coords = Tensor.indexToCoords(outIndex, outputStrides);
            // Check which tensor this output position comes from
            if (coords[dim] < this.shape[dim]) {
                // Comes from this tensor
                const srcIndex = Tensor.coordsToIndex(coords, this.strides);
                outputValue[outIndex] = this.value[srcIndex + this.offset];
            }
            else {
                // Comes from other tensor - adjust coordinate in concat dimension
                const otherCoords = [...coords];
                otherCoords[dim] -= this.shape[dim];
                const srcIndex = Tensor.coordsToIndex(otherCoords, other.strides);
                outputValue[outIndex] = other.value[srcIndex + other.offset];
            }
        }
        const out = new Tensor(outputValue, {
            shape: outputShape,
            strides: outputStrides,
            numel: outputSize
        });
        if (this.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(this);
        }
        if (other.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(other);
        }
        if (out.requiresGrad) {
            out.gradFn = () => {
                const outGrad = out.grad;
                const thisRanges = new Array(this.shape.length);
                const otherRanges = new Array(other.shape.length);
                for (let currentDim = 0; currentDim < this.shape.length; currentDim++) {
                    if (currentDim === dim) {
                        thisRanges[currentDim] = [0, this.shape[currentDim], 1];
                        otherRanges[currentDim] = [this.shape[currentDim], outputShape[currentDim], 1];
                    }
                    else {
                        thisRanges[currentDim] = [];
                        otherRanges[currentDim] = [];
                    }
                }
                Tensor.addGrad(this, outGrad.slice(thisRanges));
                Tensor.addGrad(other, outGrad.slice(otherRanges));
            };
        }
        return out;
    }
    // Tensor squeeze
    squeeze(dims) {
        if (typeof this.value === "number")
            return this;
        if (typeof dims === "number") {
            dims = [dims];
        }
        if (typeof dims === "undefined") {
            const shape = this.shape;
            dims = [];
            for (let index = 0; index < shape.length; index++) {
                if (shape[index] === 1) {
                    dims.push(index);
                }
            }
        }
        dims = Tensor.normalizeDims(dims, this.shape.length);
        // Remove size-1 dims only
        const outShape = [], outStrides = [];
        for (let index = 0; index < this.shape.length; index++) {
            const dim = this.shape[index];
            const stride = this.strides[index];
            if (dims.includes(index)) {
                if (dim !== 1)
                    throw new Error(`Can not squeeze dim with size ${dim}`);
            }
            else {
                outShape.push(dim);
                outStrides.push(stride);
            }
        }
        const outValue = outShape.length === 0 ? this.value[this.offset] : this.value;
        const out = new Tensor(outValue, {
            shape: outShape,
            strides: outStrides,
            offset: this.offset,
            device: this.device
        });
        // Set up gradient if needed
        if (this.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(this);
            out.gradFn = () => {
                let restoredGrad = out.grad;
                for (let i = dims.length - 1; i >= 0; i--) {
                    restoredGrad = restoredGrad.unsqueeze(dims[i]);
                }
                Tensor.addGrad(this, restoredGrad);
            };
        }
        return out;
    }
    // Tensor unsqueeze - adds dimension of size 1 at specified position
    unsqueeze(dim) {
        // Handle negative indices
        if (dim < 0) {
            dim += this.shape.length;
        }
        let thisValue = this.value;
        if (typeof thisValue === "number") {
            thisValue = [thisValue];
        }
        // Insert size-1 dimension at specified position
        const newShape = [...this.shape];
        newShape.splice(dim, 0, 1);
        // New stride
        const newStrides = [...this.strides];
        let newDimStride;
        if (dim >= this.shape.length) {
            // Inserting at the back: use 1
            newDimStride = 1;
        }
        else {
            // Inserting before dim: use current stride * current shape
            newDimStride = this.strides[dim] * this.shape[dim];
        }
        newStrides.splice(dim, 0, newDimStride);
        const out = new Tensor(thisValue, {
            shape: newShape,
            strides: newStrides,
            offset: this.offset,
            device: this.device
        });
        // Set up gradient if needed
        if (this.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(this);
            out.gradFn = () => {
                Tensor.addGrad(this, out.grad.squeeze(dim));
            };
        }
        return out;
    }
    // Generic reduction operation handler
    static reduce(tensor, dims, keepDims, config) {
        if (typeof tensor.value === "number")
            return tensor;
        if (typeof dims === "undefined") {
            dims = Array.from({ length: tensor.shape.length }, (_, index) => index);
        }
        if (Array.isArray(dims)) {
            dims = Tensor.normalizeDims(dims, tensor.shape.length);
            const sortedDims = dims.sort((a, b) => b - a);
            let reducedThis = tensor;
            for (let i = 0; i < sortedDims.length; i++) {
                reducedThis = Tensor.reduce(reducedThis, sortedDims[i], true, config);
            }
            return keepDims ? reducedThis : reducedThis.squeeze(dims);
        }
        const outputShape = tensor.shape.map((dim, i) => dims === i ? 1 : dim);
        const outputStrides = Tensor.getStrides(outputShape);
        const outputSize = Tensor.shapeToSize(outputShape);
        const outputValue = new Array(outputSize).fill(config.identity);
        const outputCounters = config.needsCounters ? new Array(outputSize).fill(0) : [];
        const originalSize = tensor.numel;
        const originalValue = tensor.value;
        const linearStrides = Tensor.getStrides(tensor.shape);
        // Forward pass
        for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
            // Convert linear index to coordinates using contiguous strides
            const coords = Tensor.indexToCoords(flatIndex, linearStrides);
            // Convert coordinates to actual strided index
            const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
            // Convert coords to reduced index
            coords[dims] = 0;
            const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
            // Apply op
            outputValue[outFlatIndex] = config.operation(outputValue[outFlatIndex], originalValue[realFlatIndex]);
            // Count el if needed
            if (config.needsCounters) {
                outputCounters[outFlatIndex]++;
            }
        }
        // Post-process if needed (e.g., divide by count for mean)
        if (config.postProcess) {
            config.postProcess({ values: outputValue, counters: outputCounters });
        }
        const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, numel: outputSize });
        // Gradient setup
        if (tensor.requiresGrad) {
            out.requiresGrad = true;
            out.children.push(tensor);
            out.gradFn = () => {
                let shareCounts = [];
                if (config.needsShareCounts) {
                    shareCounts = new Array(outputSize).fill(0);
                    for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
                        // Convert linear index to coordinates using contiguous strides
                        const coords = Tensor.indexToCoords(flatIndex, linearStrides);
                        // Convert coordinates to actual strided index
                        const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
                        // Convert coords to reduced index
                        coords[dims] = 0;
                        const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
                        // We collect how many elements share the same max value first
                        shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
                    }
                }
                const gradValue = new Array(originalSize);
                for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
                    // Convert linear index to coordinates using contiguous strides
                    const coords = Tensor.indexToCoords(flatIndex, linearStrides);
                    // Convert coordinates to actual strided index
                    const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
                    // Convert coords to reduced index
                    coords[dims] = 0;
                    const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
                    gradValue[flatIndex] = config.gradientFn({
                        outputValue,
                        originalValue: tensor.value,
                        counters: outputCounters,
                        shareCounts,
                        realIndex: realFlatIndex,
                        outIndex: outFlatIndex
                    });
                }
                const localGrad = new Tensor(gradValue, { shape: tensor.shape, numel: tensor.numel });
                Tensor.addGrad(tensor, out.grad.mul(localGrad));
            };
        }
        return keepDims ? out : out.squeeze(dims);
    }
    // Simplified reduction operations
    sum(dims, keepDims = false) {
        return Tensor.reduce(this, dims, keepDims, {
            identity: 0,
            operation: (a, b) => a + b,
            gradientFn: ({}) => 1
        });
    }
    prod(dims, keepDims = false) {
        return Tensor.reduce(this, dims, keepDims, {
            identity: 1,
            operation: (a, b) => a * b,
            gradientFn: ({ outputValue, originalValue, realIndex, outIndex }) => outputValue[outIndex] / originalValue[realIndex]
        });
    }
    mean(dims, keepDims = false) {
        return Tensor.reduce(this, dims, keepDims, {
            identity: 0,
            operation: (a, b) => a + b,
            needsCounters: true,
            postProcess: ({ values, counters }) => {
                for (let i = 0; i < values.length; i++) {
                    values[i] /= counters[i];
                }
            },
            gradientFn: ({ counters, outIndex }) => 1 / counters[outIndex]
        });
    }
    max(dims, keepDims = false) {
        return Tensor.reduce(this, dims, keepDims, {
            identity: -Infinity,
            operation: (a, b) => Math.max(a, b),
            needsShareCounts: true,
            gradientFn: ({ outputValue, originalValue, shareCounts, realIndex, outIndex }) => outputValue[outIndex] === originalValue[realIndex] ? 1 / shareCounts[outIndex] : 0
        });
    }
    min(dims, keepDims = false) {
        return Tensor.reduce(this, dims, keepDims, {
            identity: Infinity,
            operation: (a, b) => Math.min(a, b),
            needsShareCounts: true,
            gradientFn: ({ outputValue, originalValue, shareCounts, realIndex, outIndex }) => outputValue[outIndex] === originalValue[realIndex] ? 1 / shareCounts[outIndex] : 0
        });
    }
    // Tensor all condition reduction
    all(dims, keepDims = false) {
        return this.min(dims, keepDims).ne(0);
    }
    // Tensor any condition reduction
    any(dims, keepDims = false) {
        return this.max(dims, keepDims).ne(0);
    }
    // Tensor variance reduction
    var(dims, keepDims = false) {
        const meanXSquared = this.square().mean(dims, keepDims);
        const meanXSquaredExpanded = this.mean(dims, keepDims).square();
        return meanXSquared.sub(meanXSquaredExpanded);
    }
    // Tensor standard deviation reduction
    std(dims, keepDims = false) {
        return this.var(dims, keepDims).sqrt();
    }
    // Tensor softmax
    softmax(dim = -1) {
        if (typeof this.value === "number")
            return this;
        // Handle negative indexing
        if (dim < 0) {
            dim += this.shape.length;
        }
        // If dimension out of bound, throw error
        if (dim >= this.shape.length || dim < 0) {
            throw new Error("Dimension do not exist to apply softmax");
        }
        const maxVals = this.max(dim, true);
        const shifted = this.sub(maxVals);
        const expVals = shifted.exp();
        const sumExp = expVals.sum(dim, true);
        return expVals.div(sumExp);
    }
    // Tensor softmin
    softmin(dim = -1) {
        if (typeof this.value === "number")
            return this;
        // Handle negative indexing
        if (dim < 0) {
            dim += this.shape.length;
        }
        // If dimension out of bound, throw error
        if (dim >= this.shape.length || dim < 0) {
            throw new Error("Dimension do not exist to apply softmin");
        }
        const maxVals = this.max(dim, true);
        const shifted = maxVals.sub(this);
        const expVals = shifted.exp();
        const sumExp = expVals.sum(dim, true);
        return expVals.div(sumExp);
    }
    // Tensor element-wise addition
    add(other) {
        return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad);
    }
    // Tensor element-wise subtraction
    sub(other) {
        return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad.neg());
    }
    subtract = this.sub;
    // Tensor element-wise multiplication
    mul(other) {
        return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => outGrad.mul(other), (self, other, outGrad) => outGrad.mul(self));
    }
    multiply = this.mul;
    // Tensor element-wise power
    pow(other) {
        return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => outGrad.mul(other.mul(self.pow(other.sub(1)))), (self, other, outGrad) => outGrad.mul(self.pow(other).mul(self.log())));
    }
    // Tensor element-wise division
    div(other) {
        return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => outGrad.div(other), (self, other, outGrad) => outGrad.mul(self.neg().div(other.square())));
    }
    divide = this.div;
    // Tensor element-wise modulo
    remainder(other) {
        return this.elementWiseABDAG(other, (a, b) => a % b);
    }
    // Tensor element-wise greater or equal comparison
    ge(other) {
        return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0);
    }
    greaterEqual = this.ge;
    // Tensor element-wise less or equal comparison
    le(other) {
        return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0);
    }
    lessEqual = this.le;
    // Tensor element-wise greater-than comparison
    gt(other) {
        return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0);
    }
    greater = this.gt;
    // Tensor element-wise less-than comparison
    lt(other) {
        return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0);
    }
    less = this.lt;
    // Tensor element-wise equality comparison
    eq(other) {
        return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0);
    }
    equal = this.eq;
    // Tensor element-wise not equality comparison
    ne(other) {
        return this.elementWiseABDAG(other, (a, b) => a !== b ? 1 : 0);
    }
    notEqual = this.ne;
    // Tensor element-wise logical and
    logicalAnd(other) {
        return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0);
    }
    // Tensor element-wise logical or
    logicalOr(other) {
        return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0);
    }
    // Tensor element-wise logical xor
    logicalXor(other) {
        return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0);
    }
    // Tensor element-wise logical not
    logicalNot() {
        return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1);
    }
    // Tensor element-wise bitwise and
    bitwiseAnd(other) {
        return this.elementWiseABDAG(other, (a, b) => a & b);
    }
    // Tensor element-wise bitwise or
    bitwiseOr(other) {
        return this.elementWiseABDAG(other, (a, b) => a | b);
    }
    // Tensor element-wise bitwise xor
    bitwiseXor(other) {
        return this.elementWiseABDAG(other, (a, b) => a ^ b);
    }
    // Tensor element-wise bitwise not
    bitwiseNot() {
        return this.elementWiseSelfDAG((a) => ~a);
    }
    // Tensor element-wise left shift
    bitwiseLeftShift(other) {
        return this.elementWiseABDAG(other, (a, b) => a << b);
    }
    // Tensor element-wise right shift
    bitwiseRightShift(other) {
        return this.elementWiseABDAG(other, (a, b) => a >> b);
    }
    // Tensor element-wise negation
    neg() {
        return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => outGrad.mul(-1));
    }
    negative = this.neg;
    // Tensor element-wise reciprocal
    reciprocal() {
        return this.elementWiseSelfDAG((a) => 1 / a, (self, outGrad) => outGrad.mul(self.pow(-2).neg()));
    }
    // Tensor element-wise square
    square() {
        return this.elementWiseSelfDAG((a) => a * a, (self, outGrad) => outGrad.mul(self.mul(2)));
    }
    // Tensor element-wise absolute
    abs() {
        return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => outGrad.mul(self.sign()));
    }
    absolute = this.abs;
    // Tensor element-wise sign function
    sign() {
        return this.elementWiseSelfDAG((a) => Math.sign(a));
    }
    // Tensor element-wise sin
    sin() {
        return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => outGrad.mul(self.cos()));
    }
    // Tensor element-wise cos
    cos() {
        return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => outGrad.mul(self.sin().neg()));
    }
    // Tensor element-wise tan
    tan() {
        return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => outGrad.mul(self.tan().square().add(1)));
    }
    // Tensor element-wise asin
    asin() {
        return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()));
    }
    arcsin = this.asin;
    // Tensor element-wise acos
    acos() {
        return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()).neg());
    }
    arccos = this.acos;
    // Tensor element-wise atan
    atan() {
        return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => outGrad.div(self.square().add(1)));
    }
    arctan = this.atan;
    // Tensor element-wise atan2
    atan2(other) {
        return this.elementWiseABDAG(other, (a, b) => Math.atan2(a, b), (self, other, outGrad) => outGrad.mul(other.div(self.square().add(other.square()))), (self, other, outGrad) => outGrad.mul(self.neg().div(self.square().add(other.square()))));
    }
    arctan2 = this.atan2;
    // Tensor element-wise sinh
    sinh() {
        return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => outGrad.mul(self.cosh()));
    }
    // Tensor element-wise cosh
    cosh() {
        return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => outGrad.mul(self.sinh()));
    }
    // Tensor element-wise asinh
    asinh() {
        return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => outGrad.div(self.square().add(1).sqrt()));
    }
    arcsinh = this.asinh;
    // Tensor element-wise acosh
    acosh() {
        return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
    }
    arccosh = this.acosh;
    // Tensor element-wise atanh
    atanh() {
        return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => outGrad.div(self.square().neg().add(1)));
    }
    arctanh = this.atanh;
    // Tensor element-wise degree to radian
    deg2rad() {
        return this.elementWiseSelfDAG((a) => a * (Math.PI / 180), (self, outGrad) => outGrad.mul(Math.PI / 180));
    }
    // Tensor element-wise radian to degree
    rad2deg() {
        return this.elementWiseSelfDAG((a) => a / (Math.PI / 180), (self, outGrad) => outGrad.div(Math.PI / 180));
    }
    // Tensor element-wise square root
    sqrt() {
        return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => outGrad.div(self.sqrt().mul(2)));
    }
    // Tensor element-wise reciprocal of square root
    rsqrt() {
        return this.elementWiseSelfDAG((a) => 1 / Math.sqrt(a), (self, outGrad) => outGrad