UNPKG

ml-matrix

Version:

Matrix manipulation and computation library

1,736 lines (1,581 loc) 45.5 kB
import { isAnyArray } from 'is-any-array'; import rescale from 'ml-array-rescale'; import { inspectMatrix, inspectMatrixWithOptions } from './inspect'; import { installMathOperations } from './mathOperations'; import { sumByRow, sumByColumn, sumAll, productByRow, productByColumn, productAll, varianceByRow, varianceByColumn, varianceAll, centerByRow, centerByColumn, centerAll, scaleByRow, scaleByColumn, scaleAll, getScaleByRow, getScaleByColumn, getScaleAll, } from './stat'; import { checkRowVector, checkRowIndex, checkColumnIndex, checkColumnVector, checkRange, checkNonEmpty, checkRowIndices, checkColumnIndices, } from './util'; export class AbstractMatrix { static from1DArray(newRows, newColumns, newData) { let length = newRows * newColumns; if (length !== newData.length) { throw new RangeError('data length does not match given dimensions'); } let newMatrix = new Matrix(newRows, newColumns); for (let row = 0; row < newRows; row++) { for (let column = 0; column < newColumns; column++) { newMatrix.set(row, column, newData[row * newColumns + column]); } } return newMatrix; } static rowVector(newData) { let vector = new Matrix(1, newData.length); for (let i = 0; i < newData.length; i++) { vector.set(0, i, newData[i]); } return vector; } static columnVector(newData) { let vector = new Matrix(newData.length, 1); for (let i = 0; i < newData.length; i++) { vector.set(i, 0, newData[i]); } return vector; } static zeros(rows, columns) { return new Matrix(rows, columns); } static ones(rows, columns) { return new Matrix(rows, columns).fill(1); } static rand(rows, columns, options = {}) { if (typeof options !== 'object') { throw new TypeError('options must be an object'); } const { random = Math.random } = options; let matrix = new Matrix(rows, columns); for (let i = 0; i < rows; i++) { for (let j = 0; j < columns; j++) { matrix.set(i, j, random()); } } return matrix; } static randInt(rows, columns, options = {}) { if (typeof options !== 'object') { throw new TypeError('options must be an object'); } const { min = 0, max = 1000, random = Math.random } = options; if (!Number.isInteger(min)) throw new TypeError('min must be an integer'); if (!Number.isInteger(max)) throw new TypeError('max must be an integer'); if (min >= max) throw new RangeError('min must be smaller than max'); let interval = max - min; let matrix = new Matrix(rows, columns); for (let i = 0; i < rows; i++) { for (let j = 0; j < columns; j++) { let value = min + Math.round(random() * interval); matrix.set(i, j, value); } } return matrix; } static eye(rows, columns, value) { if (columns === undefined) columns = rows; if (value === undefined) value = 1; let min = Math.min(rows, columns); let matrix = this.zeros(rows, columns); for (let i = 0; i < min; i++) { matrix.set(i, i, value); } return matrix; } static diag(data, rows, columns) { let l = data.length; if (rows === undefined) rows = l; if (columns === undefined) columns = rows; let min = Math.min(l, rows, columns); let matrix = this.zeros(rows, columns); for (let i = 0; i < min; i++) { matrix.set(i, i, data[i]); } return matrix; } static min(matrix1, matrix2) { matrix1 = this.checkMatrix(matrix1); matrix2 = this.checkMatrix(matrix2); let rows = matrix1.rows; let columns = matrix1.columns; let result = new Matrix(rows, columns); for (let i = 0; i < rows; i++) { for (let j = 0; j < columns; j++) { result.set(i, j, Math.min(matrix1.get(i, j), matrix2.get(i, j))); } } return result; } static max(matrix1, matrix2) { matrix1 = this.checkMatrix(matrix1); matrix2 = this.checkMatrix(matrix2); let rows = matrix1.rows; let columns = matrix1.columns; let result = new this(rows, columns); for (let i = 0; i < rows; i++) { for (let j = 0; j < columns; j++) { result.set(i, j, Math.max(matrix1.get(i, j), matrix2.get(i, j))); } } return result; } static checkMatrix(value) { return AbstractMatrix.isMatrix(value) ? value : new Matrix(value); } static isMatrix(value) { return value != null && value.klass === 'Matrix'; } get size() { return this.rows * this.columns; } apply(callback) { if (typeof callback !== 'function') { throw new TypeError('callback must be a function'); } for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { callback.call(this, i, j); } } return this; } to1DArray() { let array = []; for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { array.push(this.get(i, j)); } } return array; } to2DArray() { let copy = []; for (let i = 0; i < this.rows; i++) { copy.push([]); for (let j = 0; j < this.columns; j++) { copy[i].push(this.get(i, j)); } } return copy; } toJSON() { return this.to2DArray(); } isRowVector() { return this.rows === 1; } isColumnVector() { return this.columns === 1; } isVector() { return this.rows === 1 || this.columns === 1; } isSquare() { return this.rows === this.columns; } isEmpty() { return this.rows === 0 || this.columns === 0; } isSymmetric() { if (this.isSquare()) { for (let i = 0; i < this.rows; i++) { for (let j = 0; j <= i; j++) { if (this.get(i, j) !== this.get(j, i)) { return false; } } } return true; } return false; } isDistance() { if (!this.isSymmetric()) return false; for (let i = 0; i < this.rows; i++) { if (this.get(i, i) !== 0) return false; } return true; } isEchelonForm() { let i = 0; let j = 0; let previousColumn = -1; let isEchelonForm = true; let checked = false; while (i < this.rows && isEchelonForm) { j = 0; checked = false; while (j < this.columns && checked === false) { if (this.get(i, j) === 0) { j++; } else if (this.get(i, j) === 1 && j > previousColumn) { checked = true; previousColumn = j; } else { isEchelonForm = false; checked = true; } } i++; } return isEchelonForm; } isReducedEchelonForm() { let i = 0; let j = 0; let previousColumn = -1; let isReducedEchelonForm = true; let checked = false; while (i < this.rows && isReducedEchelonForm) { j = 0; checked = false; while (j < this.columns && checked === false) { if (this.get(i, j) === 0) { j++; } else if (this.get(i, j) === 1 && j > previousColumn) { checked = true; previousColumn = j; } else { isReducedEchelonForm = false; checked = true; } } for (let k = j + 1; k < this.rows; k++) { if (this.get(i, k) !== 0) { isReducedEchelonForm = false; } } i++; } return isReducedEchelonForm; } echelonForm() { let result = this.clone(); let h = 0; let k = 0; while (h < result.rows && k < result.columns) { let iMax = h; for (let i = h; i < result.rows; i++) { if (result.get(i, k) > result.get(iMax, k)) { iMax = i; } } if (result.get(iMax, k) === 0) { k++; } else { result.swapRows(h, iMax); let tmp = result.get(h, k); for (let j = k; j < result.columns; j++) { result.set(h, j, result.get(h, j) / tmp); } for (let i = h + 1; i < result.rows; i++) { let factor = result.get(i, k) / result.get(h, k); result.set(i, k, 0); for (let j = k + 1; j < result.columns; j++) { result.set(i, j, result.get(i, j) - result.get(h, j) * factor); } } h++; k++; } } return result; } reducedEchelonForm() { let result = this.echelonForm(); let m = result.columns; let n = result.rows; let h = n - 1; while (h >= 0) { if (result.maxRow(h) === 0) { h--; } else { let p = 0; let pivot = false; while (p < n && pivot === false) { if (result.get(h, p) === 1) { pivot = true; } else { p++; } } for (let i = 0; i < h; i++) { let factor = result.get(i, p); for (let j = p; j < m; j++) { let tmp = result.get(i, j) - factor * result.get(h, j); result.set(i, j, tmp); } } h--; } } return result; } set() { throw new Error('set method is unimplemented'); } get() { throw new Error('get method is unimplemented'); } repeat(options = {}) { if (typeof options !== 'object') { throw new TypeError('options must be an object'); } const { rows = 1, columns = 1 } = options; if (!Number.isInteger(rows) || rows <= 0) { throw new TypeError('rows must be a positive integer'); } if (!Number.isInteger(columns) || columns <= 0) { throw new TypeError('columns must be a positive integer'); } let matrix = new Matrix(this.rows * rows, this.columns * columns); for (let i = 0; i < rows; i++) { for (let j = 0; j < columns; j++) { matrix.setSubMatrix(this, this.rows * i, this.columns * j); } } return matrix; } fill(value) { for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { this.set(i, j, value); } } return this; } neg() { return this.mulS(-1); } getRow(index) { checkRowIndex(this, index); let row = []; for (let i = 0; i < this.columns; i++) { row.push(this.get(index, i)); } return row; } getRowVector(index) { return Matrix.rowVector(this.getRow(index)); } setRow(index, array) { checkRowIndex(this, index); array = checkRowVector(this, array); for (let i = 0; i < this.columns; i++) { this.set(index, i, array[i]); } return this; } swapRows(row1, row2) { checkRowIndex(this, row1); checkRowIndex(this, row2); for (let i = 0; i < this.columns; i++) { let temp = this.get(row1, i); this.set(row1, i, this.get(row2, i)); this.set(row2, i, temp); } return this; } getColumn(index) { checkColumnIndex(this, index); let column = []; for (let i = 0; i < this.rows; i++) { column.push(this.get(i, index)); } return column; } getColumnVector(index) { return Matrix.columnVector(this.getColumn(index)); } setColumn(index, array) { checkColumnIndex(this, index); array = checkColumnVector(this, array); for (let i = 0; i < this.rows; i++) { this.set(i, index, array[i]); } return this; } swapColumns(column1, column2) { checkColumnIndex(this, column1); checkColumnIndex(this, column2); for (let i = 0; i < this.rows; i++) { let temp = this.get(i, column1); this.set(i, column1, this.get(i, column2)); this.set(i, column2, temp); } return this; } addRowVector(vector) { vector = checkRowVector(this, vector); for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { this.set(i, j, this.get(i, j) + vector[j]); } } return this; } subRowVector(vector) { vector = checkRowVector(this, vector); for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { this.set(i, j, this.get(i, j) - vector[j]); } } return this; } mulRowVector(vector) { vector = checkRowVector(this, vector); for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { this.set(i, j, this.get(i, j) * vector[j]); } } return this; } divRowVector(vector) { vector = checkRowVector(this, vector); for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { this.set(i, j, this.get(i, j) / vector[j]); } } return this; } addColumnVector(vector) { vector = checkColumnVector(this, vector); for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { this.set(i, j, this.get(i, j) + vector[i]); } } return this; } subColumnVector(vector) { vector = checkColumnVector(this, vector); for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { this.set(i, j, this.get(i, j) - vector[i]); } } return this; } mulColumnVector(vector) { vector = checkColumnVector(this, vector); for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { this.set(i, j, this.get(i, j) * vector[i]); } } return this; } divColumnVector(vector) { vector = checkColumnVector(this, vector); for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { this.set(i, j, this.get(i, j) / vector[i]); } } return this; } mulRow(index, value) { checkRowIndex(this, index); for (let i = 0; i < this.columns; i++) { this.set(index, i, this.get(index, i) * value); } return this; } mulColumn(index, value) { checkColumnIndex(this, index); for (let i = 0; i < this.rows; i++) { this.set(i, index, this.get(i, index) * value); } return this; } max(by) { if (this.isEmpty()) { return NaN; } switch (by) { case 'row': { const max = new Array(this.rows).fill(Number.NEGATIVE_INFINITY); for (let row = 0; row < this.rows; row++) { for (let column = 0; column < this.columns; column++) { if (this.get(row, column) > max[row]) { max[row] = this.get(row, column); } } } return max; } case 'column': { const max = new Array(this.columns).fill(Number.NEGATIVE_INFINITY); for (let row = 0; row < this.rows; row++) { for (let column = 0; column < this.columns; column++) { if (this.get(row, column) > max[column]) { max[column] = this.get(row, column); } } } return max; } case undefined: { let max = this.get(0, 0); for (let row = 0; row < this.rows; row++) { for (let column = 0; column < this.columns; column++) { if (this.get(row, column) > max) { max = this.get(row, column); } } } return max; } default: throw new Error(`invalid option: ${by}`); } } maxIndex() { checkNonEmpty(this); let v = this.get(0, 0); let idx = [0, 0]; for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { if (this.get(i, j) > v) { v = this.get(i, j); idx[0] = i; idx[1] = j; } } } return idx; } min(by) { if (this.isEmpty()) { return NaN; } switch (by) { case 'row': { const min = new Array(this.rows).fill(Number.POSITIVE_INFINITY); for (let row = 0; row < this.rows; row++) { for (let column = 0; column < this.columns; column++) { if (this.get(row, column) < min[row]) { min[row] = this.get(row, column); } } } return min; } case 'column': { const min = new Array(this.columns).fill(Number.POSITIVE_INFINITY); for (let row = 0; row < this.rows; row++) { for (let column = 0; column < this.columns; column++) { if (this.get(row, column) < min[column]) { min[column] = this.get(row, column); } } } return min; } case undefined: { let min = this.get(0, 0); for (let row = 0; row < this.rows; row++) { for (let column = 0; column < this.columns; column++) { if (this.get(row, column) < min) { min = this.get(row, column); } } } return min; } default: throw new Error(`invalid option: ${by}`); } } minIndex() { checkNonEmpty(this); let v = this.get(0, 0); let idx = [0, 0]; for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { if (this.get(i, j) < v) { v = this.get(i, j); idx[0] = i; idx[1] = j; } } } return idx; } maxRow(row) { checkRowIndex(this, row); if (this.isEmpty()) { return NaN; } let v = this.get(row, 0); for (let i = 1; i < this.columns; i++) { if (this.get(row, i) > v) { v = this.get(row, i); } } return v; } maxRowIndex(row) { checkRowIndex(this, row); checkNonEmpty(this); let v = this.get(row, 0); let idx = [row, 0]; for (let i = 1; i < this.columns; i++) { if (this.get(row, i) > v) { v = this.get(row, i); idx[1] = i; } } return idx; } minRow(row) { checkRowIndex(this, row); if (this.isEmpty()) { return NaN; } let v = this.get(row, 0); for (let i = 1; i < this.columns; i++) { if (this.get(row, i) < v) { v = this.get(row, i); } } return v; } minRowIndex(row) { checkRowIndex(this, row); checkNonEmpty(this); let v = this.get(row, 0); let idx = [row, 0]; for (let i = 1; i < this.columns; i++) { if (this.get(row, i) < v) { v = this.get(row, i); idx[1] = i; } } return idx; } maxColumn(column) { checkColumnIndex(this, column); if (this.isEmpty()) { return NaN; } let v = this.get(0, column); for (let i = 1; i < this.rows; i++) { if (this.get(i, column) > v) { v = this.get(i, column); } } return v; } maxColumnIndex(column) { checkColumnIndex(this, column); checkNonEmpty(this); let v = this.get(0, column); let idx = [0, column]; for (let i = 1; i < this.rows; i++) { if (this.get(i, column) > v) { v = this.get(i, column); idx[0] = i; } } return idx; } minColumn(column) { checkColumnIndex(this, column); if (this.isEmpty()) { return NaN; } let v = this.get(0, column); for (let i = 1; i < this.rows; i++) { if (this.get(i, column) < v) { v = this.get(i, column); } } return v; } minColumnIndex(column) { checkColumnIndex(this, column); checkNonEmpty(this); let v = this.get(0, column); let idx = [0, column]; for (let i = 1; i < this.rows; i++) { if (this.get(i, column) < v) { v = this.get(i, column); idx[0] = i; } } return idx; } diag() { let min = Math.min(this.rows, this.columns); let diag = []; for (let i = 0; i < min; i++) { diag.push(this.get(i, i)); } return diag; } norm(type = 'frobenius') { switch (type) { case 'max': return this.max(); case 'frobenius': return Math.sqrt(this.dot(this)); default: throw new RangeError(`unknown norm type: ${type}`); } } cumulativeSum() { let sum = 0; for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { sum += this.get(i, j); this.set(i, j, sum); } } return this; } dot(vector2) { if (AbstractMatrix.isMatrix(vector2)) vector2 = vector2.to1DArray(); let vector1 = this.to1DArray(); if (vector1.length !== vector2.length) { throw new RangeError('vectors do not have the same size'); } let dot = 0; for (let i = 0; i < vector1.length; i++) { dot += vector1[i] * vector2[i]; } return dot; } mmul(other) { other = Matrix.checkMatrix(other); let m = this.rows; let n = this.columns; let p = other.columns; let result = new Matrix(m, p); let Bcolj = new Float64Array(n); for (let j = 0; j < p; j++) { for (let k = 0; k < n; k++) { Bcolj[k] = other.get(k, j); } for (let i = 0; i < m; i++) { let s = 0; for (let k = 0; k < n; k++) { s += this.get(i, k) * Bcolj[k]; } result.set(i, j, s); } } return result; } mpow(scalar) { if (!this.isSquare()) { throw new RangeError('Matrix must be square'); } if (!Number.isInteger(scalar) || scalar < 0) { throw new RangeError('Exponent must be a non-negative integer'); } // Russian Peasant exponentiation, i.e. exponentiation by squaring let result = Matrix.eye(this.rows); let bb = this; // Note: Don't bit shift. In JS, that would truncate at 32 bits for (let e = scalar; e > 1; e /= 2) { if ((e & 1) !== 0) { result = result.mmul(bb); } bb = bb.mmul(bb); } return result; } strassen2x2(other) { other = Matrix.checkMatrix(other); let result = new Matrix(2, 2); const a11 = this.get(0, 0); const b11 = other.get(0, 0); const a12 = this.get(0, 1); const b12 = other.get(0, 1); const a21 = this.get(1, 0); const b21 = other.get(1, 0); const a22 = this.get(1, 1); const b22 = other.get(1, 1); // Compute intermediate values. const m1 = (a11 + a22) * (b11 + b22); const m2 = (a21 + a22) * b11; const m3 = a11 * (b12 - b22); const m4 = a22 * (b21 - b11); const m5 = (a11 + a12) * b22; const m6 = (a21 - a11) * (b11 + b12); const m7 = (a12 - a22) * (b21 + b22); // Combine intermediate values into the output. const c00 = m1 + m4 - m5 + m7; const c01 = m3 + m5; const c10 = m2 + m4; const c11 = m1 - m2 + m3 + m6; result.set(0, 0, c00); result.set(0, 1, c01); result.set(1, 0, c10); result.set(1, 1, c11); return result; } strassen3x3(other) { other = Matrix.checkMatrix(other); let result = new Matrix(3, 3); const a00 = this.get(0, 0); const a01 = this.get(0, 1); const a02 = this.get(0, 2); const a10 = this.get(1, 0); const a11 = this.get(1, 1); const a12 = this.get(1, 2); const a20 = this.get(2, 0); const a21 = this.get(2, 1); const a22 = this.get(2, 2); const b00 = other.get(0, 0); const b01 = other.get(0, 1); const b02 = other.get(0, 2); const b10 = other.get(1, 0); const b11 = other.get(1, 1); const b12 = other.get(1, 2); const b20 = other.get(2, 0); const b21 = other.get(2, 1); const b22 = other.get(2, 2); const m1 = (a00 + a01 + a02 - a10 - a11 - a21 - a22) * b11; const m2 = (a00 - a10) * (-b01 + b11); const m3 = a11 * (-b00 + b01 + b10 - b11 - b12 - b20 + b22); const m4 = (-a00 + a10 + a11) * (b00 - b01 + b11); const m5 = (a10 + a11) * (-b00 + b01); const m6 = a00 * b00; const m7 = (-a00 + a20 + a21) * (b00 - b02 + b12); const m8 = (-a00 + a20) * (b02 - b12); const m9 = (a20 + a21) * (-b00 + b02); const m10 = (a00 + a01 + a02 - a11 - a12 - a20 - a21) * b12; const m11 = a21 * (-b00 + b02 + b10 - b11 - b12 - b20 + b21); const m12 = (-a02 + a21 + a22) * (b11 + b20 - b21); const m13 = (a02 - a22) * (b11 - b21); const m14 = a02 * b20; const m15 = (a21 + a22) * (-b20 + b21); const m16 = (-a02 + a11 + a12) * (b12 + b20 - b22); const m17 = (a02 - a12) * (b12 - b22); const m18 = (a11 + a12) * (-b20 + b22); const m19 = a01 * b10; const m20 = a12 * b21; const m21 = a10 * b02; const m22 = a20 * b01; const m23 = a22 * b22; const c00 = m6 + m14 + m19; const c01 = m1 + m4 + m5 + m6 + m12 + m14 + m15; const c02 = m6 + m7 + m9 + m10 + m14 + m16 + m18; const c10 = m2 + m3 + m4 + m6 + m14 + m16 + m17; const c11 = m2 + m4 + m5 + m6 + m20; const c12 = m14 + m16 + m17 + m18 + m21; const c20 = m6 + m7 + m8 + m11 + m12 + m13 + m14; const c21 = m12 + m13 + m14 + m15 + m22; const c22 = m6 + m7 + m8 + m9 + m23; result.set(0, 0, c00); result.set(0, 1, c01); result.set(0, 2, c02); result.set(1, 0, c10); result.set(1, 1, c11); result.set(1, 2, c12); result.set(2, 0, c20); result.set(2, 1, c21); result.set(2, 2, c22); return result; } mmulStrassen(y) { y = Matrix.checkMatrix(y); let x = this.clone(); let r1 = x.rows; let c1 = x.columns; let r2 = y.rows; let c2 = y.columns; if (c1 !== r2) { // eslint-disable-next-line no-console console.warn( `Multiplying ${r1} x ${c1} and ${r2} x ${c2} matrix: dimensions do not match.`, ); } // Put a matrix into the top left of a matrix of zeros. // `rows` and `cols` are the dimensions of the output matrix. function embed(mat, rows, cols) { let r = mat.rows; let c = mat.columns; if (r === rows && c === cols) { return mat; } else { let resultat = AbstractMatrix.zeros(rows, cols); resultat = resultat.setSubMatrix(mat, 0, 0); return resultat; } } // Make sure both matrices are the same size. // This is exclusively for simplicity: // this algorithm can be implemented with matrices of different sizes. let r = Math.max(r1, r2); let c = Math.max(c1, c2); x = embed(x, r, c); y = embed(y, r, c); // Our recursive multiplication function. function blockMult(a, b, rows, cols) { // For small matrices, resort to naive multiplication. if (rows <= 512 || cols <= 512) { return a.mmul(b); // a is equivalent to this } // Apply dynamic padding. if (rows % 2 === 1 && cols % 2 === 1) { a = embed(a, rows + 1, cols + 1); b = embed(b, rows + 1, cols + 1); } else if (rows % 2 === 1) { a = embed(a, rows + 1, cols); b = embed(b, rows + 1, cols); } else if (cols % 2 === 1) { a = embed(a, rows, cols + 1); b = embed(b, rows, cols + 1); } let halfRows = parseInt(a.rows / 2, 10); let halfCols = parseInt(a.columns / 2, 10); // Subdivide input matrices. let a11 = a.subMatrix(0, halfRows - 1, 0, halfCols - 1); let b11 = b.subMatrix(0, halfRows - 1, 0, halfCols - 1); let a12 = a.subMatrix(0, halfRows - 1, halfCols, a.columns - 1); let b12 = b.subMatrix(0, halfRows - 1, halfCols, b.columns - 1); let a21 = a.subMatrix(halfRows, a.rows - 1, 0, halfCols - 1); let b21 = b.subMatrix(halfRows, b.rows - 1, 0, halfCols - 1); let a22 = a.subMatrix(halfRows, a.rows - 1, halfCols, a.columns - 1); let b22 = b.subMatrix(halfRows, b.rows - 1, halfCols, b.columns - 1); // Compute intermediate values. let m1 = blockMult( AbstractMatrix.add(a11, a22), AbstractMatrix.add(b11, b22), halfRows, halfCols, ); let m2 = blockMult(AbstractMatrix.add(a21, a22), b11, halfRows, halfCols); let m3 = blockMult(a11, AbstractMatrix.sub(b12, b22), halfRows, halfCols); let m4 = blockMult(a22, AbstractMatrix.sub(b21, b11), halfRows, halfCols); let m5 = blockMult(AbstractMatrix.add(a11, a12), b22, halfRows, halfCols); let m6 = blockMult( AbstractMatrix.sub(a21, a11), AbstractMatrix.add(b11, b12), halfRows, halfCols, ); let m7 = blockMult( AbstractMatrix.sub(a12, a22), AbstractMatrix.add(b21, b22), halfRows, halfCols, ); // Combine intermediate values into the output. let c11 = AbstractMatrix.add(m1, m4); c11.sub(m5); c11.add(m7); let c12 = AbstractMatrix.add(m3, m5); let c21 = AbstractMatrix.add(m2, m4); let c22 = AbstractMatrix.sub(m1, m2); c22.add(m3); c22.add(m6); // Crop output to the desired size (undo dynamic padding). let result = AbstractMatrix.zeros(2 * c11.rows, 2 * c11.columns); result = result.setSubMatrix(c11, 0, 0); result = result.setSubMatrix(c12, c11.rows, 0); result = result.setSubMatrix(c21, 0, c11.columns); result = result.setSubMatrix(c22, c11.rows, c11.columns); return result.subMatrix(0, rows - 1, 0, cols - 1); } return blockMult(x, y, r, c); } scaleRows(options = {}) { if (typeof options !== 'object') { throw new TypeError('options must be an object'); } const { min = 0, max = 1 } = options; if (!Number.isFinite(min)) throw new TypeError('min must be a number'); if (!Number.isFinite(max)) throw new TypeError('max must be a number'); if (min >= max) throw new RangeError('min must be smaller than max'); let newMatrix = new Matrix(this.rows, this.columns); for (let i = 0; i < this.rows; i++) { const row = this.getRow(i); if (row.length > 0) { rescale(row, { min, max, output: row }); } newMatrix.setRow(i, row); } return newMatrix; } scaleColumns(options = {}) { if (typeof options !== 'object') { throw new TypeError('options must be an object'); } const { min = 0, max = 1 } = options; if (!Number.isFinite(min)) throw new TypeError('min must be a number'); if (!Number.isFinite(max)) throw new TypeError('max must be a number'); if (min >= max) throw new RangeError('min must be smaller than max'); let newMatrix = new Matrix(this.rows, this.columns); for (let i = 0; i < this.columns; i++) { const column = this.getColumn(i); if (column.length) { rescale(column, { min, max, output: column, }); } newMatrix.setColumn(i, column); } return newMatrix; } flipRows() { const middle = Math.ceil(this.columns / 2); for (let i = 0; i < this.rows; i++) { for (let j = 0; j < middle; j++) { let first = this.get(i, j); let last = this.get(i, this.columns - 1 - j); this.set(i, j, last); this.set(i, this.columns - 1 - j, first); } } return this; } flipColumns() { const middle = Math.ceil(this.rows / 2); for (let j = 0; j < this.columns; j++) { for (let i = 0; i < middle; i++) { let first = this.get(i, j); let last = this.get(this.rows - 1 - i, j); this.set(i, j, last); this.set(this.rows - 1 - i, j, first); } } return this; } kroneckerProduct(other) { other = Matrix.checkMatrix(other); let m = this.rows; let n = this.columns; let p = other.rows; let q = other.columns; let result = new Matrix(m * p, n * q); for (let i = 0; i < m; i++) { for (let j = 0; j < n; j++) { for (let k = 0; k < p; k++) { for (let l = 0; l < q; l++) { result.set(p * i + k, q * j + l, this.get(i, j) * other.get(k, l)); } } } } return result; } kroneckerSum(other) { other = Matrix.checkMatrix(other); if (!this.isSquare() || !other.isSquare()) { throw new Error('Kronecker Sum needs two Square Matrices'); } let m = this.rows; let n = other.rows; let AxI = this.kroneckerProduct(Matrix.eye(n, n)); let IxB = Matrix.eye(m, m).kroneckerProduct(other); return AxI.add(IxB); } transpose() { let result = new Matrix(this.columns, this.rows); for (let i = 0; i < this.rows; i++) { for (let j = 0; j < this.columns; j++) { result.set(j, i, this.get(i, j)); } } return result; } sortRows(compareFunction = compareNumbers) { for (let i = 0; i < this.rows; i++) { this.setRow(i, this.getRow(i).sort(compareFunction)); } return this; } sortColumns(compareFunction = compareNumbers) { for (let i = 0; i < this.columns; i++) { this.setColumn(i, this.getColumn(i).sort(compareFunction)); } return this; } subMatrix(startRow, endRow, startColumn, endColumn) { checkRange(this, startRow, endRow, startColumn, endColumn); let newMatrix = new Matrix( endRow - startRow + 1, endColumn - startColumn + 1, ); for (let i = startRow; i <= endRow; i++) { for (let j = startColumn; j <= endColumn; j++) { newMatrix.set(i - startRow, j - startColumn, this.get(i, j)); } } return newMatrix; } subMatrixRow(indices, startColumn, endColumn) { if (startColumn === undefined) startColumn = 0; if (endColumn === undefined) endColumn = this.columns - 1; if ( startColumn > endColumn || startColumn < 0 || startColumn >= this.columns || endColumn < 0 || endColumn >= this.columns ) { throw new RangeError('Argument out of range'); } let newMatrix = new Matrix(indices.length, endColumn - startColumn + 1); for (let i = 0; i < indices.length; i++) { for (let j = startColumn; j <= endColumn; j++) { if (indices[i] < 0 || indices[i] >= this.rows) { throw new RangeError(`Row index out of range: ${indices[i]}`); } newMatrix.set(i, j - startColumn, this.get(indices[i], j)); } } return newMatrix; } subMatrixColumn(indices, startRow, endRow) { if (startRow === undefined) startRow = 0; if (endRow === undefined) endRow = this.rows - 1; if ( startRow > endRow || startRow < 0 || startRow >= this.rows || endRow < 0 || endRow >= this.rows ) { throw new RangeError('Argument out of range'); } let newMatrix = new Matrix(endRow - startRow + 1, indices.length); for (let i = 0; i < indices.length; i++) { for (let j = startRow; j <= endRow; j++) { if (indices[i] < 0 || indices[i] >= this.columns) { throw new RangeError(`Column index out of range: ${indices[i]}`); } newMatrix.set(j - startRow, i, this.get(j, indices[i])); } } return newMatrix; } setSubMatrix(matrix, startRow, startColumn) { matrix = Matrix.checkMatrix(matrix); if (matrix.isEmpty()) { return this; } let endRow = startRow + matrix.rows - 1; let endColumn = startColumn + matrix.columns - 1; checkRange(this, startRow, endRow, startColumn, endColumn); for (let i = 0; i < matrix.rows; i++) { for (let j = 0; j < matrix.columns; j++) { this.set(startRow + i, startColumn + j, matrix.get(i, j)); } } return this; } selection(rowIndices, columnIndices) { checkRowIndices(this, rowIndices); checkColumnIndices(this, columnIndices); let newMatrix = new Matrix(rowIndices.length, columnIndices.length); for (let i = 0; i < rowIndices.length; i++) { let rowIndex = rowIndices[i]; for (let j = 0; j < columnIndices.length; j++) { let columnIndex = columnIndices[j]; newMatrix.set(i, j, this.get(rowIndex, columnIndex)); } } return newMatrix; } trace() { let min = Math.min(this.rows, this.columns); let trace = 0; for (let i = 0; i < min; i++) { trace += this.get(i, i); } return trace; } clone() { return this.constructor.copy(this, new Matrix(this.rows, this.columns)); } /** * @template {AbstractMatrix} M * @param {AbstractMatrix} from * @param {M} to * @return {M} */ static copy(from, to) { for (const [row, column, value] of from.entries()) { to.set(row, column, value); } return to; } sum(by) { switch (by) { case 'row': return sumByRow(this); case 'column': return sumByColumn(this); case undefined: return sumAll(this); default: throw new Error(`invalid option: ${by}`); } } product(by) { switch (by) { case 'row': return productByRow(this); case 'column': return productByColumn(this); case undefined: return productAll(this); default: throw new Error(`invalid option: ${by}`); } } mean(by) { const sum = this.sum(by); switch (by) { case 'row': { for (let i = 0; i < this.rows; i++) { sum[i] /= this.columns; } return sum; } case 'column': { for (let i = 0; i < this.columns; i++) { sum[i] /= this.rows; } return sum; } case undefined: return sum / this.size; default: throw new Error(`invalid option: ${by}`); } } variance(by, options = {}) { if (typeof by === 'object') { options = by; by = undefined; } if (typeof options !== 'object') { throw new TypeError('options must be an object'); } const { unbiased = true, mean = this.mean(by) } = options; if (typeof unbiased !== 'boolean') { throw new TypeError('unbiased must be a boolean'); } switch (by) { case 'row': { if (!isAnyArray(mean)) { throw new TypeError('mean must be an array'); } return varianceByRow(this, unbiased, mean); } case 'column': { if (!isAnyArray(mean)) { throw new TypeError('mean must be an array'); } return varianceByColumn(this, unbiased, mean); } case undefined: { if (typeof mean !== 'number') { throw new TypeError('mean must be a number'); } return varianceAll(this, unbiased, mean); } default: throw new Error(`invalid option: ${by}`); } } standardDeviation(by, options) { if (typeof by === 'object') { options = by; by = undefined; } const variance = this.variance(by, options); if (by === undefined) { return Math.sqrt(variance); } else { for (let i = 0; i < variance.length; i++) { variance[i] = Math.sqrt(variance[i]); } return variance; } } center(by, options = {}) { if (typeof by === 'object') { options = by; by = undefined; } if (typeof options !== 'object') { throw new TypeError('options must be an object'); } const { center = this.mean(by) } = options; switch (by) { case 'row': { if (!isAnyArray(center)) { throw new TypeError('center must be an array'); } centerByRow(this, center); return this; } case 'column': { if (!isAnyArray(center)) { throw new TypeError('center must be an array'); } centerByColumn(this, center); return this; } case undefined: { if (typeof center !== 'number') { throw new TypeError('center must be a number'); } centerAll(this, center); return this; } default: throw new Error(`invalid option: ${by}`); } } scale(by, options = {}) { if (typeof by === 'object') { options = by; by = undefined; } if (typeof options !== 'object') { throw new TypeError('options must be an object'); } let scale = options.scale; switch (by) { case 'row': { if (scale === undefined) { scale = getScaleByRow(this); } else if (!isAnyArray(scale)) { throw new TypeError('scale must be an array'); } scaleByRow(this, scale); return this; } case 'column': { if (scale === undefined) { scale = getScaleByColumn(this); } else if (!isAnyArray(scale)) { throw new TypeError('scale must be an array'); } scaleByColumn(this, scale); return this; } case undefined: { if (scale === undefined) { scale = getScaleAll(this); } else if (typeof scale !== 'number') { throw new TypeError('scale must be a number'); } scaleAll(this, scale); return this; } default: throw new Error(`invalid option: ${by}`); } } toString(options) { return inspectMatrixWithOptions(this, options); } [Symbol.iterator]() { return this.entries(); } /** * iterator from left to right, from top to bottom * yield [row, column, value] * @returns {Generator<[number, number, number], void, void>} */ *entries() { for (let row = 0; row < this.rows; row++) { for (let col = 0; col < this.columns; col++) { yield [row, col, this.get(row, col)]; } } } /** * iterator from left to right, from top to bottom * yield value * @returns {Generator<number, void, void>} */ *values() { for (let row = 0; row < this.rows; row++) { for (let col = 0; col < this.columns; col++) { yield this.get(row, col); } } } } AbstractMatrix.prototype.klass = 'Matrix'; if (typeof Symbol !== 'undefined') { AbstractMatrix.prototype[Symbol.for('nodejs.util.inspect.custom')] = inspectMatrix; } function compareNumbers(a, b) { return a - b; } function isArrayOfNumbers(array) { return array.every((element) => { return typeof element === 'number'; }); } // Synonyms AbstractMatrix.random = AbstractMatrix.rand; AbstractMatrix.randomInt = AbstractMatrix.randInt; AbstractMatrix.diagonal = AbstractMatrix.diag; AbstractMatrix.prototype.diagonal = AbstractMatrix.prototype.diag; AbstractMatrix.identity = AbstractMatrix.eye; AbstractMatrix.prototype.negate = AbstractMatrix.prototype.neg; AbstractMatrix.prototype.tensorProduct = AbstractMatrix.prototype.kroneckerProduct; export default class Matrix extends AbstractMatrix { /** * @type {Float64Array[]} */ data; /** * Init an empty matrix * @param {number} nRows * @param {number} nColumns */ #initData(nRows, nColumns) { this.data = []; if (Number.isInteger(nColumns) && nColumns >= 0) { for (let i = 0; i < nRows; i++) { this.data.push(new Float64Array(nColumns)); } } else { throw new TypeError('nColumns must be a positive integer'); } this.rows = nRows; this.columns = nColumns; } constructor(nRows, nColumns) { super(); if (Matrix.isMatrix(nRows)) { this.#initData(nRows.rows, nRows.columns); Matrix.copy(nRows, this); } else if (Number.isInteger(nRows) && nRows >= 0) { this.#initData(nRows, nColumns); } else if (isAnyArray(nRows)) { // Copy the values from the 2D array const arrayData = nRows; nRows = arrayData.length; nColumns = nRows ? arrayData[0].length : 0; if (typeof nColumns !== 'number') { throw new TypeError( 'Data must be a 2D array with at least one element', ); } this.data = []; for (let i = 0; i < nRows; i++) { if (arrayData[i].length !== nColumns) { throw new RangeError('Inconsistent array dimensions'); } if (!isArrayOfNumbers(arrayData[i])) { throw new TypeError('Input data contains non-numeric values'); } this.data.push(Float64Array.from(arrayData[i])); } this.rows = nRows; this.columns = nColumns; } else { throw new TypeError( 'First argument must be a positive number or an array', ); } } set(rowIndex, columnIndex, value) { this.data[rowIndex][columnIndex] = value; return this; } get(rowIndex, columnIndex) { return this.data[rowIndex][columnIndex]; } removeRow(index) { checkRowIndex(this, index); this.data.splice(index, 1); this.rows -= 1; return this; } addRow(index, array) { if (array === undefined) { array = index; index = this.rows; } checkRowIndex(this, index, true); array = Float64Array.from(checkRowVector(this, array)); this.data.splice(index, 0, array); this.rows += 1; return this; } removeColumn(index) { checkColumnIndex(this, index); for (let i = 0; i < this.rows; i++) { const newRow = new Float64Array(this.columns - 1); for (let j = 0; j < index; j++) { newRow[j] = this.data[i][j]; } for (let j = index + 1; j < this.columns; j++) { newRow[j - 1] = this.data[i][j]; } this.data[i] = newRow; } this.columns -= 1; return this; } addColumn(index, array) { if (typeof array === 'undefined') { array = index; index = this.columns; } checkColumnIndex(this, index, true); array = checkColumnVector(this, array); for (let i = 0; i < this.rows; i++) { const newRow = new Float64Array(this.columns + 1); let j = 0; for (; j < index; j++) { newRow[j] = this.data[i][j]; } newRow[j++] = array[i]; for (; j < this.columns + 1; j++) { newRow[j] = this.data[i][j - 1]; } this.data[i] = newRow; } this.columns += 1; return this; } } installMathOperations(AbstractMatrix, Matrix);