ml-matrix
Version:
Matrix manipulation and computation library
1,736 lines (1,581 loc) • 45.5 kB
JavaScript
import { isAnyArray } from 'is-any-array';
import rescale from 'ml-array-rescale';
import { inspectMatrix, inspectMatrixWithOptions } from './inspect';
import { installMathOperations } from './mathOperations';
import {
sumByRow,
sumByColumn,
sumAll,
productByRow,
productByColumn,
productAll,
varianceByRow,
varianceByColumn,
varianceAll,
centerByRow,
centerByColumn,
centerAll,
scaleByRow,
scaleByColumn,
scaleAll,
getScaleByRow,
getScaleByColumn,
getScaleAll,
} from './stat';
import {
checkRowVector,
checkRowIndex,
checkColumnIndex,
checkColumnVector,
checkRange,
checkNonEmpty,
checkRowIndices,
checkColumnIndices,
} from './util';
export class AbstractMatrix {
static from1DArray(newRows, newColumns, newData) {
let length = newRows * newColumns;
if (length !== newData.length) {
throw new RangeError('data length does not match given dimensions');
}
let newMatrix = new Matrix(newRows, newColumns);
for (let row = 0; row < newRows; row++) {
for (let column = 0; column < newColumns; column++) {
newMatrix.set(row, column, newData[row * newColumns + column]);
}
}
return newMatrix;
}
static rowVector(newData) {
let vector = new Matrix(1, newData.length);
for (let i = 0; i < newData.length; i++) {
vector.set(0, i, newData[i]);
}
return vector;
}
static columnVector(newData) {
let vector = new Matrix(newData.length, 1);
for (let i = 0; i < newData.length; i++) {
vector.set(i, 0, newData[i]);
}
return vector;
}
static zeros(rows, columns) {
return new Matrix(rows, columns);
}
static ones(rows, columns) {
return new Matrix(rows, columns).fill(1);
}
static rand(rows, columns, options = {}) {
if (typeof options !== 'object') {
throw new TypeError('options must be an object');
}
const { random = Math.random } = options;
let matrix = new Matrix(rows, columns);
for (let i = 0; i < rows; i++) {
for (let j = 0; j < columns; j++) {
matrix.set(i, j, random());
}
}
return matrix;
}
static randInt(rows, columns, options = {}) {
if (typeof options !== 'object') {
throw new TypeError('options must be an object');
}
const { min = 0, max = 1000, random = Math.random } = options;
if (!Number.isInteger(min)) throw new TypeError('min must be an integer');
if (!Number.isInteger(max)) throw new TypeError('max must be an integer');
if (min >= max) throw new RangeError('min must be smaller than max');
let interval = max - min;
let matrix = new Matrix(rows, columns);
for (let i = 0; i < rows; i++) {
for (let j = 0; j < columns; j++) {
let value = min + Math.round(random() * interval);
matrix.set(i, j, value);
}
}
return matrix;
}
static eye(rows, columns, value) {
if (columns === undefined) columns = rows;
if (value === undefined) value = 1;
let min = Math.min(rows, columns);
let matrix = this.zeros(rows, columns);
for (let i = 0; i < min; i++) {
matrix.set(i, i, value);
}
return matrix;
}
static diag(data, rows, columns) {
let l = data.length;
if (rows === undefined) rows = l;
if (columns === undefined) columns = rows;
let min = Math.min(l, rows, columns);
let matrix = this.zeros(rows, columns);
for (let i = 0; i < min; i++) {
matrix.set(i, i, data[i]);
}
return matrix;
}
static min(matrix1, matrix2) {
matrix1 = this.checkMatrix(matrix1);
matrix2 = this.checkMatrix(matrix2);
let rows = matrix1.rows;
let columns = matrix1.columns;
let result = new Matrix(rows, columns);
for (let i = 0; i < rows; i++) {
for (let j = 0; j < columns; j++) {
result.set(i, j, Math.min(matrix1.get(i, j), matrix2.get(i, j)));
}
}
return result;
}
static max(matrix1, matrix2) {
matrix1 = this.checkMatrix(matrix1);
matrix2 = this.checkMatrix(matrix2);
let rows = matrix1.rows;
let columns = matrix1.columns;
let result = new this(rows, columns);
for (let i = 0; i < rows; i++) {
for (let j = 0; j < columns; j++) {
result.set(i, j, Math.max(matrix1.get(i, j), matrix2.get(i, j)));
}
}
return result;
}
static checkMatrix(value) {
return AbstractMatrix.isMatrix(value) ? value : new Matrix(value);
}
static isMatrix(value) {
return value != null && value.klass === 'Matrix';
}
get size() {
return this.rows * this.columns;
}
apply(callback) {
if (typeof callback !== 'function') {
throw new TypeError('callback must be a function');
}
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
callback.call(this, i, j);
}
}
return this;
}
to1DArray() {
let array = [];
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
array.push(this.get(i, j));
}
}
return array;
}
to2DArray() {
let copy = [];
for (let i = 0; i < this.rows; i++) {
copy.push([]);
for (let j = 0; j < this.columns; j++) {
copy[i].push(this.get(i, j));
}
}
return copy;
}
toJSON() {
return this.to2DArray();
}
isRowVector() {
return this.rows === 1;
}
isColumnVector() {
return this.columns === 1;
}
isVector() {
return this.rows === 1 || this.columns === 1;
}
isSquare() {
return this.rows === this.columns;
}
isEmpty() {
return this.rows === 0 || this.columns === 0;
}
isSymmetric() {
if (this.isSquare()) {
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j <= i; j++) {
if (this.get(i, j) !== this.get(j, i)) {
return false;
}
}
}
return true;
}
return false;
}
isDistance() {
if (!this.isSymmetric()) return false;
for (let i = 0; i < this.rows; i++) {
if (this.get(i, i) !== 0) return false;
}
return true;
}
isEchelonForm() {
let i = 0;
let j = 0;
let previousColumn = -1;
let isEchelonForm = true;
let checked = false;
while (i < this.rows && isEchelonForm) {
j = 0;
checked = false;
while (j < this.columns && checked === false) {
if (this.get(i, j) === 0) {
j++;
} else if (this.get(i, j) === 1 && j > previousColumn) {
checked = true;
previousColumn = j;
} else {
isEchelonForm = false;
checked = true;
}
}
i++;
}
return isEchelonForm;
}
isReducedEchelonForm() {
let i = 0;
let j = 0;
let previousColumn = -1;
let isReducedEchelonForm = true;
let checked = false;
while (i < this.rows && isReducedEchelonForm) {
j = 0;
checked = false;
while (j < this.columns && checked === false) {
if (this.get(i, j) === 0) {
j++;
} else if (this.get(i, j) === 1 && j > previousColumn) {
checked = true;
previousColumn = j;
} else {
isReducedEchelonForm = false;
checked = true;
}
}
for (let k = j + 1; k < this.rows; k++) {
if (this.get(i, k) !== 0) {
isReducedEchelonForm = false;
}
}
i++;
}
return isReducedEchelonForm;
}
echelonForm() {
let result = this.clone();
let h = 0;
let k = 0;
while (h < result.rows && k < result.columns) {
let iMax = h;
for (let i = h; i < result.rows; i++) {
if (result.get(i, k) > result.get(iMax, k)) {
iMax = i;
}
}
if (result.get(iMax, k) === 0) {
k++;
} else {
result.swapRows(h, iMax);
let tmp = result.get(h, k);
for (let j = k; j < result.columns; j++) {
result.set(h, j, result.get(h, j) / tmp);
}
for (let i = h + 1; i < result.rows; i++) {
let factor = result.get(i, k) / result.get(h, k);
result.set(i, k, 0);
for (let j = k + 1; j < result.columns; j++) {
result.set(i, j, result.get(i, j) - result.get(h, j) * factor);
}
}
h++;
k++;
}
}
return result;
}
reducedEchelonForm() {
let result = this.echelonForm();
let m = result.columns;
let n = result.rows;
let h = n - 1;
while (h >= 0) {
if (result.maxRow(h) === 0) {
h--;
} else {
let p = 0;
let pivot = false;
while (p < n && pivot === false) {
if (result.get(h, p) === 1) {
pivot = true;
} else {
p++;
}
}
for (let i = 0; i < h; i++) {
let factor = result.get(i, p);
for (let j = p; j < m; j++) {
let tmp = result.get(i, j) - factor * result.get(h, j);
result.set(i, j, tmp);
}
}
h--;
}
}
return result;
}
set() {
throw new Error('set method is unimplemented');
}
get() {
throw new Error('get method is unimplemented');
}
repeat(options = {}) {
if (typeof options !== 'object') {
throw new TypeError('options must be an object');
}
const { rows = 1, columns = 1 } = options;
if (!Number.isInteger(rows) || rows <= 0) {
throw new TypeError('rows must be a positive integer');
}
if (!Number.isInteger(columns) || columns <= 0) {
throw new TypeError('columns must be a positive integer');
}
let matrix = new Matrix(this.rows * rows, this.columns * columns);
for (let i = 0; i < rows; i++) {
for (let j = 0; j < columns; j++) {
matrix.setSubMatrix(this, this.rows * i, this.columns * j);
}
}
return matrix;
}
fill(value) {
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
this.set(i, j, value);
}
}
return this;
}
neg() {
return this.mulS(-1);
}
getRow(index) {
checkRowIndex(this, index);
let row = [];
for (let i = 0; i < this.columns; i++) {
row.push(this.get(index, i));
}
return row;
}
getRowVector(index) {
return Matrix.rowVector(this.getRow(index));
}
setRow(index, array) {
checkRowIndex(this, index);
array = checkRowVector(this, array);
for (let i = 0; i < this.columns; i++) {
this.set(index, i, array[i]);
}
return this;
}
swapRows(row1, row2) {
checkRowIndex(this, row1);
checkRowIndex(this, row2);
for (let i = 0; i < this.columns; i++) {
let temp = this.get(row1, i);
this.set(row1, i, this.get(row2, i));
this.set(row2, i, temp);
}
return this;
}
getColumn(index) {
checkColumnIndex(this, index);
let column = [];
for (let i = 0; i < this.rows; i++) {
column.push(this.get(i, index));
}
return column;
}
getColumnVector(index) {
return Matrix.columnVector(this.getColumn(index));
}
setColumn(index, array) {
checkColumnIndex(this, index);
array = checkColumnVector(this, array);
for (let i = 0; i < this.rows; i++) {
this.set(i, index, array[i]);
}
return this;
}
swapColumns(column1, column2) {
checkColumnIndex(this, column1);
checkColumnIndex(this, column2);
for (let i = 0; i < this.rows; i++) {
let temp = this.get(i, column1);
this.set(i, column1, this.get(i, column2));
this.set(i, column2, temp);
}
return this;
}
addRowVector(vector) {
vector = checkRowVector(this, vector);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
this.set(i, j, this.get(i, j) + vector[j]);
}
}
return this;
}
subRowVector(vector) {
vector = checkRowVector(this, vector);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
this.set(i, j, this.get(i, j) - vector[j]);
}
}
return this;
}
mulRowVector(vector) {
vector = checkRowVector(this, vector);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
this.set(i, j, this.get(i, j) * vector[j]);
}
}
return this;
}
divRowVector(vector) {
vector = checkRowVector(this, vector);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
this.set(i, j, this.get(i, j) / vector[j]);
}
}
return this;
}
addColumnVector(vector) {
vector = checkColumnVector(this, vector);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
this.set(i, j, this.get(i, j) + vector[i]);
}
}
return this;
}
subColumnVector(vector) {
vector = checkColumnVector(this, vector);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
this.set(i, j, this.get(i, j) - vector[i]);
}
}
return this;
}
mulColumnVector(vector) {
vector = checkColumnVector(this, vector);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
this.set(i, j, this.get(i, j) * vector[i]);
}
}
return this;
}
divColumnVector(vector) {
vector = checkColumnVector(this, vector);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
this.set(i, j, this.get(i, j) / vector[i]);
}
}
return this;
}
mulRow(index, value) {
checkRowIndex(this, index);
for (let i = 0; i < this.columns; i++) {
this.set(index, i, this.get(index, i) * value);
}
return this;
}
mulColumn(index, value) {
checkColumnIndex(this, index);
for (let i = 0; i < this.rows; i++) {
this.set(i, index, this.get(i, index) * value);
}
return this;
}
max(by) {
if (this.isEmpty()) {
return NaN;
}
switch (by) {
case 'row': {
const max = new Array(this.rows).fill(Number.NEGATIVE_INFINITY);
for (let row = 0; row < this.rows; row++) {
for (let column = 0; column < this.columns; column++) {
if (this.get(row, column) > max[row]) {
max[row] = this.get(row, column);
}
}
}
return max;
}
case 'column': {
const max = new Array(this.columns).fill(Number.NEGATIVE_INFINITY);
for (let row = 0; row < this.rows; row++) {
for (let column = 0; column < this.columns; column++) {
if (this.get(row, column) > max[column]) {
max[column] = this.get(row, column);
}
}
}
return max;
}
case undefined: {
let max = this.get(0, 0);
for (let row = 0; row < this.rows; row++) {
for (let column = 0; column < this.columns; column++) {
if (this.get(row, column) > max) {
max = this.get(row, column);
}
}
}
return max;
}
default:
throw new Error(`invalid option: ${by}`);
}
}
maxIndex() {
checkNonEmpty(this);
let v = this.get(0, 0);
let idx = [0, 0];
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
if (this.get(i, j) > v) {
v = this.get(i, j);
idx[0] = i;
idx[1] = j;
}
}
}
return idx;
}
min(by) {
if (this.isEmpty()) {
return NaN;
}
switch (by) {
case 'row': {
const min = new Array(this.rows).fill(Number.POSITIVE_INFINITY);
for (let row = 0; row < this.rows; row++) {
for (let column = 0; column < this.columns; column++) {
if (this.get(row, column) < min[row]) {
min[row] = this.get(row, column);
}
}
}
return min;
}
case 'column': {
const min = new Array(this.columns).fill(Number.POSITIVE_INFINITY);
for (let row = 0; row < this.rows; row++) {
for (let column = 0; column < this.columns; column++) {
if (this.get(row, column) < min[column]) {
min[column] = this.get(row, column);
}
}
}
return min;
}
case undefined: {
let min = this.get(0, 0);
for (let row = 0; row < this.rows; row++) {
for (let column = 0; column < this.columns; column++) {
if (this.get(row, column) < min) {
min = this.get(row, column);
}
}
}
return min;
}
default:
throw new Error(`invalid option: ${by}`);
}
}
minIndex() {
checkNonEmpty(this);
let v = this.get(0, 0);
let idx = [0, 0];
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
if (this.get(i, j) < v) {
v = this.get(i, j);
idx[0] = i;
idx[1] = j;
}
}
}
return idx;
}
maxRow(row) {
checkRowIndex(this, row);
if (this.isEmpty()) {
return NaN;
}
let v = this.get(row, 0);
for (let i = 1; i < this.columns; i++) {
if (this.get(row, i) > v) {
v = this.get(row, i);
}
}
return v;
}
maxRowIndex(row) {
checkRowIndex(this, row);
checkNonEmpty(this);
let v = this.get(row, 0);
let idx = [row, 0];
for (let i = 1; i < this.columns; i++) {
if (this.get(row, i) > v) {
v = this.get(row, i);
idx[1] = i;
}
}
return idx;
}
minRow(row) {
checkRowIndex(this, row);
if (this.isEmpty()) {
return NaN;
}
let v = this.get(row, 0);
for (let i = 1; i < this.columns; i++) {
if (this.get(row, i) < v) {
v = this.get(row, i);
}
}
return v;
}
minRowIndex(row) {
checkRowIndex(this, row);
checkNonEmpty(this);
let v = this.get(row, 0);
let idx = [row, 0];
for (let i = 1; i < this.columns; i++) {
if (this.get(row, i) < v) {
v = this.get(row, i);
idx[1] = i;
}
}
return idx;
}
maxColumn(column) {
checkColumnIndex(this, column);
if (this.isEmpty()) {
return NaN;
}
let v = this.get(0, column);
for (let i = 1; i < this.rows; i++) {
if (this.get(i, column) > v) {
v = this.get(i, column);
}
}
return v;
}
maxColumnIndex(column) {
checkColumnIndex(this, column);
checkNonEmpty(this);
let v = this.get(0, column);
let idx = [0, column];
for (let i = 1; i < this.rows; i++) {
if (this.get(i, column) > v) {
v = this.get(i, column);
idx[0] = i;
}
}
return idx;
}
minColumn(column) {
checkColumnIndex(this, column);
if (this.isEmpty()) {
return NaN;
}
let v = this.get(0, column);
for (let i = 1; i < this.rows; i++) {
if (this.get(i, column) < v) {
v = this.get(i, column);
}
}
return v;
}
minColumnIndex(column) {
checkColumnIndex(this, column);
checkNonEmpty(this);
let v = this.get(0, column);
let idx = [0, column];
for (let i = 1; i < this.rows; i++) {
if (this.get(i, column) < v) {
v = this.get(i, column);
idx[0] = i;
}
}
return idx;
}
diag() {
let min = Math.min(this.rows, this.columns);
let diag = [];
for (let i = 0; i < min; i++) {
diag.push(this.get(i, i));
}
return diag;
}
norm(type = 'frobenius') {
switch (type) {
case 'max':
return this.max();
case 'frobenius':
return Math.sqrt(this.dot(this));
default:
throw new RangeError(`unknown norm type: ${type}`);
}
}
cumulativeSum() {
let sum = 0;
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
sum += this.get(i, j);
this.set(i, j, sum);
}
}
return this;
}
dot(vector2) {
if (AbstractMatrix.isMatrix(vector2)) vector2 = vector2.to1DArray();
let vector1 = this.to1DArray();
if (vector1.length !== vector2.length) {
throw new RangeError('vectors do not have the same size');
}
let dot = 0;
for (let i = 0; i < vector1.length; i++) {
dot += vector1[i] * vector2[i];
}
return dot;
}
mmul(other) {
other = Matrix.checkMatrix(other);
let m = this.rows;
let n = this.columns;
let p = other.columns;
let result = new Matrix(m, p);
let Bcolj = new Float64Array(n);
for (let j = 0; j < p; j++) {
for (let k = 0; k < n; k++) {
Bcolj[k] = other.get(k, j);
}
for (let i = 0; i < m; i++) {
let s = 0;
for (let k = 0; k < n; k++) {
s += this.get(i, k) * Bcolj[k];
}
result.set(i, j, s);
}
}
return result;
}
mpow(scalar) {
if (!this.isSquare()) {
throw new RangeError('Matrix must be square');
}
if (!Number.isInteger(scalar) || scalar < 0) {
throw new RangeError('Exponent must be a non-negative integer');
}
// Russian Peasant exponentiation, i.e. exponentiation by squaring
let result = Matrix.eye(this.rows);
let bb = this;
// Note: Don't bit shift. In JS, that would truncate at 32 bits
for (let e = scalar; e > 1; e /= 2) {
if ((e & 1) !== 0) {
result = result.mmul(bb);
}
bb = bb.mmul(bb);
}
return result;
}
strassen2x2(other) {
other = Matrix.checkMatrix(other);
let result = new Matrix(2, 2);
const a11 = this.get(0, 0);
const b11 = other.get(0, 0);
const a12 = this.get(0, 1);
const b12 = other.get(0, 1);
const a21 = this.get(1, 0);
const b21 = other.get(1, 0);
const a22 = this.get(1, 1);
const b22 = other.get(1, 1);
// Compute intermediate values.
const m1 = (a11 + a22) * (b11 + b22);
const m2 = (a21 + a22) * b11;
const m3 = a11 * (b12 - b22);
const m4 = a22 * (b21 - b11);
const m5 = (a11 + a12) * b22;
const m6 = (a21 - a11) * (b11 + b12);
const m7 = (a12 - a22) * (b21 + b22);
// Combine intermediate values into the output.
const c00 = m1 + m4 - m5 + m7;
const c01 = m3 + m5;
const c10 = m2 + m4;
const c11 = m1 - m2 + m3 + m6;
result.set(0, 0, c00);
result.set(0, 1, c01);
result.set(1, 0, c10);
result.set(1, 1, c11);
return result;
}
strassen3x3(other) {
other = Matrix.checkMatrix(other);
let result = new Matrix(3, 3);
const a00 = this.get(0, 0);
const a01 = this.get(0, 1);
const a02 = this.get(0, 2);
const a10 = this.get(1, 0);
const a11 = this.get(1, 1);
const a12 = this.get(1, 2);
const a20 = this.get(2, 0);
const a21 = this.get(2, 1);
const a22 = this.get(2, 2);
const b00 = other.get(0, 0);
const b01 = other.get(0, 1);
const b02 = other.get(0, 2);
const b10 = other.get(1, 0);
const b11 = other.get(1, 1);
const b12 = other.get(1, 2);
const b20 = other.get(2, 0);
const b21 = other.get(2, 1);
const b22 = other.get(2, 2);
const m1 = (a00 + a01 + a02 - a10 - a11 - a21 - a22) * b11;
const m2 = (a00 - a10) * (-b01 + b11);
const m3 = a11 * (-b00 + b01 + b10 - b11 - b12 - b20 + b22);
const m4 = (-a00 + a10 + a11) * (b00 - b01 + b11);
const m5 = (a10 + a11) * (-b00 + b01);
const m6 = a00 * b00;
const m7 = (-a00 + a20 + a21) * (b00 - b02 + b12);
const m8 = (-a00 + a20) * (b02 - b12);
const m9 = (a20 + a21) * (-b00 + b02);
const m10 = (a00 + a01 + a02 - a11 - a12 - a20 - a21) * b12;
const m11 = a21 * (-b00 + b02 + b10 - b11 - b12 - b20 + b21);
const m12 = (-a02 + a21 + a22) * (b11 + b20 - b21);
const m13 = (a02 - a22) * (b11 - b21);
const m14 = a02 * b20;
const m15 = (a21 + a22) * (-b20 + b21);
const m16 = (-a02 + a11 + a12) * (b12 + b20 - b22);
const m17 = (a02 - a12) * (b12 - b22);
const m18 = (a11 + a12) * (-b20 + b22);
const m19 = a01 * b10;
const m20 = a12 * b21;
const m21 = a10 * b02;
const m22 = a20 * b01;
const m23 = a22 * b22;
const c00 = m6 + m14 + m19;
const c01 = m1 + m4 + m5 + m6 + m12 + m14 + m15;
const c02 = m6 + m7 + m9 + m10 + m14 + m16 + m18;
const c10 = m2 + m3 + m4 + m6 + m14 + m16 + m17;
const c11 = m2 + m4 + m5 + m6 + m20;
const c12 = m14 + m16 + m17 + m18 + m21;
const c20 = m6 + m7 + m8 + m11 + m12 + m13 + m14;
const c21 = m12 + m13 + m14 + m15 + m22;
const c22 = m6 + m7 + m8 + m9 + m23;
result.set(0, 0, c00);
result.set(0, 1, c01);
result.set(0, 2, c02);
result.set(1, 0, c10);
result.set(1, 1, c11);
result.set(1, 2, c12);
result.set(2, 0, c20);
result.set(2, 1, c21);
result.set(2, 2, c22);
return result;
}
mmulStrassen(y) {
y = Matrix.checkMatrix(y);
let x = this.clone();
let r1 = x.rows;
let c1 = x.columns;
let r2 = y.rows;
let c2 = y.columns;
if (c1 !== r2) {
// eslint-disable-next-line no-console
console.warn(
`Multiplying ${r1} x ${c1} and ${r2} x ${c2} matrix: dimensions do not match.`,
);
}
// Put a matrix into the top left of a matrix of zeros.
// `rows` and `cols` are the dimensions of the output matrix.
function embed(mat, rows, cols) {
let r = mat.rows;
let c = mat.columns;
if (r === rows && c === cols) {
return mat;
} else {
let resultat = AbstractMatrix.zeros(rows, cols);
resultat = resultat.setSubMatrix(mat, 0, 0);
return resultat;
}
}
// Make sure both matrices are the same size.
// This is exclusively for simplicity:
// this algorithm can be implemented with matrices of different sizes.
let r = Math.max(r1, r2);
let c = Math.max(c1, c2);
x = embed(x, r, c);
y = embed(y, r, c);
// Our recursive multiplication function.
function blockMult(a, b, rows, cols) {
// For small matrices, resort to naive multiplication.
if (rows <= 512 || cols <= 512) {
return a.mmul(b); // a is equivalent to this
}
// Apply dynamic padding.
if (rows % 2 === 1 && cols % 2 === 1) {
a = embed(a, rows + 1, cols + 1);
b = embed(b, rows + 1, cols + 1);
} else if (rows % 2 === 1) {
a = embed(a, rows + 1, cols);
b = embed(b, rows + 1, cols);
} else if (cols % 2 === 1) {
a = embed(a, rows, cols + 1);
b = embed(b, rows, cols + 1);
}
let halfRows = parseInt(a.rows / 2, 10);
let halfCols = parseInt(a.columns / 2, 10);
// Subdivide input matrices.
let a11 = a.subMatrix(0, halfRows - 1, 0, halfCols - 1);
let b11 = b.subMatrix(0, halfRows - 1, 0, halfCols - 1);
let a12 = a.subMatrix(0, halfRows - 1, halfCols, a.columns - 1);
let b12 = b.subMatrix(0, halfRows - 1, halfCols, b.columns - 1);
let a21 = a.subMatrix(halfRows, a.rows - 1, 0, halfCols - 1);
let b21 = b.subMatrix(halfRows, b.rows - 1, 0, halfCols - 1);
let a22 = a.subMatrix(halfRows, a.rows - 1, halfCols, a.columns - 1);
let b22 = b.subMatrix(halfRows, b.rows - 1, halfCols, b.columns - 1);
// Compute intermediate values.
let m1 = blockMult(
AbstractMatrix.add(a11, a22),
AbstractMatrix.add(b11, b22),
halfRows,
halfCols,
);
let m2 = blockMult(AbstractMatrix.add(a21, a22), b11, halfRows, halfCols);
let m3 = blockMult(a11, AbstractMatrix.sub(b12, b22), halfRows, halfCols);
let m4 = blockMult(a22, AbstractMatrix.sub(b21, b11), halfRows, halfCols);
let m5 = blockMult(AbstractMatrix.add(a11, a12), b22, halfRows, halfCols);
let m6 = blockMult(
AbstractMatrix.sub(a21, a11),
AbstractMatrix.add(b11, b12),
halfRows,
halfCols,
);
let m7 = blockMult(
AbstractMatrix.sub(a12, a22),
AbstractMatrix.add(b21, b22),
halfRows,
halfCols,
);
// Combine intermediate values into the output.
let c11 = AbstractMatrix.add(m1, m4);
c11.sub(m5);
c11.add(m7);
let c12 = AbstractMatrix.add(m3, m5);
let c21 = AbstractMatrix.add(m2, m4);
let c22 = AbstractMatrix.sub(m1, m2);
c22.add(m3);
c22.add(m6);
// Crop output to the desired size (undo dynamic padding).
let result = AbstractMatrix.zeros(2 * c11.rows, 2 * c11.columns);
result = result.setSubMatrix(c11, 0, 0);
result = result.setSubMatrix(c12, c11.rows, 0);
result = result.setSubMatrix(c21, 0, c11.columns);
result = result.setSubMatrix(c22, c11.rows, c11.columns);
return result.subMatrix(0, rows - 1, 0, cols - 1);
}
return blockMult(x, y, r, c);
}
scaleRows(options = {}) {
if (typeof options !== 'object') {
throw new TypeError('options must be an object');
}
const { min = 0, max = 1 } = options;
if (!Number.isFinite(min)) throw new TypeError('min must be a number');
if (!Number.isFinite(max)) throw new TypeError('max must be a number');
if (min >= max) throw new RangeError('min must be smaller than max');
let newMatrix = new Matrix(this.rows, this.columns);
for (let i = 0; i < this.rows; i++) {
const row = this.getRow(i);
if (row.length > 0) {
rescale(row, { min, max, output: row });
}
newMatrix.setRow(i, row);
}
return newMatrix;
}
scaleColumns(options = {}) {
if (typeof options !== 'object') {
throw new TypeError('options must be an object');
}
const { min = 0, max = 1 } = options;
if (!Number.isFinite(min)) throw new TypeError('min must be a number');
if (!Number.isFinite(max)) throw new TypeError('max must be a number');
if (min >= max) throw new RangeError('min must be smaller than max');
let newMatrix = new Matrix(this.rows, this.columns);
for (let i = 0; i < this.columns; i++) {
const column = this.getColumn(i);
if (column.length) {
rescale(column, {
min,
max,
output: column,
});
}
newMatrix.setColumn(i, column);
}
return newMatrix;
}
flipRows() {
const middle = Math.ceil(this.columns / 2);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < middle; j++) {
let first = this.get(i, j);
let last = this.get(i, this.columns - 1 - j);
this.set(i, j, last);
this.set(i, this.columns - 1 - j, first);
}
}
return this;
}
flipColumns() {
const middle = Math.ceil(this.rows / 2);
for (let j = 0; j < this.columns; j++) {
for (let i = 0; i < middle; i++) {
let first = this.get(i, j);
let last = this.get(this.rows - 1 - i, j);
this.set(i, j, last);
this.set(this.rows - 1 - i, j, first);
}
}
return this;
}
kroneckerProduct(other) {
other = Matrix.checkMatrix(other);
let m = this.rows;
let n = this.columns;
let p = other.rows;
let q = other.columns;
let result = new Matrix(m * p, n * q);
for (let i = 0; i < m; i++) {
for (let j = 0; j < n; j++) {
for (let k = 0; k < p; k++) {
for (let l = 0; l < q; l++) {
result.set(p * i + k, q * j + l, this.get(i, j) * other.get(k, l));
}
}
}
}
return result;
}
kroneckerSum(other) {
other = Matrix.checkMatrix(other);
if (!this.isSquare() || !other.isSquare()) {
throw new Error('Kronecker Sum needs two Square Matrices');
}
let m = this.rows;
let n = other.rows;
let AxI = this.kroneckerProduct(Matrix.eye(n, n));
let IxB = Matrix.eye(m, m).kroneckerProduct(other);
return AxI.add(IxB);
}
transpose() {
let result = new Matrix(this.columns, this.rows);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.columns; j++) {
result.set(j, i, this.get(i, j));
}
}
return result;
}
sortRows(compareFunction = compareNumbers) {
for (let i = 0; i < this.rows; i++) {
this.setRow(i, this.getRow(i).sort(compareFunction));
}
return this;
}
sortColumns(compareFunction = compareNumbers) {
for (let i = 0; i < this.columns; i++) {
this.setColumn(i, this.getColumn(i).sort(compareFunction));
}
return this;
}
subMatrix(startRow, endRow, startColumn, endColumn) {
checkRange(this, startRow, endRow, startColumn, endColumn);
let newMatrix = new Matrix(
endRow - startRow + 1,
endColumn - startColumn + 1,
);
for (let i = startRow; i <= endRow; i++) {
for (let j = startColumn; j <= endColumn; j++) {
newMatrix.set(i - startRow, j - startColumn, this.get(i, j));
}
}
return newMatrix;
}
subMatrixRow(indices, startColumn, endColumn) {
if (startColumn === undefined) startColumn = 0;
if (endColumn === undefined) endColumn = this.columns - 1;
if (
startColumn > endColumn ||
startColumn < 0 ||
startColumn >= this.columns ||
endColumn < 0 ||
endColumn >= this.columns
) {
throw new RangeError('Argument out of range');
}
let newMatrix = new Matrix(indices.length, endColumn - startColumn + 1);
for (let i = 0; i < indices.length; i++) {
for (let j = startColumn; j <= endColumn; j++) {
if (indices[i] < 0 || indices[i] >= this.rows) {
throw new RangeError(`Row index out of range: ${indices[i]}`);
}
newMatrix.set(i, j - startColumn, this.get(indices[i], j));
}
}
return newMatrix;
}
subMatrixColumn(indices, startRow, endRow) {
if (startRow === undefined) startRow = 0;
if (endRow === undefined) endRow = this.rows - 1;
if (
startRow > endRow ||
startRow < 0 ||
startRow >= this.rows ||
endRow < 0 ||
endRow >= this.rows
) {
throw new RangeError('Argument out of range');
}
let newMatrix = new Matrix(endRow - startRow + 1, indices.length);
for (let i = 0; i < indices.length; i++) {
for (let j = startRow; j <= endRow; j++) {
if (indices[i] < 0 || indices[i] >= this.columns) {
throw new RangeError(`Column index out of range: ${indices[i]}`);
}
newMatrix.set(j - startRow, i, this.get(j, indices[i]));
}
}
return newMatrix;
}
setSubMatrix(matrix, startRow, startColumn) {
matrix = Matrix.checkMatrix(matrix);
if (matrix.isEmpty()) {
return this;
}
let endRow = startRow + matrix.rows - 1;
let endColumn = startColumn + matrix.columns - 1;
checkRange(this, startRow, endRow, startColumn, endColumn);
for (let i = 0; i < matrix.rows; i++) {
for (let j = 0; j < matrix.columns; j++) {
this.set(startRow + i, startColumn + j, matrix.get(i, j));
}
}
return this;
}
selection(rowIndices, columnIndices) {
checkRowIndices(this, rowIndices);
checkColumnIndices(this, columnIndices);
let newMatrix = new Matrix(rowIndices.length, columnIndices.length);
for (let i = 0; i < rowIndices.length; i++) {
let rowIndex = rowIndices[i];
for (let j = 0; j < columnIndices.length; j++) {
let columnIndex = columnIndices[j];
newMatrix.set(i, j, this.get(rowIndex, columnIndex));
}
}
return newMatrix;
}
trace() {
let min = Math.min(this.rows, this.columns);
let trace = 0;
for (let i = 0; i < min; i++) {
trace += this.get(i, i);
}
return trace;
}
clone() {
return this.constructor.copy(this, new Matrix(this.rows, this.columns));
}
/**
* @template {AbstractMatrix} M
* @param {AbstractMatrix} from
* @param {M} to
* @return {M}
*/
static copy(from, to) {
for (const [row, column, value] of from.entries()) {
to.set(row, column, value);
}
return to;
}
sum(by) {
switch (by) {
case 'row':
return sumByRow(this);
case 'column':
return sumByColumn(this);
case undefined:
return sumAll(this);
default:
throw new Error(`invalid option: ${by}`);
}
}
product(by) {
switch (by) {
case 'row':
return productByRow(this);
case 'column':
return productByColumn(this);
case undefined:
return productAll(this);
default:
throw new Error(`invalid option: ${by}`);
}
}
mean(by) {
const sum = this.sum(by);
switch (by) {
case 'row': {
for (let i = 0; i < this.rows; i++) {
sum[i] /= this.columns;
}
return sum;
}
case 'column': {
for (let i = 0; i < this.columns; i++) {
sum[i] /= this.rows;
}
return sum;
}
case undefined:
return sum / this.size;
default:
throw new Error(`invalid option: ${by}`);
}
}
variance(by, options = {}) {
if (typeof by === 'object') {
options = by;
by = undefined;
}
if (typeof options !== 'object') {
throw new TypeError('options must be an object');
}
const { unbiased = true, mean = this.mean(by) } = options;
if (typeof unbiased !== 'boolean') {
throw new TypeError('unbiased must be a boolean');
}
switch (by) {
case 'row': {
if (!isAnyArray(mean)) {
throw new TypeError('mean must be an array');
}
return varianceByRow(this, unbiased, mean);
}
case 'column': {
if (!isAnyArray(mean)) {
throw new TypeError('mean must be an array');
}
return varianceByColumn(this, unbiased, mean);
}
case undefined: {
if (typeof mean !== 'number') {
throw new TypeError('mean must be a number');
}
return varianceAll(this, unbiased, mean);
}
default:
throw new Error(`invalid option: ${by}`);
}
}
standardDeviation(by, options) {
if (typeof by === 'object') {
options = by;
by = undefined;
}
const variance = this.variance(by, options);
if (by === undefined) {
return Math.sqrt(variance);
} else {
for (let i = 0; i < variance.length; i++) {
variance[i] = Math.sqrt(variance[i]);
}
return variance;
}
}
center(by, options = {}) {
if (typeof by === 'object') {
options = by;
by = undefined;
}
if (typeof options !== 'object') {
throw new TypeError('options must be an object');
}
const { center = this.mean(by) } = options;
switch (by) {
case 'row': {
if (!isAnyArray(center)) {
throw new TypeError('center must be an array');
}
centerByRow(this, center);
return this;
}
case 'column': {
if (!isAnyArray(center)) {
throw new TypeError('center must be an array');
}
centerByColumn(this, center);
return this;
}
case undefined: {
if (typeof center !== 'number') {
throw new TypeError('center must be a number');
}
centerAll(this, center);
return this;
}
default:
throw new Error(`invalid option: ${by}`);
}
}
scale(by, options = {}) {
if (typeof by === 'object') {
options = by;
by = undefined;
}
if (typeof options !== 'object') {
throw new TypeError('options must be an object');
}
let scale = options.scale;
switch (by) {
case 'row': {
if (scale === undefined) {
scale = getScaleByRow(this);
} else if (!isAnyArray(scale)) {
throw new TypeError('scale must be an array');
}
scaleByRow(this, scale);
return this;
}
case 'column': {
if (scale === undefined) {
scale = getScaleByColumn(this);
} else if (!isAnyArray(scale)) {
throw new TypeError('scale must be an array');
}
scaleByColumn(this, scale);
return this;
}
case undefined: {
if (scale === undefined) {
scale = getScaleAll(this);
} else if (typeof scale !== 'number') {
throw new TypeError('scale must be a number');
}
scaleAll(this, scale);
return this;
}
default:
throw new Error(`invalid option: ${by}`);
}
}
toString(options) {
return inspectMatrixWithOptions(this, options);
}
[Symbol.iterator]() {
return this.entries();
}
/**
* iterator from left to right, from top to bottom
* yield [row, column, value]
* @returns {Generator<[number, number, number], void, void>}
*/
*entries() {
for (let row = 0; row < this.rows; row++) {
for (let col = 0; col < this.columns; col++) {
yield [row, col, this.get(row, col)];
}
}
}
/**
* iterator from left to right, from top to bottom
* yield value
* @returns {Generator<number, void, void>}
*/
*values() {
for (let row = 0; row < this.rows; row++) {
for (let col = 0; col < this.columns; col++) {
yield this.get(row, col);
}
}
}
}
AbstractMatrix.prototype.klass = 'Matrix';
if (typeof Symbol !== 'undefined') {
AbstractMatrix.prototype[Symbol.for('nodejs.util.inspect.custom')] =
inspectMatrix;
}
function compareNumbers(a, b) {
return a - b;
}
function isArrayOfNumbers(array) {
return array.every((element) => {
return typeof element === 'number';
});
}
// Synonyms
AbstractMatrix.random = AbstractMatrix.rand;
AbstractMatrix.randomInt = AbstractMatrix.randInt;
AbstractMatrix.diagonal = AbstractMatrix.diag;
AbstractMatrix.prototype.diagonal = AbstractMatrix.prototype.diag;
AbstractMatrix.identity = AbstractMatrix.eye;
AbstractMatrix.prototype.negate = AbstractMatrix.prototype.neg;
AbstractMatrix.prototype.tensorProduct =
AbstractMatrix.prototype.kroneckerProduct;
export default class Matrix extends AbstractMatrix {
/**
* @type {Float64Array[]}
*/
data;
/**
* Init an empty matrix
* @param {number} nRows
* @param {number} nColumns
*/
#initData(nRows, nColumns) {
this.data = [];
if (Number.isInteger(nColumns) && nColumns >= 0) {
for (let i = 0; i < nRows; i++) {
this.data.push(new Float64Array(nColumns));
}
} else {
throw new TypeError('nColumns must be a positive integer');
}
this.rows = nRows;
this.columns = nColumns;
}
constructor(nRows, nColumns) {
super();
if (Matrix.isMatrix(nRows)) {
this.#initData(nRows.rows, nRows.columns);
Matrix.copy(nRows, this);
} else if (Number.isInteger(nRows) && nRows >= 0) {
this.#initData(nRows, nColumns);
} else if (isAnyArray(nRows)) {
// Copy the values from the 2D array
const arrayData = nRows;
nRows = arrayData.length;
nColumns = nRows ? arrayData[0].length : 0;
if (typeof nColumns !== 'number') {
throw new TypeError(
'Data must be a 2D array with at least one element',
);
}
this.data = [];
for (let i = 0; i < nRows; i++) {
if (arrayData[i].length !== nColumns) {
throw new RangeError('Inconsistent array dimensions');
}
if (!isArrayOfNumbers(arrayData[i])) {
throw new TypeError('Input data contains non-numeric values');
}
this.data.push(Float64Array.from(arrayData[i]));
}
this.rows = nRows;
this.columns = nColumns;
} else {
throw new TypeError(
'First argument must be a positive number or an array',
);
}
}
set(rowIndex, columnIndex, value) {
this.data[rowIndex][columnIndex] = value;
return this;
}
get(rowIndex, columnIndex) {
return this.data[rowIndex][columnIndex];
}
removeRow(index) {
checkRowIndex(this, index);
this.data.splice(index, 1);
this.rows -= 1;
return this;
}
addRow(index, array) {
if (array === undefined) {
array = index;
index = this.rows;
}
checkRowIndex(this, index, true);
array = Float64Array.from(checkRowVector(this, array));
this.data.splice(index, 0, array);
this.rows += 1;
return this;
}
removeColumn(index) {
checkColumnIndex(this, index);
for (let i = 0; i < this.rows; i++) {
const newRow = new Float64Array(this.columns - 1);
for (let j = 0; j < index; j++) {
newRow[j] = this.data[i][j];
}
for (let j = index + 1; j < this.columns; j++) {
newRow[j - 1] = this.data[i][j];
}
this.data[i] = newRow;
}
this.columns -= 1;
return this;
}
addColumn(index, array) {
if (typeof array === 'undefined') {
array = index;
index = this.columns;
}
checkColumnIndex(this, index, true);
array = checkColumnVector(this, array);
for (let i = 0; i < this.rows; i++) {
const newRow = new Float64Array(this.columns + 1);
let j = 0;
for (; j < index; j++) {
newRow[j] = this.data[i][j];
}
newRow[j++] = array[i];
for (; j < this.columns + 1; j++) {
newRow[j] = this.data[i][j - 1];
}
this.data[i] = newRow;
}
this.columns += 1;
return this;
}
}
installMathOperations(AbstractMatrix, Matrix);