UNPKG

catniff

Version:

Torch-like deep learning framework for Javascript

1,248 lines 105 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.Tensor = void 0; const dtype_1 = require("./dtype"); const utils_1 = require("./utils"); class Tensor { value; shape; strides; offset; numel; grad; requiresGrad; gradFn; children; device; dtype; static training = false; static noGrad = false; static createGraph = false; constructor(value, options = {}) { // Memory buffer this.dtype = options.dtype || "float32"; const flatValue = Tensor.flattenValue(value); const TypedArrayConstructor = dtype_1.TypedArray[this.dtype]; this.value = flatValue instanceof TypedArrayConstructor ? flatValue : TypedArrayConstructor.from(flatValue); // Tensor metadata this.shape = options.shape || Tensor.getShape(value); this.strides = options.strides || Tensor.getStrides(this.shape); this.offset = options.offset || 0; this.numel = options.numel || Tensor.shapeToSize(this.shape); this.device = options.device || "cpu"; // Autograd data this.grad = options.grad; this.requiresGrad = options.requiresGrad ?? false; this.gradFn = options.gradFn || (() => { }); this.children = options.children || []; // Move to device in-place this.to_(this.device); } // Utility to flatten an nD array to be 1D static flattenValue(tensorValue) { // Handle scalar tensors if (typeof tensorValue === "number") return [tensorValue]; // If value is already 1D, we just need to return the value ('s reference) if (typeof tensorValue[0] === "number") return tensorValue; // Or else recursively traverse through the nD array to flatten const result = []; function traverse(arr) { if (typeof arr === "number") { result.push(arr); // Assume if we can index a value, it is an ArrayLike } else if (typeof arr[0] !== "undefined") { for (let index = 0; index < arr.length; index++) { traverse(arr[index]); } } } traverse(tensorValue); return result; } // Utility to get shape from tensor *value* static getShape(tensorValue) { const shape = []; let subA = tensorValue; while (typeof subA !== "number") { shape.push(subA.length); subA = subA[0]; } return shape; } // Utility to get strides from shape static getStrides(shape) { if (shape.length === 0) return []; const strides = new Array(shape.length); strides[strides.length - 1] = 1; for (let i = strides.length - 2; i >= 0; i--) { strides[i] = strides[i + 1] * shape[i + 1]; } return strides; } // Left-pad shape and strides for two shape to be of same length static padShape(stridesA, stridesB, shapeA, shapeB) { const newStrideA = [...stridesA], newStrideB = [...stridesB]; const newShapeA = [...shapeA], newShapeB = [...shapeB]; while (newStrideA.length < newStrideB.length) { const newStride = newShapeA[0] * newStrideA[0]; newStrideA.unshift(newStride); newShapeA.unshift(1); } while (newStrideA.length > newStrideB.length) { const newStride = newShapeB[0] * newStrideB[0]; newStrideB.unshift(newStride); newShapeB.unshift(1); } return [newStrideA, newStrideB, newShapeA, newShapeB]; } // Broadcast shapes static broadcastShapes(shapeA, shapeB) { const newShape = new Array(shapeA.length); for (let index = 0; index < shapeA.length; index++) { if (shapeA[index] === 1) { newShape[index] = shapeB[index]; } else if (shapeB[index] === 1) { newShape[index] = shapeA[index]; } else if (shapeA[index] === shapeB[index]) { newShape[index] = shapeA[index]; } else { throw new Error(`Can not broadcast shapes: ${shapeA} and ${shapeB}`); } } return newShape; } // Utility to convert flat index to array of coordinates static indexToCoords(index, strides) { const coords = new Array(strides.length); let remaining = index; for (let dim = 0; dim < strides.length; dim++) { coords[dim] = Math.floor(remaining / strides[dim]); remaining %= strides[dim]; } return coords; } // Utility to convert array of coordinates to *unbroadcasted* flat index static coordsToUnbroadcastedIndex(coords, shape, strides) { let index = 0; for (let i = 0; i < coords.length; i++) { // Handle broadcasting const actualCoord = shape[i] === 1 ? 0 : coords[i]; index += actualCoord * strides[i]; } return index; } // Utility to convert array of coordinates to flat index static coordsToIndex(coords, strides) { let index = 0; for (let i = 0; i < coords.length; i++) { index += coords[i] * strides[i]; } return index; } // Utility to convert shape into 1D value array size static shapeToSize(shape) { let prod = 1; for (let i = 0; i < shape.length; i++) { prod *= shape[i]; } return prod; } // Utility to get best possible result type if type conflicts happen: static getResultDtype(type1, type2) { if (type1 === type2) return type1; const type1Ranking = dtype_1.dtypeHiearchy[type1]; const type2Ranking = dtype_1.dtypeHiearchy[type2]; if (type1Ranking > type2Ranking) { return type1; } return type2; } // Utility to handle other tensor if an op needs a second operand handleOther(other) { if (other instanceof Tensor) { if (this.device !== other.device) { throw new Error("Can not operate on tensors that are not on the same device"); } return other; } return new Tensor(other, { offset: 0, device: this.device, dtype: this.dtype }); } // Utility for binary (two operators involved) element-wise ops static elementWiseAB(tA, tB, op) { const outputDtype = Tensor.getResultDtype(tA.dtype, tB.dtype); // Both are scalars if (tA.shape.length === 0 && tB.shape.length === 0) { return new Tensor(op(tA.value[0], tB.value[0]), { shape: [], strides: [], offset: 0, numel: 1, device: tA.device, dtype: outputDtype }); } // First tensor is scalar if (tA.shape.length === 0) { return Tensor.elementWiseSelf(tB.cast(outputDtype), (a) => op(a, tA.value[0])); } // Second tensor is scalar if (tB.shape.length === 0) { return Tensor.elementWiseSelf(tA.cast(outputDtype), (a) => op(a, tB.value[0])); } // Pad + broadcast shape const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape); const outputShape = Tensor.broadcastShapes(paddedAShape, paddedBShape); // Get other output info const outputStrides = Tensor.getStrides(outputShape); const outputSize = Tensor.shapeToSize(outputShape); const outputValue = new dtype_1.TypedArray[outputDtype](outputSize); // Check fast path conditions of two tensors const aFastPath = tA.isContiguous() && tA.numel === outputSize; const bFastPath = tB.isContiguous() && tB.numel === outputSize; for (let i = 0; i < outputSize; i++) { // Get coordinates from 1D index const coordsOutput = aFastPath && bFastPath ? [] : Tensor.indexToCoords(i, outputStrides); // Convert the coordinates to 1D index of flattened A with respect to A's shape const indexA = aFastPath ? i : Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedAShape, paddedAStrides); // Convert the coordinates to 1D index of flattened B with respect to B's shape const indexB = bFastPath ? i : Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides); // Calculate with op outputValue[i] = op(tA.value[indexA + tA.offset], tB.value[indexB + tB.offset]); } return new Tensor(outputValue, { shape: outputShape, strides: outputStrides, offset: 0, numel: outputSize, device: tA.device, dtype: outputDtype }); } // Utility for self-inflicting element-wise ops static elementWiseSelf(tA, op) { // Handle scalar case if (tA.shape.length === 0) return new Tensor(op(tA.value[0]), { shape: [], strides: [], offset: 0, numel: 1, device: tA.device, dtype: tA.dtype }); const contiguous = tA.isContiguous(); const outputShape = tA.shape; const outputStrides = contiguous ? tA.strides : Tensor.getStrides(outputShape); const outputSize = tA.numel; const outputValue = new dtype_1.TypedArray[tA.dtype](outputSize); if (contiguous) { for (let index = 0; index < outputSize; index++) { outputValue[index] = op(tA.value[index + tA.offset]); } } else { for (let index = 0; index < outputSize; index++) { const outputCoords = Tensor.indexToCoords(index, outputStrides); const originalIndex = tA.offset + Tensor.coordsToIndex(outputCoords, tA.strides); outputValue[index] = op(tA.value[originalIndex]); } } return new Tensor(outputValue, { shape: outputShape, strides: outputStrides, offset: 0, numel: tA.numel, device: tA.device, dtype: tA.dtype }); } // Utility to do element-wise operation and build a dag node with another tensor elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) { other = this.handleOther(other); const out = Tensor.elementWiseAB(this, other, op); if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); } if (other.requiresGrad) { out.requiresGrad = true; out.children.push(other); } if (out.requiresGrad) { out.gradFn = () => { const outGrad = out.grad; const selfWithGrad = Tensor.createGraph ? this : this.detach(); const otherWithGrad = Tensor.createGraph ? other : other.detach(); if (this.requiresGrad) Tensor.addGrad(this, thisGrad(selfWithGrad, otherWithGrad, outGrad)); if (other.requiresGrad) Tensor.addGrad(other, otherGrad(selfWithGrad, otherWithGrad, outGrad)); }; } return out; } // Utility to do self-inflicting element-wise operation and build a dag node elementWiseSelfDAG(op, thisGrad = () => new Tensor(0)) { const out = Tensor.elementWiseSelf(this, op); if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); } if (out.requiresGrad) { out.gradFn = () => { const outGrad = out.grad; const selfWithGrad = Tensor.createGraph ? this : this.detach(); if (this.requiresGrad) Tensor.addGrad(this, thisGrad(selfWithGrad, outGrad)); }; } return out; } // Utility to add to gradient of tensor static addGrad(tensor, accumGrad) { const axesToSqueeze = []; const axesToReduce = []; const shape = tensor.shape; const gradShape = accumGrad.shape; const paddedDims = gradShape.length - shape.length; for (let i = 0; i < paddedDims; i++) { axesToReduce.push(i); axesToSqueeze.push(i); } for (let i = 0; i < shape.length; i++) { if (shape[i] === 1 && gradShape[i + paddedDims] > 1) { axesToReduce.push(i + paddedDims); } } const reducedGrad = accumGrad.sum(axesToReduce, true); const squeezedGrad = reducedGrad.squeeze(axesToSqueeze); if (typeof tensor.grad === "undefined") { tensor.grad = squeezedGrad; } else { tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype)); } } static normalizeDims(dims, numDims) { for (let index = 0; index < dims.length; index++) { // Handle negative indices if (dims[index] < 0) { dims[index] += numDims; } // If dimension out of bound, throw error if (dims[index] >= numDims || dims[index] < 0) { throw new Error("Dimensions do not exist"); } } return dims; } // Contiguity-related ops isContiguous() { const expectedStrides = Tensor.getStrides(this.shape); for (let i = 0; i < this.strides.length; i++) { if (this.strides[i] !== expectedStrides[i]) { return false; } } return true; } contiguous() { // Check if scalar if (this.shape.length === 0) return this; // Check if already contiguous if (this.isContiguous()) return this; const outputStrides = Tensor.getStrides(this.shape); const outputSize = this.numel; const outputValue = new dtype_1.TypedArray[this.dtype](outputSize); for (let index = 0; index < outputSize; index++) { const outputCoords = Tensor.indexToCoords(index, outputStrides); const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides); outputValue[index] = this.value[this.offset + originalIndex]; } const out = new Tensor(outputValue, { shape: this.shape, strides: outputStrides, offset: 0, numel: outputSize, device: this.device, dtype: this.dtype }); // Gradient flow back to the original tensor if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { Tensor.addGrad(this, out.grad); }; } return out; } view(newShape) { // Verify shape size const originalSize = this.numel; const outputSize = Tensor.shapeToSize(newShape); if (originalSize !== outputSize) { throw new Error("Can not create view: incompatible sizes"); } // Verify compatibility (only contiguity for now) if (!this.isContiguous()) { throw new Error("Can not create view: incompatible metadata"); } const outputStrides = Tensor.getStrides(newShape); const out = new Tensor(this.value, { shape: newShape, strides: outputStrides, offset: this.offset, numel: outputSize, device: this.device, dtype: this.dtype }); // Gradient reshaped and flow back to the original tensor if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { Tensor.addGrad(this, out.grad.reshape(this.shape)); }; } return out; } reshape(newShape) { return this.contiguous().view(newShape); } flatten(startDim = 0, endDim = -1) { // Handle negative indices if (startDim < 0) { startDim += this.shape.length; } if (endDim < 0) { endDim += this.shape.length; } // If dimension out of bound, throw error if (startDim >= this.shape.length || endDim >= this.shape.length || startDim < 0 || endDim < 0) { throw new Error("Dimensions do not exist to flatten"); } const newShape = []; let middleSize = 1; for (let index = 0; index < this.shape.length; index++) { // Keep dims before startDim if (index < startDim) { newShape.push(this.shape[index]); } // Multiply dims from startDim to endDim if (index >= startDim && index <= endDim) { middleSize *= this.shape[index]; } // Push new flatten middle if (index === endDim) { newShape.push(middleSize); } // Keep dims after endDim if (index > endDim) { newShape.push(this.shape[index]); } } return this.reshape(newShape); } // Transpose transpose(dim1, dim2) { // Handle negative indices if (dim1 < 0) { dim1 += this.shape.length; } if (dim2 < 0) { dim2 += this.shape.length; } // If dimension out of bound, throw error if (dim1 >= this.shape.length || dim2 >= this.shape.length || dim1 < 0 || dim2 < 0) { throw new Error("Dimensions do not exist to transpose"); } // If same dimension, return view if (dim1 === dim2) return this; // Create new shape and strides by swapping const newShape = [...this.shape]; const newStrides = [...this.strides]; [newShape[dim1], newShape[dim2]] = [newShape[dim2], newShape[dim1]]; [newStrides[dim1], newStrides[dim2]] = [newStrides[dim2], newStrides[dim1]]; // Create new tensor with same data but swapped shape/strides const out = new Tensor(this.value, { shape: newShape, strides: newStrides, offset: this.offset, numel: this.numel, device: this.device, dtype: this.dtype }); out.requiresGrad = this.requiresGrad; // Handle gradient if needed if (this.requiresGrad) { out.children.push(this); out.gradFn = () => { Tensor.addGrad(this, out.grad.transpose(dim1, dim2)); }; } return out; } swapaxes = this.transpose; swapdims = this.transpose; // Transpose 2D t() { // Verify matrix shape if (this.shape.length !== 2) { throw new Error("Input is not a matrix"); } return this.transpose(0, 1); } // Permute permute(dims) { dims = Tensor.normalizeDims(dims, this.shape.length); if (dims.length !== this.shape.length) { throw new Error("Permutation must specify all dimensions"); } // Compute new shape and strides const newShape = new Array(dims.length); const newStrides = new Array(dims.length); for (let index = 0; index < dims.length; index++) { const dim = dims[index]; newShape[index] = this.shape[dim]; newStrides[index] = this.strides[dim]; } const out = new Tensor(this.value, { shape: newShape, strides: newStrides, offset: this.offset, numel: this.numel, device: this.device, dtype: this.dtype }); if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { // Compute inverse permutation const inverseAxes = new Array(dims.length); for (let i = 0; i < dims.length; i++) { inverseAxes[dims[i]] = i; } // Permute gradient back to original order const permutedGrad = out.grad.permute(inverseAxes); Tensor.addGrad(this, permutedGrad); }; } return out; } // Utility for indexing with array of indices indexWithArray(indices) { if (this.shape.length === 0) return this; indices = Tensor.normalizeDims(indices, this.shape[0]); // Init necessary stuff for indexing const reducedShape = this.shape.slice(1); const reducedStrides = this.strides.slice(1); const elementsPerIndex = Tensor.shapeToSize(reducedShape); // Init output data const outputShape = [indices.length, ...reducedShape]; const outputSize = Tensor.shapeToSize(outputShape); const outputValue = new dtype_1.TypedArray[this.dtype](outputSize); for (let i = 0; i < indices.length; i++) { const sourceRowIndex = indices[i]; const targetStart = i * elementsPerIndex; for (let j = 0; j < elementsPerIndex; j++) { const fullCoords = Tensor.indexToCoords(j, reducedStrides); fullCoords.unshift(sourceRowIndex); const sourceIndex = Tensor.coordsToIndex(fullCoords, this.strides); outputValue[targetStart + j] = this.value[this.offset + sourceIndex]; } } const out = new Tensor(outputValue, { shape: outputShape, offset: 0, numel: outputSize, device: this.device, dtype: this.dtype }); // Handle gradient if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { const outGrad = out.grad; // Create zero gradient tensor with original shape const grad = Tensor.zerosLike(this); // Scatter gradients back to original positions for (let i = 0; i < indices.length; i++) { const originalRowIndex = indices[i]; const sourceStart = i * elementsPerIndex; for (let j = 0; j < elementsPerIndex; j++) { const fullCoords = Tensor.indexToCoords(j, reducedStrides); fullCoords.unshift(originalRowIndex); const targetIndex = Tensor.coordsToIndex(fullCoords, this.strides); grad.value[targetIndex] += outGrad.value[sourceStart + j]; } } Tensor.addGrad(this, grad); }; } return out; } // Tensor indexing index(indices) { const tensorIndices = this.handleOther(indices).clone(); if (tensorIndices.shape.length === 0) { return this.indexWithArray([tensorIndices.value[0]]).squeeze(0); } else { const originalShape = tensorIndices.shape; const flatIndices = tensorIndices.value; const result = this.indexWithArray(Array.from(flatIndices)); // Reshape to preserve input shape const outputShape = [...originalShape, ...this.shape.slice(1)]; return result.reshape(outputShape); } } // Tensor slicing slice(ranges) { // Handle scalars if (this.shape.length === 0) return this; const newShape = []; const newStrides = []; let newOffset = this.offset || 0; // Pad ranges to match tensor dimensions const paddedRanges = [...ranges]; while (paddedRanges.length < this.shape.length) { paddedRanges.push([]); } for (let i = 0; i < this.shape.length; i++) { const range = paddedRanges[i] || []; const dimSize = this.shape[i]; const stride = this.strides[i]; // Default values let start = range[0] ?? 0; let end = range[1] ?? dimSize; let step = range[2] ?? 1; // Handle negative indices if (start < 0) start += dimSize; if (end < 0) end += dimSize; // Clamp to valid range start = Math.max(0, Math.min(start, dimSize)); end = Math.max(0, Math.min(end, dimSize)); // Calculate new dimension size const newDimSize = step > 0 ? Math.max(0, Math.ceil((end - start) / step)) : Math.max(0, Math.ceil((start - end) / Math.abs(step))); newShape.push(newDimSize); newStrides.push(stride * step); newOffset += start * stride; } const out = new Tensor(this.value, { shape: newShape, strides: newStrides, offset: newOffset, device: this.device, dtype: this.dtype }); if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { // Create zero tensor of original shape const grad = Tensor.zerosLike(this); // Upstream grad const outGrad = out.grad; const totalElements = outGrad.numel; for (let i = 0; i < totalElements; i++) { // Convert flat index to coordinates in sliced tensor const slicedCoords = Tensor.indexToCoords(i, outGrad.strides); // Map back to original coordinates const originalCoords = new Array(slicedCoords.length); for (let dim = 0; dim < slicedCoords.length; dim++) { const coord = slicedCoords[dim]; const range = paddedRanges[dim] || []; const start = range[0] ?? 0; const step = range[2] ?? 1; const normalizedStart = start < 0 ? start + this.shape[dim] : start; originalCoords[dim] = normalizedStart + coord * step; } // Get flat indices with offsets const srcIndex = Tensor.coordsToIndex(slicedCoords, outGrad.strides) + outGrad.offset; const targetIndex = Tensor.coordsToIndex(originalCoords, grad.strides) + grad.offset; // Accumulate gradient grad.value[targetIndex] += outGrad.value[srcIndex]; } Tensor.addGrad(this, grad); }; } return out; } // Tensor chunk chunk(chunks, dim = 0) { // Handle negative indices if (dim < 0) { dim += this.shape.length; } // If dimension out of bound, throw error if (dim >= this.shape.length || dim < 0) { throw new Error("Dimension do not exist to chunk"); } const sliceOpt = new Array(this.shape.length); for (let index = 0; index < sliceOpt.length; index++) { sliceOpt[index] = []; } const dimSize = this.shape[dim]; const chunkDimSize = Math.ceil(dimSize / chunks); const results = []; for (let index = 0; index < dimSize; index += chunkDimSize) { sliceOpt[dim] = [index, Math.min(index + chunkDimSize, dimSize)]; results.push(this.slice(sliceOpt)); } return results; } // Tensor expansion expand(newShape) { // Handle scalars let self = this; if (this.shape.length === 0) { self = self.unsqueeze(0); } // Pad shapes to same length const ndim = Math.max(self.shape.length, newShape.length); const oldShape = [...Array(ndim - self.shape.length).fill(1), ...self.shape]; const oldStrides = [...Array(ndim - self.strides.length).fill(0), ...self.strides]; const targetShape = [...Array(ndim - newShape.length).fill(1), ...newShape]; const newStrides = new Array(ndim); for (let i = 0; i < ndim; i++) { if (oldShape[i] === targetShape[i]) { newStrides[i] = oldStrides[i]; } else if (oldShape[i] === 1) { newStrides[i] = 0; } else { throw new Error(`Cannot expand dimension of size ${oldShape[i]} to ${targetShape[i]}`); } } const out = new Tensor(self.value, { shape: targetShape, strides: newStrides, offset: self.offset, device: self.device, dtype: self.dtype }); if (self.requiresGrad) { out.requiresGrad = true; out.children.push(self); out.gradFn = () => { Tensor.addGrad(self, out.grad); }; } return out; } // Tensor concatentation cat(other, dim = 0) { other = this.handleOther(other); // Handle scalars if (this.shape.length === 0 || other.shape.length === 0) { throw new Error("Can not concatenate scalars"); } // Handle negative indices if (dim < 0) { dim += this.shape.length; } // If dimension out of bound, throw error if (dim >= this.shape.length || dim < 0) { throw new Error("Dimension does not exist to concatenate"); } // If shape does not match, throw error if (this.shape.length !== other.shape.length) { throw new Error("Shape does not match to concatenate"); } const outputShape = new Array(this.shape.length); for (let currentDim = 0; currentDim < this.shape.length; currentDim++) { if (currentDim === dim) { outputShape[currentDim] = this.shape[currentDim] + other.shape[currentDim]; } else if (this.shape[currentDim] !== other.shape[currentDim]) { throw new Error("Shape does not match to concatenate"); } else { outputShape[currentDim] = this.shape[currentDim]; } } const outputSize = Tensor.shapeToSize(outputShape); const outputStrides = Tensor.getStrides(outputShape); const outputDtype = Tensor.getResultDtype(this.dtype, other.dtype); const outputValue = new dtype_1.TypedArray[outputDtype](outputSize); for (let outIndex = 0; outIndex < outputSize; outIndex++) { const coords = Tensor.indexToCoords(outIndex, outputStrides); // Check which tensor this output position comes from if (coords[dim] < this.shape[dim]) { // Comes from this tensor const srcIndex = Tensor.coordsToIndex(coords, this.strides); outputValue[outIndex] = this.value[srcIndex + this.offset]; } else { // Comes from other tensor - adjust coordinate in concat dimension const otherCoords = [...coords]; otherCoords[dim] -= this.shape[dim]; const srcIndex = Tensor.coordsToIndex(otherCoords, other.strides); outputValue[outIndex] = other.value[srcIndex + other.offset]; } } const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, offset: 0, numel: outputSize, device: this.device, dtype: this.dtype }); if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); } if (other.requiresGrad) { out.requiresGrad = true; out.children.push(other); } if (out.requiresGrad) { out.gradFn = () => { const outGrad = out.grad; const thisRanges = new Array(this.shape.length); const otherRanges = new Array(other.shape.length); for (let currentDim = 0; currentDim < this.shape.length; currentDim++) { if (currentDim === dim) { thisRanges[currentDim] = [0, this.shape[currentDim], 1]; otherRanges[currentDim] = [this.shape[currentDim], outputShape[currentDim], 1]; } else { thisRanges[currentDim] = []; otherRanges[currentDim] = []; } } Tensor.addGrad(this, outGrad.slice(thisRanges)); Tensor.addGrad(other, outGrad.slice(otherRanges)); }; } return out; } // Tensor squeeze squeeze(dims) { if (this.shape.length === 0) return this; if (typeof dims === "number") { dims = [dims]; } if (typeof dims === "undefined") { const shape = this.shape; dims = []; for (let index = 0; index < shape.length; index++) { if (shape[index] === 1) { dims.push(index); } } } dims = Tensor.normalizeDims(dims, this.shape.length); // Remove size-1 dims only const outShape = [], outStrides = []; for (let index = 0; index < this.shape.length; index++) { const dim = this.shape[index]; const stride = this.strides[index]; if (dims.includes(index)) { if (dim !== 1) throw new Error(`Can not squeeze dim with size ${dim}`); } else { outShape.push(dim); outStrides.push(stride); } } const outValue = outShape.length === 0 ? this.value[this.offset] : this.value; const out = new Tensor(outValue, { shape: outShape, strides: outStrides, offset: this.offset, numel: this.numel, device: this.device, dtype: this.dtype }); // Set up gradient if needed if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { let restoredGrad = out.grad; for (let i = dims.length - 1; i >= 0; i--) { restoredGrad = restoredGrad.unsqueeze(dims[i]); } Tensor.addGrad(this, restoredGrad); }; } return out; } // Tensor unsqueeze - adds dimension of size 1 at specified position unsqueeze(dim) { // Handle negative indices if (dim < 0) { dim += this.shape.length; } let thisValue = this.value; // Insert size-1 dimension at specified position const newShape = [...this.shape]; newShape.splice(dim, 0, 1); // New stride const newStrides = [...this.strides]; let newDimStride; if (dim >= this.shape.length) { // Inserting at the back: use 1 newDimStride = 1; } else { // Inserting before dim: use current stride * current shape newDimStride = this.strides[dim] * this.shape[dim]; } newStrides.splice(dim, 0, newDimStride); const out = new Tensor(thisValue, { shape: newShape, strides: newStrides, offset: this.offset, numel: this.numel, device: this.device, dtype: this.dtype }); // Set up gradient if needed if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { Tensor.addGrad(this, out.grad.squeeze(dim)); }; } return out; } // Tensor sort sort(dim = -1, descending = false) { if (dim < 0) { dim += this.shape.length; } // If dimension out of bound, throw error if (dim >= this.shape.length || dim < 0) { throw new Error("Dimension do not exist to sort"); } // Copy if not contiguous const outputSize = this.numel; const outputShape = this.shape; const outputValue = new dtype_1.TypedArray[this.dtype](outputSize); const outputStrides = Tensor.getStrides(outputShape); if (this.isContiguous()) { // Fast path: direct copy outputValue.set(this.value.subarray(this.offset, this.offset + outputSize)); } else { // Slow path: coordinate conversion for (let flatIndex = 0; flatIndex < outputSize; flatIndex++) { const coords = Tensor.indexToCoords(flatIndex, outputStrides); const originalIndex = Tensor.coordsToIndex(coords, this.strides); outputValue[flatIndex] = this.value[originalIndex + this.offset]; } } // Calculate dimensions for gather-scatter const dimSize = outputShape[dim]; const outerSize = outputShape.slice(0, dim).reduce((a, b) => a * b, 1); const innerSize = outputShape.slice(dim + 1).reduce((a, b) => a * b, 1); // Store permutation indices for gradient const permutation = new Array(outputSize); // Sort each group independently for (let outer = 0; outer < outerSize; outer++) { for (let inner = 0; inner < innerSize; inner++) { const group = []; for (let i = 0; i < dimSize; i++) { const flatIdx = outer * (dimSize * innerSize) + i * innerSize + inner; group.push({ value: outputValue[flatIdx], dimIdx: i }); } // Sort this group by value group.sort((a, b) => descending ? b.value - a.value : a.value - b.value); // Scatter: write back sorted values and record permutation for (let i = 0; i < dimSize; i++) { const flatIdx = outer * (dimSize * innerSize) + i * innerSize + inner; outputValue[flatIdx] = group[i].value; // Record where this element came from (for gradient) const originalFlatIdx = outer * (dimSize * innerSize) + group[i].dimIdx * innerSize + inner; permutation[flatIdx] = originalFlatIdx; } } } const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, offset: 0, numel: outputSize, device: this.device, dtype: this.dtype }); // Gradient setup if (this.requiresGrad) { out.requiresGrad = true; out.children.push(this); out.gradFn = () => { const outGrad = out.grad; // Scatter output gradients back to original positions const inputGradValue = new dtype_1.TypedArray[this.dtype](outputSize); for (let sortedIdx = 0; sortedIdx < outputSize; sortedIdx++) { const originalIdx = permutation[sortedIdx]; inputGradValue[originalIdx] = outGrad.value[sortedIdx]; } const inputGrad = new Tensor(inputGradValue, { shape: outputShape, strides: outputStrides, offset: 0, numel: outputSize, device: this.device, dtype: this.dtype }); Tensor.addGrad(this, inputGrad); }; } return out; } // Top-k sampling topk(k, dim = -1, largest = true) { if (dim < 0) { dim += this.shape.length; } // If dimension out of bound, throw error if (dim >= this.shape.length || dim < 0) { throw new Error("Dimension do not exist to get topk"); } const dimRanges = new Array(this.shape.length); for (let index = 0; index < dimRanges.length; index++) { if (index === dim) { dimRanges[index] = [0, k]; } else { dimRanges[index] = []; } } return this.sort(dim, largest).slice(dimRanges); } // Generic reduction operation handler static reduce(tensor, dims, keepDims, config) { if (tensor.shape.length === 0) return tensor; if (typeof dims === "undefined") { dims = new Array(tensor.shape.length); for (let index = 0; index < dims.length; index++) { dims[index] = index; } } if (Array.isArray(dims)) { dims = Tensor.normalizeDims(dims, tensor.shape.length); const sortedDims = dims.sort((a, b) => b - a); let reducedThis = tensor; for (let i = 0; i < sortedDims.length; i++) { reducedThis = Tensor.reduce(reducedThis, sortedDims[i], true, config); } return keepDims ? reducedThis : reducedThis.squeeze(dims); } const dimSize = tensor.shape[dims]; const outputShape = tensor.shape.map((dim, i) => dims === i ? 1 : dim); const outputStrides = Tensor.getStrides(outputShape); const outputSize = tensor.numel / dimSize; const outputValue = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(config.identity); const originalSize = tensor.numel; const originalValue = tensor.value; const linearStrides = Tensor.getStrides(tensor.shape); // Forward pass for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) { // Convert linear index to coordinates using contiguous strides const coords = Tensor.indexToCoords(flatIndex, linearStrides); // Convert coordinates to actual strided index const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset; // Convert coords to reduced index coords[dims] = 0; const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides); // Apply op outputValue[outFlatIndex] = config.operation(outputValue[outFlatIndex], originalValue[realFlatIndex]); } // Post-process if needed (e.g., divide by count for mean) if (config.postProcess) { config.postProcess({ values: outputValue, dimSize }); } const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides, offset: 0, numel: outputSize, device: tensor.device, dtype: tensor.dtype }); // Gradient setup if (tensor.requiresGrad) { out.requiresGrad = true; out.children.push(tensor); out.gradFn = () => { let shareCounts = new dtype_1.TypedArray[tensor.dtype](); if (config.needsShareCounts) { shareCounts = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(0); for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) { // Convert linear index to coordinates using contiguous strides const coords = Tensor.indexToCoords(flatIndex, linearStrides); // Convert coordinates to actual strided index const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset; // Convert coords to reduced index coords[dims] = 0; const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides); // We collect how many elements share the same max value first shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0; } } const gradValue = new dtype_1.TypedArray[tensor.dtype](originalSize); for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) { // Convert linear index to coordinates using contiguous strides const coords = Tensor.indexToCoords(flatIndex, linearStrides); // Convert coordinates to actual strided index const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset; // Convert coords to reduced index coords[dims] = 0; const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides); gradValue[flatIndex] = config.gradientFn({ outputValue, originalValue: tensor.value, dimSize, shareCounts, realIndex: realFlatIndex, outIndex: outFlatIndex }); } const localGrad = new Tensor(gradValue, { shape: tensor.shape, offset: 0, numel: tensor.numel, device: tensor.device, dtype: tensor.dtype }); Tensor.addGrad(tensor, out.grad.mul(localGrad)); }; } return keepDims ? out : out.squeeze(dims); } // Simplified reduction operations sum(dims, keepDims = false) { return Tensor.reduce(this, dims, keepDims, { identity: 0, operation: (a, b) => a + b, gradientFn: ({}) => 1 }); } prod(dims, keepDims = false) { return Tensor.reduce(this, dims, keepDims, { identity: 1, operation: (a, b) => a * b, gradientFn: ({ outputValue, originalValue, realIndex, outIndex }) => outputValue[outIndex] / originalValue[realIndex] }); } mean(dims, keepDims = false) { return Tensor.reduce(this, dims, keepDims, { identity: 0, operation: (a, b) => a + b, postProcess: ({ values, dimSize }) => { for (let i = 0; i < values.length; i++) { values[i] /= dimSize; } }, gradientFn: ({ dimSize }) => 1 / dimSize }); } max(dims, keepDims = false) { return Tensor.reduce(this, dims, keepDims, { identity: -Infinity, operation: (a, b) => Math.max(a, b), needsShareCounts: true, gradientFn: ({ outputValue, originalValue, shareCounts, realIndex, outIndex }) => outputValue[outIndex] === originalValue[realIndex] ? 1 / shareCounts[outIndex] : 0 }); } min(dims, keepDims = false) { return Tensor.reduce(this, dims, keepDims, { identity: Infinity, operation: (a, b) => Math.min(a, b), needsShareCounts: true, gradientFn: ({ outputValue, originalValue, shareCounts, realIndex, outIndex }) => outputValue[outIndex] === originalValue[realIndex] ? 1 / shareCounts[outIndex] : 0 }); } // Tensor all condition reduction all(dims, keepDims = false) { return this.min(dims, keepDims).ne(0); } // Tensor any condition reduction any(dims, keepDims = false) { return this.max(dims, keepDims).ne(0); } // Tensor variance reduction var(dims, keepDims = false) { const meanXSquared = this.square().mean(dims, keepDims); const meanXSquaredExpanded = this.mean(dims, keepDims).square(); return meanXSquared.sub(meanXSquaredExpanded); } // Tensor standard deviation reduction std(dims, keepDims = false) { return this.var(dims, keepDims).sqrt(); } // Tensor softmax softmax(dim = -1) { if (this.shape.length === 0) return this; // Handle negative indexing if (dim < 0) { dim += this.shape.length; } // If dimension out of bound, throw error if (dim >= this.shape.length || dim < 0) { throw new Error("Dimension do not exist to apply softmax"); } const maxVals = this.max(dim, true); const shifted = this.sub(maxVals); const expVals = shifted.exp(); const sumExp = expVals.sum(dim, true); return expVals.div(sumExp); } // Tensor softmin softmin(dim = -1) { if (this.shape.length === 0) return this; // Handle negative indexing if (dim < 0) { dim += this.shape.length; } // If dimension out of bound, throw error if (dim >= this.shape.length || dim < 0) { throw new Error("Dimension do not exist to apply softmin"); } const maxVals = this.max(dim, true); const shifted = maxVals.sub(this); const expVals = shifted.exp(); const sumExp = expVals.sum(dim, true); return expVals.div(sumExp); } // Tensor element-wise addition add(other) { return this.