stargrad
Version:
A JavaScript library for automatic gradient calculation, inspired by PyTorch
310 lines (309 loc) • 11.4 kB
JavaScript
/*
* Represents a tensor with automatic gradient calculation capabilities
*/
export class Tensor {
constructor(data, requiresGrad = false, operation = null, inputs = null) {
// Convert input to flat array and store shape
if (typeof data === 'number') {
this.data = [data];
this.shape = [1];
}
else {
this.data = this.flattenArray(data);
this.shape = this.getShapeFromArray(data);
}
this.requiresGrad = requiresGrad;
this.grad = null;
this.operation = operation;
this.inputs = inputs;
}
/**
* Flattens a nested array of any dimension into a 1D array
*/
flattenArray(arr) {
if (typeof arr === 'number') {
return [arr];
}
if (!Array.isArray(arr)) {
throw new Error('Input must be a number or array');
}
return arr.reduce((flat, item) => {
if (Array.isArray(item)) {
return flat.concat(this.flattenArray(item));
}
if (typeof item === 'number') {
return flat.concat(item);
}
throw new Error('Array elements must be numbers or arrays of numbers');
}, []);
}
/*
* Gets the shape of a nested array of any dimension (to get the structure)
*/
getShapeFromArray(arr) {
if (typeof arr === 'number') {
return [1];
}
if (!Array.isArray(arr)) {
throw new Error('Input must be a number or array');
}
const shape = [];
let current = arr;
while (Array.isArray(current)) {
shape.push(current.length);
if (current.length > 0 && Array.isArray(current[0])) {
current = current[0];
}
else {
break;
}
}
return shape;
}
/**
* Gets the element at the specified indices
*/
get(indices) {
if (indices.length !== this.shape.length) {
throw new Error(`Expected ${this.shape.length} indices, got ${indices.length}`);
}
let index = 0;
let stride = 1;
for (let i = this.shape.length - 1; i >= 0; i--) {
if (indices[i] >= this.shape[i]) {
throw new Error(`Index ${indices[i]} out of bounds for dimension ${i} with size ${this.shape[i]}`);
}
index += indices[i] * stride;
stride *= this.shape[i];
}
return this.data[index];
}
/**
* Sets the element at the specified indices
*/
set(indices, value) {
if (indices.length !== this.shape.length) {
throw new Error(`Expected ${this.shape.length} indices, got ${indices.length}`);
}
let index = 0;
let stride = 1;
for (let i = this.shape.length - 1; i >= 0; i--) {
if (indices[i] >= this.shape[i]) {
throw new Error(`Index ${indices[i]} out of bounds for dimension ${i} with size ${this.shape[i]}`);
}
index += indices[i] * stride;
stride *= this.shape[i];
}
this.data[index] = value;
}
/**
* Adds two tensors element-wise
*/
add(other) {
if (this.shape.length !== other.shape.length) {
throw new Error('Tensors must have the same number of dimensions');
}
if (!this.shape.every((dim, i) => dim === other.shape[i])) {
throw new Error('Tensors must have the same shape');
}
// Create new data array and add corresponding elements
const newData = new Array(this.data.length);
for (let i = 0; i < this.data.length; i++) {
newData[i] = this.data[i] + other.data[i];
}
// Create result tensor with same shape and operation info
const result = new Tensor(newData, this.requiresGrad || other.requiresGrad, 'add', [this, other]);
// Reshape result to match input shape
result.shape = [...this.shape];
return result;
}
/**
* Multiplies two tensors element-wise
*/
mul(other) {
if (this.shape.length !== other.shape.length) {
throw new Error('Tensors must have the same number of dimensions');
}
if (!this.shape.every((dim, i) => dim === other.shape[i])) {
throw new Error('Tensors must have the same shape');
}
// Create new data array and multiply corresponding elements
const newData = new Array(this.data.length);
for (let i = 0; i < this.data.length; i++) {
newData[i] = this.data[i] * other.data[i];
}
// Create result tensor with same shape and operation info
const result = new Tensor(newData, this.requiresGrad || other.requiresGrad, 'mul', [this, other]);
// Reshape result to match input shape
result.shape = [...this.shape];
return result;
}
/**
* Subtracts two tensors element-wise
*/
sub(other) {
if (this.shape.length !== other.shape.length) {
throw new Error('Tensors must have the same number of dimensions');
}
if (!this.shape.every((dim, i) => dim === other.shape[i])) {
throw new Error('Tensors must have the same shape');
}
const newData = new Array(this.data.length);
for (let i = 0; i < this.data.length; i++) {
newData[i] = this.data[i] - other.data[i];
}
const result = new Tensor(newData, this.requiresGrad || other.requiresGrad, 'sub', [this, other]);
result.shape = [...this.shape];
return result;
}
/**
* Divides two tensors element-wise
*/
div(other) {
if (this.shape.length !== other.shape.length) {
throw new Error('Tensors must have the same number of dimensions');
}
if (!this.shape.every((dim, i) => dim === other.shape[i])) {
throw new Error('Tensors must have the same shape');
}
// Check for division by zero
for (let i = 0; i < other.data.length; i++) {
if (other.data[i] === 0) {
throw new Error('Division by zero');
}
}
const newData = new Array(this.data.length);
for (let i = 0; i < this.data.length; i++) {
newData[i] = this.data[i] / other.data[i];
}
const result = new Tensor(newData, this.requiresGrad || other.requiresGrad, 'div', [this, other]);
result.shape = [...this.shape];
return result;
}
/**
* Raises the tensor to a power element-wise
*/
pow(exponent) {
// Check for invalid cases before performing the operation
for (let i = 0; i < this.data.length; i++) {
if (this.data[i] === 0 && exponent < 1) {
throw new Error('Cannot compute power for zero base with exponent < 1');
}
}
const newData = new Array(this.data.length);
for (let i = 0; i < this.data.length; i++) {
newData[i] = Math.pow(this.data[i], exponent);
}
const result = new Tensor(newData, this.requiresGrad, 'pow', [this, new Tensor(exponent)] // Store exponent as a tensor for gradient calculation
);
result.shape = [...this.shape];
return result;
}
/**
* Calculates the natural logarithm of the tensor element-wise
*/
log() {
// Check for non-positive values
for (let i = 0; i < this.data.length; i++) {
if (this.data[i] <= 0) {
throw new Error('Logarithm is only defined for positive numbers');
}
}
const newData = new Array(this.data.length);
for (let i = 0; i < this.data.length; i++) {
newData[i] = Math.log(this.data[i]);
}
const result = new Tensor(newData, this.requiresGrad, 'log', [this]);
result.shape = [...this.shape];
return result;
}
/**
* Calculates the gradient of the tensor
*/
backward(gradient = 1) {
if (!this.requiresGrad) {
return;
}
const gradArray = typeof gradient === 'number' ?
new Array(this.data.length).fill(gradient) : gradient;
if (this.grad === null) {
this.grad = new Tensor(gradArray);
this.grad.shape = [...this.shape];
}
else {
const newGrad = new Tensor(gradArray);
newGrad.shape = [...this.shape];
this.grad = this.grad.add(newGrad);
}
if (this.operation && this.inputs) {
let input0Grad;
let input1Grad;
switch (this.operation) {
case 'add':
this.inputs.forEach(input => input.backward(gradArray));
break;
case 'mul':
input0Grad = gradArray.map((g, i) => g * this.inputs[1].data[i]);
input1Grad = gradArray.map((g, i) => g * this.inputs[0].data[i]);
this.inputs[0].backward(input0Grad);
this.inputs[1].backward(input1Grad);
break;
case 'sub':
this.inputs[0].backward(gradArray);
this.inputs[1].backward(gradArray.map(g => -g));
break;
case 'div': {
// For division: d(a/b)/da = 1/b, d(a/b)/db = -a/b^2
input0Grad = gradArray.map((g, i) => g / this.inputs[1].data[i]);
input1Grad = gradArray.map((g, i) => -g * this.inputs[0].data[i] / Math.pow(this.inputs[1].data[i], 2));
this.inputs[0].backward(input0Grad);
this.inputs[1].backward(input1Grad);
break;
}
case 'pow': {
// For power: d(x^n)/dx = n * x^(n-1)
const base = this.inputs[0].data;
const exponent = this.inputs[1].data[0];
const inputGradPow = gradArray.map((g, i) => {
if (base[i] === 0 && exponent < 1) {
throw new Error('Cannot compute gradient for zero base with exponent < 1');
}
return g * exponent * Math.pow(base[i], exponent - 1);
});
this.inputs[0].backward(inputGradPow);
break;
}
case 'log': {
// For logarithm: d(log(x))/dx = 1/x
const inputGradLog = gradArray.map((g, i) => g / this.inputs[0].data[i]);
this.inputs[0].backward(inputGradLog);
break;
}
}
}
}
/**
* Returns the current gradient
*/
getGrad() {
return this.grad;
}
/**
* Returns the tensor data
*/
getData() {
return this.data;
}
/**
* Returns the tensor shape
*/
getShape() {
return this.shape;
}
/**
* Returns a string representation of the tensor
*/
toString() {
return `Tensor(shape=${JSON.stringify(this.shape)}, data=${JSON.stringify(this.data)})`;
}
}