UNPKG

stargrad

Version:

A JavaScript library for automatic gradient calculation, inspired by PyTorch

310 lines (309 loc) 11.4 kB
/* * Represents a tensor with automatic gradient calculation capabilities */ export class Tensor { constructor(data, requiresGrad = false, operation = null, inputs = null) { // Convert input to flat array and store shape if (typeof data === 'number') { this.data = [data]; this.shape = [1]; } else { this.data = this.flattenArray(data); this.shape = this.getShapeFromArray(data); } this.requiresGrad = requiresGrad; this.grad = null; this.operation = operation; this.inputs = inputs; } /** * Flattens a nested array of any dimension into a 1D array */ flattenArray(arr) { if (typeof arr === 'number') { return [arr]; } if (!Array.isArray(arr)) { throw new Error('Input must be a number or array'); } return arr.reduce((flat, item) => { if (Array.isArray(item)) { return flat.concat(this.flattenArray(item)); } if (typeof item === 'number') { return flat.concat(item); } throw new Error('Array elements must be numbers or arrays of numbers'); }, []); } /* * Gets the shape of a nested array of any dimension (to get the structure) */ getShapeFromArray(arr) { if (typeof arr === 'number') { return [1]; } if (!Array.isArray(arr)) { throw new Error('Input must be a number or array'); } const shape = []; let current = arr; while (Array.isArray(current)) { shape.push(current.length); if (current.length > 0 && Array.isArray(current[0])) { current = current[0]; } else { break; } } return shape; } /** * Gets the element at the specified indices */ get(indices) { if (indices.length !== this.shape.length) { throw new Error(`Expected ${this.shape.length} indices, got ${indices.length}`); } let index = 0; let stride = 1; for (let i = this.shape.length - 1; i >= 0; i--) { if (indices[i] >= this.shape[i]) { throw new Error(`Index ${indices[i]} out of bounds for dimension ${i} with size ${this.shape[i]}`); } index += indices[i] * stride; stride *= this.shape[i]; } return this.data[index]; } /** * Sets the element at the specified indices */ set(indices, value) { if (indices.length !== this.shape.length) { throw new Error(`Expected ${this.shape.length} indices, got ${indices.length}`); } let index = 0; let stride = 1; for (let i = this.shape.length - 1; i >= 0; i--) { if (indices[i] >= this.shape[i]) { throw new Error(`Index ${indices[i]} out of bounds for dimension ${i} with size ${this.shape[i]}`); } index += indices[i] * stride; stride *= this.shape[i]; } this.data[index] = value; } /** * Adds two tensors element-wise */ add(other) { if (this.shape.length !== other.shape.length) { throw new Error('Tensors must have the same number of dimensions'); } if (!this.shape.every((dim, i) => dim === other.shape[i])) { throw new Error('Tensors must have the same shape'); } // Create new data array and add corresponding elements const newData = new Array(this.data.length); for (let i = 0; i < this.data.length; i++) { newData[i] = this.data[i] + other.data[i]; } // Create result tensor with same shape and operation info const result = new Tensor(newData, this.requiresGrad || other.requiresGrad, 'add', [this, other]); // Reshape result to match input shape result.shape = [...this.shape]; return result; } /** * Multiplies two tensors element-wise */ mul(other) { if (this.shape.length !== other.shape.length) { throw new Error('Tensors must have the same number of dimensions'); } if (!this.shape.every((dim, i) => dim === other.shape[i])) { throw new Error('Tensors must have the same shape'); } // Create new data array and multiply corresponding elements const newData = new Array(this.data.length); for (let i = 0; i < this.data.length; i++) { newData[i] = this.data[i] * other.data[i]; } // Create result tensor with same shape and operation info const result = new Tensor(newData, this.requiresGrad || other.requiresGrad, 'mul', [this, other]); // Reshape result to match input shape result.shape = [...this.shape]; return result; } /** * Subtracts two tensors element-wise */ sub(other) { if (this.shape.length !== other.shape.length) { throw new Error('Tensors must have the same number of dimensions'); } if (!this.shape.every((dim, i) => dim === other.shape[i])) { throw new Error('Tensors must have the same shape'); } const newData = new Array(this.data.length); for (let i = 0; i < this.data.length; i++) { newData[i] = this.data[i] - other.data[i]; } const result = new Tensor(newData, this.requiresGrad || other.requiresGrad, 'sub', [this, other]); result.shape = [...this.shape]; return result; } /** * Divides two tensors element-wise */ div(other) { if (this.shape.length !== other.shape.length) { throw new Error('Tensors must have the same number of dimensions'); } if (!this.shape.every((dim, i) => dim === other.shape[i])) { throw new Error('Tensors must have the same shape'); } // Check for division by zero for (let i = 0; i < other.data.length; i++) { if (other.data[i] === 0) { throw new Error('Division by zero'); } } const newData = new Array(this.data.length); for (let i = 0; i < this.data.length; i++) { newData[i] = this.data[i] / other.data[i]; } const result = new Tensor(newData, this.requiresGrad || other.requiresGrad, 'div', [this, other]); result.shape = [...this.shape]; return result; } /** * Raises the tensor to a power element-wise */ pow(exponent) { // Check for invalid cases before performing the operation for (let i = 0; i < this.data.length; i++) { if (this.data[i] === 0 && exponent < 1) { throw new Error('Cannot compute power for zero base with exponent < 1'); } } const newData = new Array(this.data.length); for (let i = 0; i < this.data.length; i++) { newData[i] = Math.pow(this.data[i], exponent); } const result = new Tensor(newData, this.requiresGrad, 'pow', [this, new Tensor(exponent)] // Store exponent as a tensor for gradient calculation ); result.shape = [...this.shape]; return result; } /** * Calculates the natural logarithm of the tensor element-wise */ log() { // Check for non-positive values for (let i = 0; i < this.data.length; i++) { if (this.data[i] <= 0) { throw new Error('Logarithm is only defined for positive numbers'); } } const newData = new Array(this.data.length); for (let i = 0; i < this.data.length; i++) { newData[i] = Math.log(this.data[i]); } const result = new Tensor(newData, this.requiresGrad, 'log', [this]); result.shape = [...this.shape]; return result; } /** * Calculates the gradient of the tensor */ backward(gradient = 1) { if (!this.requiresGrad) { return; } const gradArray = typeof gradient === 'number' ? new Array(this.data.length).fill(gradient) : gradient; if (this.grad === null) { this.grad = new Tensor(gradArray); this.grad.shape = [...this.shape]; } else { const newGrad = new Tensor(gradArray); newGrad.shape = [...this.shape]; this.grad = this.grad.add(newGrad); } if (this.operation && this.inputs) { let input0Grad; let input1Grad; switch (this.operation) { case 'add': this.inputs.forEach(input => input.backward(gradArray)); break; case 'mul': input0Grad = gradArray.map((g, i) => g * this.inputs[1].data[i]); input1Grad = gradArray.map((g, i) => g * this.inputs[0].data[i]); this.inputs[0].backward(input0Grad); this.inputs[1].backward(input1Grad); break; case 'sub': this.inputs[0].backward(gradArray); this.inputs[1].backward(gradArray.map(g => -g)); break; case 'div': { // For division: d(a/b)/da = 1/b, d(a/b)/db = -a/b^2 input0Grad = gradArray.map((g, i) => g / this.inputs[1].data[i]); input1Grad = gradArray.map((g, i) => -g * this.inputs[0].data[i] / Math.pow(this.inputs[1].data[i], 2)); this.inputs[0].backward(input0Grad); this.inputs[1].backward(input1Grad); break; } case 'pow': { // For power: d(x^n)/dx = n * x^(n-1) const base = this.inputs[0].data; const exponent = this.inputs[1].data[0]; const inputGradPow = gradArray.map((g, i) => { if (base[i] === 0 && exponent < 1) { throw new Error('Cannot compute gradient for zero base with exponent < 1'); } return g * exponent * Math.pow(base[i], exponent - 1); }); this.inputs[0].backward(inputGradPow); break; } case 'log': { // For logarithm: d(log(x))/dx = 1/x const inputGradLog = gradArray.map((g, i) => g / this.inputs[0].data[i]); this.inputs[0].backward(inputGradLog); break; } } } } /** * Returns the current gradient */ getGrad() { return this.grad; } /** * Returns the tensor data */ getData() { return this.data; } /** * Returns the tensor shape */ getShape() { return this.shape; } /** * Returns a string representation of the tensor */ toString() { return `Tensor(shape=${JSON.stringify(this.shape)}, data=${JSON.stringify(this.data)})`; } }