catniff
Version:
Torch-like deep learning framework for Javascript
1,248 lines • 105 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.Tensor = void 0;
const dtype_1 = require("./dtype");
const utils_1 = require("./utils");
class Tensor {
value;
shape;
strides;
offset;
numel;
grad;
requiresGrad;
gradFn;
children;
device;
dtype;
static training = false;
static noGrad = false;
static createGraph = false;
constructor(value, options = {}) {
// Memory buffer
this.dtype = options.dtype || "float32";
const flatValue = Tensor.flattenValue(value);
const TypedArrayConstructor = dtype_1.TypedArray[this.dtype];
this.value = flatValue instanceof TypedArrayConstructor ? flatValue : TypedArrayConstructor.from(flatValue);
// Tensor metadata
this.shape = options.shape || Tensor.getShape(value);
this.strides = options.strides || Tensor.getStrides(this.shape);
this.offset = options.offset || 0;
this.numel = options.numel || Tensor.shapeToSize(this.shape);
this.device = options.device || "cpu";
// Autograd data
this.grad = options.grad;
this.requiresGrad = options.requiresGrad ?? false;
this.gradFn = options.gradFn || (() => { });
this.children = options.children || [];
// Move to device in-place
this.to_(this.device);
}
// Utility to flatten an nD array to be 1D
static flattenValue(tensorValue) {
// Handle scalar tensors
if (typeof tensorValue === "number")
return [tensorValue];
// If value is already 1D, we just need to return the value ('s reference)
if (typeof tensorValue[0] === "number")
return tensorValue;
// Or else recursively traverse through the nD array to flatten
const result = [];
function traverse(arr) {
if (typeof arr === "number") {
result.push(arr);
// Assume if we can index a value, it is an ArrayLike
}
else if (typeof arr[0] !== "undefined") {
for (let index = 0; index < arr.length; index++) {
traverse(arr[index]);
}
}
}
traverse(tensorValue);
return result;
}
// Utility to get shape from tensor *value*
static getShape(tensorValue) {
const shape = [];
let subA = tensorValue;
while (typeof subA !== "number") {
shape.push(subA.length);
subA = subA[0];
}
return shape;
}
// Utility to get strides from shape
static getStrides(shape) {
if (shape.length === 0)
return [];
const strides = new Array(shape.length);
strides[strides.length - 1] = 1;
for (let i = strides.length - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
// Left-pad shape and strides for two shape to be of same length
static padShape(stridesA, stridesB, shapeA, shapeB) {
const newStrideA = [...stridesA], newStrideB = [...stridesB];
const newShapeA = [...shapeA], newShapeB = [...shapeB];
while (newStrideA.length < newStrideB.length) {
const newStride = newShapeA[0] * newStrideA[0];
newStrideA.unshift(newStride);
newShapeA.unshift(1);
}
while (newStrideA.length > newStrideB.length) {
const newStride = newShapeB[0] * newStrideB[0];
newStrideB.unshift(newStride);
newShapeB.unshift(1);
}
return [newStrideA, newStrideB, newShapeA, newShapeB];
}
// Broadcast shapes
static broadcastShapes(shapeA, shapeB) {
const newShape = new Array(shapeA.length);
for (let index = 0; index < shapeA.length; index++) {
if (shapeA[index] === 1) {
newShape[index] = shapeB[index];
}
else if (shapeB[index] === 1) {
newShape[index] = shapeA[index];
}
else if (shapeA[index] === shapeB[index]) {
newShape[index] = shapeA[index];
}
else {
throw new Error(`Can not broadcast shapes: ${shapeA} and ${shapeB}`);
}
}
return newShape;
}
// Utility to convert flat index to array of coordinates
static indexToCoords(index, strides) {
const coords = new Array(strides.length);
let remaining = index;
for (let dim = 0; dim < strides.length; dim++) {
coords[dim] = Math.floor(remaining / strides[dim]);
remaining %= strides[dim];
}
return coords;
}
// Utility to convert array of coordinates to *unbroadcasted* flat index
static coordsToUnbroadcastedIndex(coords, shape, strides) {
let index = 0;
for (let i = 0; i < coords.length; i++) {
// Handle broadcasting
const actualCoord = shape[i] === 1 ? 0 : coords[i];
index += actualCoord * strides[i];
}
return index;
}
// Utility to convert array of coordinates to flat index
static coordsToIndex(coords, strides) {
let index = 0;
for (let i = 0; i < coords.length; i++) {
index += coords[i] * strides[i];
}
return index;
}
// Utility to convert shape into 1D value array size
static shapeToSize(shape) {
let prod = 1;
for (let i = 0; i < shape.length; i++) {
prod *= shape[i];
}
return prod;
}
// Utility to get best possible result type if type conflicts happen:
static getResultDtype(type1, type2) {
if (type1 === type2)
return type1;
const type1Ranking = dtype_1.dtypeHiearchy[type1];
const type2Ranking = dtype_1.dtypeHiearchy[type2];
if (type1Ranking > type2Ranking) {
return type1;
}
return type2;
}
// Utility to handle other tensor if an op needs a second operand
handleOther(other) {
if (other instanceof Tensor) {
if (this.device !== other.device) {
throw new Error("Can not operate on tensors that are not on the same device");
}
return other;
}
return new Tensor(other, {
offset: 0,
device: this.device,
dtype: this.dtype
});
}
// Utility for binary (two operators involved) element-wise ops
static elementWiseAB(tA, tB, op) {
const outputDtype = Tensor.getResultDtype(tA.dtype, tB.dtype);
// Both are scalars
if (tA.shape.length === 0 && tB.shape.length === 0) {
return new Tensor(op(tA.value[0], tB.value[0]), {
shape: [],
strides: [],
offset: 0,
numel: 1,
device: tA.device,
dtype: outputDtype
});
}
// First tensor is scalar
if (tA.shape.length === 0) {
return Tensor.elementWiseSelf(tB.cast(outputDtype), (a) => op(a, tA.value[0]));
}
// Second tensor is scalar
if (tB.shape.length === 0) {
return Tensor.elementWiseSelf(tA.cast(outputDtype), (a) => op(a, tB.value[0]));
}
// Pad + broadcast shape
const [paddedAStrides, paddedBStrides, paddedAShape, paddedBShape] = Tensor.padShape(tA.strides, tB.strides, tA.shape, tB.shape);
const outputShape = Tensor.broadcastShapes(paddedAShape, paddedBShape);
// Get other output info
const outputStrides = Tensor.getStrides(outputShape);
const outputSize = Tensor.shapeToSize(outputShape);
const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
// Check fast path conditions of two tensors
const aFastPath = tA.isContiguous() && tA.numel === outputSize;
const bFastPath = tB.isContiguous() && tB.numel === outputSize;
for (let i = 0; i < outputSize; i++) {
// Get coordinates from 1D index
const coordsOutput = aFastPath && bFastPath ? [] : Tensor.indexToCoords(i, outputStrides);
// Convert the coordinates to 1D index of flattened A with respect to A's shape
const indexA = aFastPath ? i : Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedAShape, paddedAStrides);
// Convert the coordinates to 1D index of flattened B with respect to B's shape
const indexB = bFastPath ? i : Tensor.coordsToUnbroadcastedIndex(coordsOutput, paddedBShape, paddedBStrides);
// Calculate with op
outputValue[i] = op(tA.value[indexA + tA.offset], tB.value[indexB + tB.offset]);
}
return new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides,
offset: 0,
numel: outputSize,
device: tA.device,
dtype: outputDtype
});
}
// Utility for self-inflicting element-wise ops
static elementWiseSelf(tA, op) {
// Handle scalar case
if (tA.shape.length === 0)
return new Tensor(op(tA.value[0]), {
shape: [],
strides: [],
offset: 0,
numel: 1,
device: tA.device,
dtype: tA.dtype
});
const contiguous = tA.isContiguous();
const outputShape = tA.shape;
const outputStrides = contiguous ? tA.strides : Tensor.getStrides(outputShape);
const outputSize = tA.numel;
const outputValue = new dtype_1.TypedArray[tA.dtype](outputSize);
if (contiguous) {
for (let index = 0; index < outputSize; index++) {
outputValue[index] = op(tA.value[index + tA.offset]);
}
}
else {
for (let index = 0; index < outputSize; index++) {
const outputCoords = Tensor.indexToCoords(index, outputStrides);
const originalIndex = tA.offset + Tensor.coordsToIndex(outputCoords, tA.strides);
outputValue[index] = op(tA.value[originalIndex]);
}
}
return new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides,
offset: 0,
numel: tA.numel,
device: tA.device,
dtype: tA.dtype
});
}
// Utility to do element-wise operation and build a dag node with another tensor
elementWiseABDAG(other, op, thisGrad = () => new Tensor(0), otherGrad = () => new Tensor(0)) {
other = this.handleOther(other);
const out = Tensor.elementWiseAB(this, other, op);
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
}
if (other.requiresGrad) {
out.requiresGrad = true;
out.children.push(other);
}
if (out.requiresGrad) {
out.gradFn = () => {
const outGrad = out.grad;
const selfWithGrad = Tensor.createGraph ? this : this.detach();
const otherWithGrad = Tensor.createGraph ? other : other.detach();
if (this.requiresGrad)
Tensor.addGrad(this, thisGrad(selfWithGrad, otherWithGrad, outGrad));
if (other.requiresGrad)
Tensor.addGrad(other, otherGrad(selfWithGrad, otherWithGrad, outGrad));
};
}
return out;
}
// Utility to do self-inflicting element-wise operation and build a dag node
elementWiseSelfDAG(op, thisGrad = () => new Tensor(0)) {
const out = Tensor.elementWiseSelf(this, op);
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
}
if (out.requiresGrad) {
out.gradFn = () => {
const outGrad = out.grad;
const selfWithGrad = Tensor.createGraph ? this : this.detach();
if (this.requiresGrad)
Tensor.addGrad(this, thisGrad(selfWithGrad, outGrad));
};
}
return out;
}
// Utility to add to gradient of tensor
static addGrad(tensor, accumGrad) {
const axesToSqueeze = [];
const axesToReduce = [];
const shape = tensor.shape;
const gradShape = accumGrad.shape;
const paddedDims = gradShape.length - shape.length;
for (let i = 0; i < paddedDims; i++) {
axesToReduce.push(i);
axesToSqueeze.push(i);
}
for (let i = 0; i < shape.length; i++) {
if (shape[i] === 1 && gradShape[i + paddedDims] > 1) {
axesToReduce.push(i + paddedDims);
}
}
const reducedGrad = accumGrad.sum(axesToReduce, true);
const squeezedGrad = reducedGrad.squeeze(axesToSqueeze);
if (typeof tensor.grad === "undefined") {
tensor.grad = squeezedGrad;
}
else {
tensor.grad = tensor.grad.add(squeezedGrad.cast(tensor.dtype));
}
}
static normalizeDims(dims, numDims) {
for (let index = 0; index < dims.length; index++) {
// Handle negative indices
if (dims[index] < 0) {
dims[index] += numDims;
}
// If dimension out of bound, throw error
if (dims[index] >= numDims || dims[index] < 0) {
throw new Error("Dimensions do not exist");
}
}
return dims;
}
// Contiguity-related ops
isContiguous() {
const expectedStrides = Tensor.getStrides(this.shape);
for (let i = 0; i < this.strides.length; i++) {
if (this.strides[i] !== expectedStrides[i]) {
return false;
}
}
return true;
}
contiguous() {
// Check if scalar
if (this.shape.length === 0)
return this;
// Check if already contiguous
if (this.isContiguous())
return this;
const outputStrides = Tensor.getStrides(this.shape);
const outputSize = this.numel;
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
for (let index = 0; index < outputSize; index++) {
const outputCoords = Tensor.indexToCoords(index, outputStrides);
const originalIndex = Tensor.coordsToIndex(outputCoords, this.strides);
outputValue[index] = this.value[this.offset + originalIndex];
}
const out = new Tensor(outputValue, {
shape: this.shape,
strides: outputStrides,
offset: 0,
numel: outputSize,
device: this.device,
dtype: this.dtype
});
// Gradient flow back to the original tensor
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
Tensor.addGrad(this, out.grad);
};
}
return out;
}
view(newShape) {
// Verify shape size
const originalSize = this.numel;
const outputSize = Tensor.shapeToSize(newShape);
if (originalSize !== outputSize) {
throw new Error("Can not create view: incompatible sizes");
}
// Verify compatibility (only contiguity for now)
if (!this.isContiguous()) {
throw new Error("Can not create view: incompatible metadata");
}
const outputStrides = Tensor.getStrides(newShape);
const out = new Tensor(this.value, {
shape: newShape,
strides: outputStrides,
offset: this.offset,
numel: outputSize,
device: this.device,
dtype: this.dtype
});
// Gradient reshaped and flow back to the original tensor
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
Tensor.addGrad(this, out.grad.reshape(this.shape));
};
}
return out;
}
reshape(newShape) {
return this.contiguous().view(newShape);
}
flatten(startDim = 0, endDim = -1) {
// Handle negative indices
if (startDim < 0) {
startDim += this.shape.length;
}
if (endDim < 0) {
endDim += this.shape.length;
}
// If dimension out of bound, throw error
if (startDim >= this.shape.length || endDim >= this.shape.length || startDim < 0 || endDim < 0) {
throw new Error("Dimensions do not exist to flatten");
}
const newShape = [];
let middleSize = 1;
for (let index = 0; index < this.shape.length; index++) {
// Keep dims before startDim
if (index < startDim) {
newShape.push(this.shape[index]);
}
// Multiply dims from startDim to endDim
if (index >= startDim && index <= endDim) {
middleSize *= this.shape[index];
}
// Push new flatten middle
if (index === endDim) {
newShape.push(middleSize);
}
// Keep dims after endDim
if (index > endDim) {
newShape.push(this.shape[index]);
}
}
return this.reshape(newShape);
}
// Transpose
transpose(dim1, dim2) {
// Handle negative indices
if (dim1 < 0) {
dim1 += this.shape.length;
}
if (dim2 < 0) {
dim2 += this.shape.length;
}
// If dimension out of bound, throw error
if (dim1 >= this.shape.length || dim2 >= this.shape.length || dim1 < 0 || dim2 < 0) {
throw new Error("Dimensions do not exist to transpose");
}
// If same dimension, return view
if (dim1 === dim2)
return this;
// Create new shape and strides by swapping
const newShape = [...this.shape];
const newStrides = [...this.strides];
[newShape[dim1], newShape[dim2]] = [newShape[dim2], newShape[dim1]];
[newStrides[dim1], newStrides[dim2]] = [newStrides[dim2], newStrides[dim1]];
// Create new tensor with same data but swapped shape/strides
const out = new Tensor(this.value, {
shape: newShape,
strides: newStrides,
offset: this.offset,
numel: this.numel,
device: this.device,
dtype: this.dtype
});
out.requiresGrad = this.requiresGrad;
// Handle gradient if needed
if (this.requiresGrad) {
out.children.push(this);
out.gradFn = () => {
Tensor.addGrad(this, out.grad.transpose(dim1, dim2));
};
}
return out;
}
swapaxes = this.transpose;
swapdims = this.transpose;
// Transpose 2D
t() {
// Verify matrix shape
if (this.shape.length !== 2) {
throw new Error("Input is not a matrix");
}
return this.transpose(0, 1);
}
// Permute
permute(dims) {
dims = Tensor.normalizeDims(dims, this.shape.length);
if (dims.length !== this.shape.length) {
throw new Error("Permutation must specify all dimensions");
}
// Compute new shape and strides
const newShape = new Array(dims.length);
const newStrides = new Array(dims.length);
for (let index = 0; index < dims.length; index++) {
const dim = dims[index];
newShape[index] = this.shape[dim];
newStrides[index] = this.strides[dim];
}
const out = new Tensor(this.value, {
shape: newShape,
strides: newStrides,
offset: this.offset,
numel: this.numel,
device: this.device,
dtype: this.dtype
});
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
// Compute inverse permutation
const inverseAxes = new Array(dims.length);
for (let i = 0; i < dims.length; i++) {
inverseAxes[dims[i]] = i;
}
// Permute gradient back to original order
const permutedGrad = out.grad.permute(inverseAxes);
Tensor.addGrad(this, permutedGrad);
};
}
return out;
}
// Utility for indexing with array of indices
indexWithArray(indices) {
if (this.shape.length === 0)
return this;
indices = Tensor.normalizeDims(indices, this.shape[0]);
// Init necessary stuff for indexing
const reducedShape = this.shape.slice(1);
const reducedStrides = this.strides.slice(1);
const elementsPerIndex = Tensor.shapeToSize(reducedShape);
// Init output data
const outputShape = [indices.length, ...reducedShape];
const outputSize = Tensor.shapeToSize(outputShape);
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
for (let i = 0; i < indices.length; i++) {
const sourceRowIndex = indices[i];
const targetStart = i * elementsPerIndex;
for (let j = 0; j < elementsPerIndex; j++) {
const fullCoords = Tensor.indexToCoords(j, reducedStrides);
fullCoords.unshift(sourceRowIndex);
const sourceIndex = Tensor.coordsToIndex(fullCoords, this.strides);
outputValue[targetStart + j] = this.value[this.offset + sourceIndex];
}
}
const out = new Tensor(outputValue, {
shape: outputShape,
offset: 0,
numel: outputSize,
device: this.device,
dtype: this.dtype
});
// Handle gradient
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
const outGrad = out.grad;
// Create zero gradient tensor with original shape
const grad = Tensor.zerosLike(this);
// Scatter gradients back to original positions
for (let i = 0; i < indices.length; i++) {
const originalRowIndex = indices[i];
const sourceStart = i * elementsPerIndex;
for (let j = 0; j < elementsPerIndex; j++) {
const fullCoords = Tensor.indexToCoords(j, reducedStrides);
fullCoords.unshift(originalRowIndex);
const targetIndex = Tensor.coordsToIndex(fullCoords, this.strides);
grad.value[targetIndex] += outGrad.value[sourceStart + j];
}
}
Tensor.addGrad(this, grad);
};
}
return out;
}
// Tensor indexing
index(indices) {
const tensorIndices = this.handleOther(indices).clone();
if (tensorIndices.shape.length === 0) {
return this.indexWithArray([tensorIndices.value[0]]).squeeze(0);
}
else {
const originalShape = tensorIndices.shape;
const flatIndices = tensorIndices.value;
const result = this.indexWithArray(Array.from(flatIndices));
// Reshape to preserve input shape
const outputShape = [...originalShape, ...this.shape.slice(1)];
return result.reshape(outputShape);
}
}
// Tensor slicing
slice(ranges) {
// Handle scalars
if (this.shape.length === 0)
return this;
const newShape = [];
const newStrides = [];
let newOffset = this.offset || 0;
// Pad ranges to match tensor dimensions
const paddedRanges = [...ranges];
while (paddedRanges.length < this.shape.length) {
paddedRanges.push([]);
}
for (let i = 0; i < this.shape.length; i++) {
const range = paddedRanges[i] || [];
const dimSize = this.shape[i];
const stride = this.strides[i];
// Default values
let start = range[0] ?? 0;
let end = range[1] ?? dimSize;
let step = range[2] ?? 1;
// Handle negative indices
if (start < 0)
start += dimSize;
if (end < 0)
end += dimSize;
// Clamp to valid range
start = Math.max(0, Math.min(start, dimSize));
end = Math.max(0, Math.min(end, dimSize));
// Calculate new dimension size
const newDimSize = step > 0
? Math.max(0, Math.ceil((end - start) / step))
: Math.max(0, Math.ceil((start - end) / Math.abs(step)));
newShape.push(newDimSize);
newStrides.push(stride * step);
newOffset += start * stride;
}
const out = new Tensor(this.value, {
shape: newShape,
strides: newStrides,
offset: newOffset,
device: this.device,
dtype: this.dtype
});
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
// Create zero tensor of original shape
const grad = Tensor.zerosLike(this);
// Upstream grad
const outGrad = out.grad;
const totalElements = outGrad.numel;
for (let i = 0; i < totalElements; i++) {
// Convert flat index to coordinates in sliced tensor
const slicedCoords = Tensor.indexToCoords(i, outGrad.strides);
// Map back to original coordinates
const originalCoords = new Array(slicedCoords.length);
for (let dim = 0; dim < slicedCoords.length; dim++) {
const coord = slicedCoords[dim];
const range = paddedRanges[dim] || [];
const start = range[0] ?? 0;
const step = range[2] ?? 1;
const normalizedStart = start < 0 ? start + this.shape[dim] : start;
originalCoords[dim] = normalizedStart + coord * step;
}
// Get flat indices with offsets
const srcIndex = Tensor.coordsToIndex(slicedCoords, outGrad.strides) + outGrad.offset;
const targetIndex = Tensor.coordsToIndex(originalCoords, grad.strides) + grad.offset;
// Accumulate gradient
grad.value[targetIndex] += outGrad.value[srcIndex];
}
Tensor.addGrad(this, grad);
};
}
return out;
}
// Tensor chunk
chunk(chunks, dim = 0) {
// Handle negative indices
if (dim < 0) {
dim += this.shape.length;
}
// If dimension out of bound, throw error
if (dim >= this.shape.length || dim < 0) {
throw new Error("Dimension do not exist to chunk");
}
const sliceOpt = new Array(this.shape.length);
for (let index = 0; index < sliceOpt.length; index++) {
sliceOpt[index] = [];
}
const dimSize = this.shape[dim];
const chunkDimSize = Math.ceil(dimSize / chunks);
const results = [];
for (let index = 0; index < dimSize; index += chunkDimSize) {
sliceOpt[dim] = [index, Math.min(index + chunkDimSize, dimSize)];
results.push(this.slice(sliceOpt));
}
return results;
}
// Tensor expansion
expand(newShape) {
// Handle scalars
let self = this;
if (this.shape.length === 0) {
self = self.unsqueeze(0);
}
// Pad shapes to same length
const ndim = Math.max(self.shape.length, newShape.length);
const oldShape = [...Array(ndim - self.shape.length).fill(1), ...self.shape];
const oldStrides = [...Array(ndim - self.strides.length).fill(0), ...self.strides];
const targetShape = [...Array(ndim - newShape.length).fill(1), ...newShape];
const newStrides = new Array(ndim);
for (let i = 0; i < ndim; i++) {
if (oldShape[i] === targetShape[i]) {
newStrides[i] = oldStrides[i];
}
else if (oldShape[i] === 1) {
newStrides[i] = 0;
}
else {
throw new Error(`Cannot expand dimension of size ${oldShape[i]} to ${targetShape[i]}`);
}
}
const out = new Tensor(self.value, {
shape: targetShape,
strides: newStrides,
offset: self.offset,
device: self.device,
dtype: self.dtype
});
if (self.requiresGrad) {
out.requiresGrad = true;
out.children.push(self);
out.gradFn = () => {
Tensor.addGrad(self, out.grad);
};
}
return out;
}
// Tensor concatentation
cat(other, dim = 0) {
other = this.handleOther(other);
// Handle scalars
if (this.shape.length === 0 || other.shape.length === 0) {
throw new Error("Can not concatenate scalars");
}
// Handle negative indices
if (dim < 0) {
dim += this.shape.length;
}
// If dimension out of bound, throw error
if (dim >= this.shape.length || dim < 0) {
throw new Error("Dimension does not exist to concatenate");
}
// If shape does not match, throw error
if (this.shape.length !== other.shape.length) {
throw new Error("Shape does not match to concatenate");
}
const outputShape = new Array(this.shape.length);
for (let currentDim = 0; currentDim < this.shape.length; currentDim++) {
if (currentDim === dim) {
outputShape[currentDim] = this.shape[currentDim] + other.shape[currentDim];
}
else if (this.shape[currentDim] !== other.shape[currentDim]) {
throw new Error("Shape does not match to concatenate");
}
else {
outputShape[currentDim] = this.shape[currentDim];
}
}
const outputSize = Tensor.shapeToSize(outputShape);
const outputStrides = Tensor.getStrides(outputShape);
const outputDtype = Tensor.getResultDtype(this.dtype, other.dtype);
const outputValue = new dtype_1.TypedArray[outputDtype](outputSize);
for (let outIndex = 0; outIndex < outputSize; outIndex++) {
const coords = Tensor.indexToCoords(outIndex, outputStrides);
// Check which tensor this output position comes from
if (coords[dim] < this.shape[dim]) {
// Comes from this tensor
const srcIndex = Tensor.coordsToIndex(coords, this.strides);
outputValue[outIndex] = this.value[srcIndex + this.offset];
}
else {
// Comes from other tensor - adjust coordinate in concat dimension
const otherCoords = [...coords];
otherCoords[dim] -= this.shape[dim];
const srcIndex = Tensor.coordsToIndex(otherCoords, other.strides);
outputValue[outIndex] = other.value[srcIndex + other.offset];
}
}
const out = new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides,
offset: 0,
numel: outputSize,
device: this.device,
dtype: this.dtype
});
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
}
if (other.requiresGrad) {
out.requiresGrad = true;
out.children.push(other);
}
if (out.requiresGrad) {
out.gradFn = () => {
const outGrad = out.grad;
const thisRanges = new Array(this.shape.length);
const otherRanges = new Array(other.shape.length);
for (let currentDim = 0; currentDim < this.shape.length; currentDim++) {
if (currentDim === dim) {
thisRanges[currentDim] = [0, this.shape[currentDim], 1];
otherRanges[currentDim] = [this.shape[currentDim], outputShape[currentDim], 1];
}
else {
thisRanges[currentDim] = [];
otherRanges[currentDim] = [];
}
}
Tensor.addGrad(this, outGrad.slice(thisRanges));
Tensor.addGrad(other, outGrad.slice(otherRanges));
};
}
return out;
}
// Tensor squeeze
squeeze(dims) {
if (this.shape.length === 0)
return this;
if (typeof dims === "number") {
dims = [dims];
}
if (typeof dims === "undefined") {
const shape = this.shape;
dims = [];
for (let index = 0; index < shape.length; index++) {
if (shape[index] === 1) {
dims.push(index);
}
}
}
dims = Tensor.normalizeDims(dims, this.shape.length);
// Remove size-1 dims only
const outShape = [], outStrides = [];
for (let index = 0; index < this.shape.length; index++) {
const dim = this.shape[index];
const stride = this.strides[index];
if (dims.includes(index)) {
if (dim !== 1)
throw new Error(`Can not squeeze dim with size ${dim}`);
}
else {
outShape.push(dim);
outStrides.push(stride);
}
}
const outValue = outShape.length === 0 ? this.value[this.offset] : this.value;
const out = new Tensor(outValue, {
shape: outShape,
strides: outStrides,
offset: this.offset,
numel: this.numel,
device: this.device,
dtype: this.dtype
});
// Set up gradient if needed
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
let restoredGrad = out.grad;
for (let i = dims.length - 1; i >= 0; i--) {
restoredGrad = restoredGrad.unsqueeze(dims[i]);
}
Tensor.addGrad(this, restoredGrad);
};
}
return out;
}
// Tensor unsqueeze - adds dimension of size 1 at specified position
unsqueeze(dim) {
// Handle negative indices
if (dim < 0) {
dim += this.shape.length;
}
let thisValue = this.value;
// Insert size-1 dimension at specified position
const newShape = [...this.shape];
newShape.splice(dim, 0, 1);
// New stride
const newStrides = [...this.strides];
let newDimStride;
if (dim >= this.shape.length) {
// Inserting at the back: use 1
newDimStride = 1;
}
else {
// Inserting before dim: use current stride * current shape
newDimStride = this.strides[dim] * this.shape[dim];
}
newStrides.splice(dim, 0, newDimStride);
const out = new Tensor(thisValue, {
shape: newShape,
strides: newStrides,
offset: this.offset,
numel: this.numel,
device: this.device,
dtype: this.dtype
});
// Set up gradient if needed
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
Tensor.addGrad(this, out.grad.squeeze(dim));
};
}
return out;
}
// Tensor sort
sort(dim = -1, descending = false) {
if (dim < 0) {
dim += this.shape.length;
}
// If dimension out of bound, throw error
if (dim >= this.shape.length || dim < 0) {
throw new Error("Dimension do not exist to sort");
}
// Copy if not contiguous
const outputSize = this.numel;
const outputShape = this.shape;
const outputValue = new dtype_1.TypedArray[this.dtype](outputSize);
const outputStrides = Tensor.getStrides(outputShape);
if (this.isContiguous()) {
// Fast path: direct copy
outputValue.set(this.value.subarray(this.offset, this.offset + outputSize));
}
else {
// Slow path: coordinate conversion
for (let flatIndex = 0; flatIndex < outputSize; flatIndex++) {
const coords = Tensor.indexToCoords(flatIndex, outputStrides);
const originalIndex = Tensor.coordsToIndex(coords, this.strides);
outputValue[flatIndex] = this.value[originalIndex + this.offset];
}
}
// Calculate dimensions for gather-scatter
const dimSize = outputShape[dim];
const outerSize = outputShape.slice(0, dim).reduce((a, b) => a * b, 1);
const innerSize = outputShape.slice(dim + 1).reduce((a, b) => a * b, 1);
// Store permutation indices for gradient
const permutation = new Array(outputSize);
// Sort each group independently
for (let outer = 0; outer < outerSize; outer++) {
for (let inner = 0; inner < innerSize; inner++) {
const group = [];
for (let i = 0; i < dimSize; i++) {
const flatIdx = outer * (dimSize * innerSize) + i * innerSize + inner;
group.push({
value: outputValue[flatIdx],
dimIdx: i
});
}
// Sort this group by value
group.sort((a, b) => descending ? b.value - a.value : a.value - b.value);
// Scatter: write back sorted values and record permutation
for (let i = 0; i < dimSize; i++) {
const flatIdx = outer * (dimSize * innerSize) + i * innerSize + inner;
outputValue[flatIdx] = group[i].value;
// Record where this element came from (for gradient)
const originalFlatIdx = outer * (dimSize * innerSize) + group[i].dimIdx * innerSize + inner;
permutation[flatIdx] = originalFlatIdx;
}
}
}
const out = new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides,
offset: 0,
numel: outputSize,
device: this.device,
dtype: this.dtype
});
// Gradient setup
if (this.requiresGrad) {
out.requiresGrad = true;
out.children.push(this);
out.gradFn = () => {
const outGrad = out.grad;
// Scatter output gradients back to original positions
const inputGradValue = new dtype_1.TypedArray[this.dtype](outputSize);
for (let sortedIdx = 0; sortedIdx < outputSize; sortedIdx++) {
const originalIdx = permutation[sortedIdx];
inputGradValue[originalIdx] = outGrad.value[sortedIdx];
}
const inputGrad = new Tensor(inputGradValue, {
shape: outputShape,
strides: outputStrides,
offset: 0,
numel: outputSize,
device: this.device,
dtype: this.dtype
});
Tensor.addGrad(this, inputGrad);
};
}
return out;
}
// Top-k sampling
topk(k, dim = -1, largest = true) {
if (dim < 0) {
dim += this.shape.length;
}
// If dimension out of bound, throw error
if (dim >= this.shape.length || dim < 0) {
throw new Error("Dimension do not exist to get topk");
}
const dimRanges = new Array(this.shape.length);
for (let index = 0; index < dimRanges.length; index++) {
if (index === dim) {
dimRanges[index] = [0, k];
}
else {
dimRanges[index] = [];
}
}
return this.sort(dim, largest).slice(dimRanges);
}
// Generic reduction operation handler
static reduce(tensor, dims, keepDims, config) {
if (tensor.shape.length === 0)
return tensor;
if (typeof dims === "undefined") {
dims = new Array(tensor.shape.length);
for (let index = 0; index < dims.length; index++) {
dims[index] = index;
}
}
if (Array.isArray(dims)) {
dims = Tensor.normalizeDims(dims, tensor.shape.length);
const sortedDims = dims.sort((a, b) => b - a);
let reducedThis = tensor;
for (let i = 0; i < sortedDims.length; i++) {
reducedThis = Tensor.reduce(reducedThis, sortedDims[i], true, config);
}
return keepDims ? reducedThis : reducedThis.squeeze(dims);
}
const dimSize = tensor.shape[dims];
const outputShape = tensor.shape.map((dim, i) => dims === i ? 1 : dim);
const outputStrides = Tensor.getStrides(outputShape);
const outputSize = tensor.numel / dimSize;
const outputValue = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(config.identity);
const originalSize = tensor.numel;
const originalValue = tensor.value;
const linearStrides = Tensor.getStrides(tensor.shape);
// Forward pass
for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
// Convert linear index to coordinates using contiguous strides
const coords = Tensor.indexToCoords(flatIndex, linearStrides);
// Convert coordinates to actual strided index
const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
// Convert coords to reduced index
coords[dims] = 0;
const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
// Apply op
outputValue[outFlatIndex] = config.operation(outputValue[outFlatIndex], originalValue[realFlatIndex]);
}
// Post-process if needed (e.g., divide by count for mean)
if (config.postProcess) {
config.postProcess({ values: outputValue, dimSize });
}
const out = new Tensor(outputValue, {
shape: outputShape,
strides: outputStrides,
offset: 0,
numel: outputSize,
device: tensor.device,
dtype: tensor.dtype
});
// Gradient setup
if (tensor.requiresGrad) {
out.requiresGrad = true;
out.children.push(tensor);
out.gradFn = () => {
let shareCounts = new dtype_1.TypedArray[tensor.dtype]();
if (config.needsShareCounts) {
shareCounts = new dtype_1.TypedArray[tensor.dtype](outputSize).fill(0);
for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
// Convert linear index to coordinates using contiguous strides
const coords = Tensor.indexToCoords(flatIndex, linearStrides);
// Convert coordinates to actual strided index
const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
// Convert coords to reduced index
coords[dims] = 0;
const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
// We collect how many elements share the same max value first
shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
}
}
const gradValue = new dtype_1.TypedArray[tensor.dtype](originalSize);
for (let flatIndex = 0; flatIndex < originalSize; flatIndex++) {
// Convert linear index to coordinates using contiguous strides
const coords = Tensor.indexToCoords(flatIndex, linearStrides);
// Convert coordinates to actual strided index
const realFlatIndex = Tensor.coordsToIndex(coords, tensor.strides) + tensor.offset;
// Convert coords to reduced index
coords[dims] = 0;
const outFlatIndex = Tensor.coordsToIndex(coords, outputStrides);
gradValue[flatIndex] = config.gradientFn({
outputValue,
originalValue: tensor.value,
dimSize,
shareCounts,
realIndex: realFlatIndex,
outIndex: outFlatIndex
});
}
const localGrad = new Tensor(gradValue, {
shape: tensor.shape,
offset: 0,
numel: tensor.numel,
device: tensor.device,
dtype: tensor.dtype
});
Tensor.addGrad(tensor, out.grad.mul(localGrad));
};
}
return keepDims ? out : out.squeeze(dims);
}
// Simplified reduction operations
sum(dims, keepDims = false) {
return Tensor.reduce(this, dims, keepDims, {
identity: 0,
operation: (a, b) => a + b,
gradientFn: ({}) => 1
});
}
prod(dims, keepDims = false) {
return Tensor.reduce(this, dims, keepDims, {
identity: 1,
operation: (a, b) => a * b,
gradientFn: ({ outputValue, originalValue, realIndex, outIndex }) => outputValue[outIndex] / originalValue[realIndex]
});
}
mean(dims, keepDims = false) {
return Tensor.reduce(this, dims, keepDims, {
identity: 0,
operation: (a, b) => a + b,
postProcess: ({ values, dimSize }) => {
for (let i = 0; i < values.length; i++) {
values[i] /= dimSize;
}
},
gradientFn: ({ dimSize }) => 1 / dimSize
});
}
max(dims, keepDims = false) {
return Tensor.reduce(this, dims, keepDims, {
identity: -Infinity,
operation: (a, b) => Math.max(a, b),
needsShareCounts: true,
gradientFn: ({ outputValue, originalValue, shareCounts, realIndex, outIndex }) => outputValue[outIndex] === originalValue[realIndex] ? 1 / shareCounts[outIndex] : 0
});
}
min(dims, keepDims = false) {
return Tensor.reduce(this, dims, keepDims, {
identity: Infinity,
operation: (a, b) => Math.min(a, b),
needsShareCounts: true,
gradientFn: ({ outputValue, originalValue, shareCounts, realIndex, outIndex }) => outputValue[outIndex] === originalValue[realIndex] ? 1 / shareCounts[outIndex] : 0
});
}
// Tensor all condition reduction
all(dims, keepDims = false) {
return this.min(dims, keepDims).ne(0);
}
// Tensor any condition reduction
any(dims, keepDims = false) {
return this.max(dims, keepDims).ne(0);
}
// Tensor variance reduction
var(dims, keepDims = false) {
const meanXSquared = this.square().mean(dims, keepDims);
const meanXSquaredExpanded = this.mean(dims, keepDims).square();
return meanXSquared.sub(meanXSquaredExpanded);
}
// Tensor standard deviation reduction
std(dims, keepDims = false) {
return this.var(dims, keepDims).sqrt();
}
// Tensor softmax
softmax(dim = -1) {
if (this.shape.length === 0)
return this;
// Handle negative indexing
if (dim < 0) {
dim += this.shape.length;
}
// If dimension out of bound, throw error
if (dim >= this.shape.length || dim < 0) {
throw new Error("Dimension do not exist to apply softmax");
}
const maxVals = this.max(dim, true);
const shifted = this.sub(maxVals);
const expVals = shifted.exp();
const sumExp = expVals.sum(dim, true);
return expVals.div(sumExp);
}
// Tensor softmin
softmin(dim = -1) {
if (this.shape.length === 0)
return this;
// Handle negative indexing
if (dim < 0) {
dim += this.shape.length;
}
// If dimension out of bound, throw error
if (dim >= this.shape.length || dim < 0) {
throw new Error("Dimension do not exist to apply softmin");
}
const maxVals = this.max(dim, true);
const shifted = maxVals.sub(this);
const expVals = shifted.exp();
const sumExp = expVals.sum(dim, true);
return expVals.div(sumExp);
}
// Tensor element-wise addition
add(other) {
return this.