UNPKG

ml-matrix

Version:

Matrix manipulation and computation library

528 lines (483 loc) 13.1 kB
import Matrix from '../matrix'; import WrapperMatrix2D from '../wrap/WrapperMatrix2D'; import { hypotenuse } from './util'; export default class SingularValueDecomposition { constructor(value, options = {}) { value = WrapperMatrix2D.checkMatrix(value); if (value.isEmpty()) { throw new Error('Matrix must be non-empty'); } let m = value.rows; let n = value.columns; const { computeLeftSingularVectors = true, computeRightSingularVectors = true, autoTranspose = false, } = options; let wantu = Boolean(computeLeftSingularVectors); let wantv = Boolean(computeRightSingularVectors); let swapped = false; let a; if (m < n) { if (!autoTranspose) { a = value.clone(); // eslint-disable-next-line no-console console.warn( 'Computing SVD on a matrix with more columns than rows. Consider enabling autoTranspose', ); } else { a = value.transpose(); m = a.rows; n = a.columns; swapped = true; let aux = wantu; wantu = wantv; wantv = aux; } } else { a = value.clone(); } let nu = Math.min(m, n); let ni = Math.min(m + 1, n); let s = new Float64Array(ni); let U = new Matrix(m, nu); let V = new Matrix(n, n); let e = new Float64Array(n); let work = new Float64Array(m); let si = new Float64Array(ni); for (let i = 0; i < ni; i++) si[i] = i; let nct = Math.min(m - 1, n); let nrt = Math.max(0, Math.min(n - 2, m)); let mrc = Math.max(nct, nrt); for (let k = 0; k < mrc; k++) { if (k < nct) { s[k] = 0; for (let i = k; i < m; i++) { s[k] = hypotenuse(s[k], a.get(i, k)); } if (s[k] !== 0) { if (a.get(k, k) < 0) { s[k] = -s[k]; } for (let i = k; i < m; i++) { a.set(i, k, a.get(i, k) / s[k]); } a.set(k, k, a.get(k, k) + 1); } s[k] = -s[k]; } for (let j = k + 1; j < n; j++) { if (k < nct && s[k] !== 0) { let t = 0; for (let i = k; i < m; i++) { t += a.get(i, k) * a.get(i, j); } t = -t / a.get(k, k); for (let i = k; i < m; i++) { a.set(i, j, a.get(i, j) + t * a.get(i, k)); } } e[j] = a.get(k, j); } if (wantu && k < nct) { for (let i = k; i < m; i++) { U.set(i, k, a.get(i, k)); } } if (k < nrt) { e[k] = 0; for (let i = k + 1; i < n; i++) { e[k] = hypotenuse(e[k], e[i]); } if (e[k] !== 0) { if (e[k + 1] < 0) { e[k] = 0 - e[k]; } for (let i = k + 1; i < n; i++) { e[i] /= e[k]; } e[k + 1] += 1; } e[k] = -e[k]; if (k + 1 < m && e[k] !== 0) { for (let i = k + 1; i < m; i++) { work[i] = 0; } for (let i = k + 1; i < m; i++) { for (let j = k + 1; j < n; j++) { work[i] += e[j] * a.get(i, j); } } for (let j = k + 1; j < n; j++) { let t = -e[j] / e[k + 1]; for (let i = k + 1; i < m; i++) { a.set(i, j, a.get(i, j) + t * work[i]); } } } if (wantv) { for (let i = k + 1; i < n; i++) { V.set(i, k, e[i]); } } } } let p = Math.min(n, m + 1); if (nct < n) { s[nct] = a.get(nct, nct); } if (m < p) { s[p - 1] = 0; } if (nrt + 1 < p) { e[nrt] = a.get(nrt, p - 1); } e[p - 1] = 0; if (wantu) { for (let j = nct; j < nu; j++) { for (let i = 0; i < m; i++) { U.set(i, j, 0); } U.set(j, j, 1); } for (let k = nct - 1; k >= 0; k--) { if (s[k] !== 0) { for (let j = k + 1; j < nu; j++) { let t = 0; for (let i = k; i < m; i++) { t += U.get(i, k) * U.get(i, j); } t = -t / U.get(k, k); for (let i = k; i < m; i++) { U.set(i, j, U.get(i, j) + t * U.get(i, k)); } } for (let i = k; i < m; i++) { U.set(i, k, -U.get(i, k)); } U.set(k, k, 1 + U.get(k, k)); for (let i = 0; i < k - 1; i++) { U.set(i, k, 0); } } else { for (let i = 0; i < m; i++) { U.set(i, k, 0); } U.set(k, k, 1); } } } if (wantv) { for (let k = n - 1; k >= 0; k--) { if (k < nrt && e[k] !== 0) { for (let j = k + 1; j < n; j++) { let t = 0; for (let i = k + 1; i < n; i++) { t += V.get(i, k) * V.get(i, j); } t = -t / V.get(k + 1, k); for (let i = k + 1; i < n; i++) { V.set(i, j, V.get(i, j) + t * V.get(i, k)); } } } for (let i = 0; i < n; i++) { V.set(i, k, 0); } V.set(k, k, 1); } } let pp = p - 1; let iter = 0; let eps = Number.EPSILON; while (p > 0) { let k, kase; for (k = p - 2; k >= -1; k--) { if (k === -1) { break; } const alpha = Number.MIN_VALUE + eps * Math.abs(s[k] + Math.abs(s[k + 1])); if (Math.abs(e[k]) <= alpha || Number.isNaN(e[k])) { e[k] = 0; break; } } if (k === p - 2) { kase = 4; } else { let ks; for (ks = p - 1; ks >= k; ks--) { if (ks === k) { break; } let t = (ks !== p ? Math.abs(e[ks]) : 0) + (ks !== k + 1 ? Math.abs(e[ks - 1]) : 0); if (Math.abs(s[ks]) <= eps * t) { s[ks] = 0; break; } } if (ks === k) { kase = 3; } else if (ks === p - 1) { kase = 1; } else { kase = 2; k = ks; } } k++; switch (kase) { case 1: { let f = e[p - 2]; e[p - 2] = 0; for (let j = p - 2; j >= k; j--) { let t = hypotenuse(s[j], f); let cs = s[j] / t; let sn = f / t; s[j] = t; if (j !== k) { f = -sn * e[j - 1]; e[j - 1] = cs * e[j - 1]; } if (wantv) { for (let i = 0; i < n; i++) { t = cs * V.get(i, j) + sn * V.get(i, p - 1); V.set(i, p - 1, -sn * V.get(i, j) + cs * V.get(i, p - 1)); V.set(i, j, t); } } } break; } case 2: { let f = e[k - 1]; e[k - 1] = 0; for (let j = k; j < p; j++) { let t = hypotenuse(s[j], f); let cs = s[j] / t; let sn = f / t; s[j] = t; f = -sn * e[j]; e[j] = cs * e[j]; if (wantu) { for (let i = 0; i < m; i++) { t = cs * U.get(i, j) + sn * U.get(i, k - 1); U.set(i, k - 1, -sn * U.get(i, j) + cs * U.get(i, k - 1)); U.set(i, j, t); } } } break; } case 3: { const scale = Math.max( Math.abs(s[p - 1]), Math.abs(s[p - 2]), Math.abs(e[p - 2]), Math.abs(s[k]), Math.abs(e[k]), ); const sp = s[p - 1] / scale; const spm1 = s[p - 2] / scale; const epm1 = e[p - 2] / scale; const sk = s[k] / scale; const ek = e[k] / scale; const b = ((spm1 + sp) * (spm1 - sp) + epm1 * epm1) / 2; const c = sp * epm1 * (sp * epm1); let shift = 0; if (b !== 0 || c !== 0) { if (b < 0) { shift = 0 - Math.sqrt(b * b + c); } else { shift = Math.sqrt(b * b + c); } shift = c / (b + shift); } let f = (sk + sp) * (sk - sp) + shift; let g = sk * ek; for (let j = k; j < p - 1; j++) { let t = hypotenuse(f, g); if (t === 0) t = Number.MIN_VALUE; let cs = f / t; let sn = g / t; if (j !== k) { e[j - 1] = t; } f = cs * s[j] + sn * e[j]; e[j] = cs * e[j] - sn * s[j]; g = sn * s[j + 1]; s[j + 1] = cs * s[j + 1]; if (wantv) { for (let i = 0; i < n; i++) { t = cs * V.get(i, j) + sn * V.get(i, j + 1); V.set(i, j + 1, -sn * V.get(i, j) + cs * V.get(i, j + 1)); V.set(i, j, t); } } t = hypotenuse(f, g); if (t === 0) t = Number.MIN_VALUE; cs = f / t; sn = g / t; s[j] = t; f = cs * e[j] + sn * s[j + 1]; s[j + 1] = -sn * e[j] + cs * s[j + 1]; g = sn * e[j + 1]; e[j + 1] = cs * e[j + 1]; if (wantu && j < m - 1) { for (let i = 0; i < m; i++) { t = cs * U.get(i, j) + sn * U.get(i, j + 1); U.set(i, j + 1, -sn * U.get(i, j) + cs * U.get(i, j + 1)); U.set(i, j, t); } } } e[p - 2] = f; iter = iter + 1; break; } case 4: { if (s[k] <= 0) { s[k] = s[k] < 0 ? -s[k] : 0; if (wantv) { for (let i = 0; i <= pp; i++) { V.set(i, k, -V.get(i, k)); } } } while (k < pp) { if (s[k] >= s[k + 1]) { break; } let t = s[k]; s[k] = s[k + 1]; s[k + 1] = t; if (wantv && k < n - 1) { for (let i = 0; i < n; i++) { t = V.get(i, k + 1); V.set(i, k + 1, V.get(i, k)); V.set(i, k, t); } } if (wantu && k < m - 1) { for (let i = 0; i < m; i++) { t = U.get(i, k + 1); U.set(i, k + 1, U.get(i, k)); U.set(i, k, t); } } k++; } iter = 0; p--; break; } // no default } } if (swapped) { let tmp = V; V = U; U = tmp; } this.m = m; this.n = n; this.s = s; this.U = U; this.V = V; } solve(value) { let Y = value; let e = this.threshold; let scols = this.s.length; let Ls = Matrix.zeros(scols, scols); for (let i = 0; i < scols; i++) { if (Math.abs(this.s[i]) <= e) { Ls.set(i, i, 0); } else { Ls.set(i, i, 1 / this.s[i]); } } let U = this.U; let V = this.rightSingularVectors; let VL = V.mmul(Ls); let vrows = V.rows; let urows = U.rows; let VLU = Matrix.zeros(vrows, urows); for (let i = 0; i < vrows; i++) { for (let j = 0; j < urows; j++) { let sum = 0; for (let k = 0; k < scols; k++) { sum += VL.get(i, k) * U.get(j, k); } VLU.set(i, j, sum); } } return VLU.mmul(Y); } solveForDiagonal(value) { return this.solve(Matrix.diag(value)); } inverse() { let V = this.V; let e = this.threshold; let vrows = V.rows; let vcols = V.columns; let X = new Matrix(vrows, this.s.length); for (let i = 0; i < vrows; i++) { for (let j = 0; j < vcols; j++) { if (Math.abs(this.s[j]) > e) { X.set(i, j, V.get(i, j) / this.s[j]); } } } let U = this.U; let urows = U.rows; let ucols = U.columns; let Y = new Matrix(vrows, urows); for (let i = 0; i < vrows; i++) { for (let j = 0; j < urows; j++) { let sum = 0; for (let k = 0; k < ucols; k++) { sum += X.get(i, k) * U.get(j, k); } Y.set(i, j, sum); } } return Y; } get condition() { return this.s[0] / this.s[Math.min(this.m, this.n) - 1]; } get norm2() { return this.s[0]; } get rank() { let tol = Math.max(this.m, this.n) * this.s[0] * Number.EPSILON; let r = 0; let s = this.s; for (let i = 0, ii = s.length; i < ii; i++) { if (s[i] > tol) { r++; } } return r; } get diagonal() { return Array.from(this.s); } get threshold() { return (Number.EPSILON / 2) * Math.max(this.m, this.n) * this.s[0]; } get leftSingularVectors() { return this.U; } get rightSingularVectors() { return this.V; } get diagonalMatrix() { return Matrix.diag(this.s); } }