ml-matrix
Version:
Matrix manipulation and computation library
528 lines (483 loc) • 13.1 kB
JavaScript
import Matrix from '../matrix';
import WrapperMatrix2D from '../wrap/WrapperMatrix2D';
import { hypotenuse } from './util';
export default class SingularValueDecomposition {
constructor(value, options = {}) {
value = WrapperMatrix2D.checkMatrix(value);
if (value.isEmpty()) {
throw new Error('Matrix must be non-empty');
}
let m = value.rows;
let n = value.columns;
const {
computeLeftSingularVectors = true,
computeRightSingularVectors = true,
autoTranspose = false,
} = options;
let wantu = Boolean(computeLeftSingularVectors);
let wantv = Boolean(computeRightSingularVectors);
let swapped = false;
let a;
if (m < n) {
if (!autoTranspose) {
a = value.clone();
// eslint-disable-next-line no-console
console.warn(
'Computing SVD on a matrix with more columns than rows. Consider enabling autoTranspose',
);
} else {
a = value.transpose();
m = a.rows;
n = a.columns;
swapped = true;
let aux = wantu;
wantu = wantv;
wantv = aux;
}
} else {
a = value.clone();
}
let nu = Math.min(m, n);
let ni = Math.min(m + 1, n);
let s = new Float64Array(ni);
let U = new Matrix(m, nu);
let V = new Matrix(n, n);
let e = new Float64Array(n);
let work = new Float64Array(m);
let si = new Float64Array(ni);
for (let i = 0; i < ni; i++) si[i] = i;
let nct = Math.min(m - 1, n);
let nrt = Math.max(0, Math.min(n - 2, m));
let mrc = Math.max(nct, nrt);
for (let k = 0; k < mrc; k++) {
if (k < nct) {
s[k] = 0;
for (let i = k; i < m; i++) {
s[k] = hypotenuse(s[k], a.get(i, k));
}
if (s[k] !== 0) {
if (a.get(k, k) < 0) {
s[k] = -s[k];
}
for (let i = k; i < m; i++) {
a.set(i, k, a.get(i, k) / s[k]);
}
a.set(k, k, a.get(k, k) + 1);
}
s[k] = -s[k];
}
for (let j = k + 1; j < n; j++) {
if (k < nct && s[k] !== 0) {
let t = 0;
for (let i = k; i < m; i++) {
t += a.get(i, k) * a.get(i, j);
}
t = -t / a.get(k, k);
for (let i = k; i < m; i++) {
a.set(i, j, a.get(i, j) + t * a.get(i, k));
}
}
e[j] = a.get(k, j);
}
if (wantu && k < nct) {
for (let i = k; i < m; i++) {
U.set(i, k, a.get(i, k));
}
}
if (k < nrt) {
e[k] = 0;
for (let i = k + 1; i < n; i++) {
e[k] = hypotenuse(e[k], e[i]);
}
if (e[k] !== 0) {
if (e[k + 1] < 0) {
e[k] = 0 - e[k];
}
for (let i = k + 1; i < n; i++) {
e[i] /= e[k];
}
e[k + 1] += 1;
}
e[k] = -e[k];
if (k + 1 < m && e[k] !== 0) {
for (let i = k + 1; i < m; i++) {
work[i] = 0;
}
for (let i = k + 1; i < m; i++) {
for (let j = k + 1; j < n; j++) {
work[i] += e[j] * a.get(i, j);
}
}
for (let j = k + 1; j < n; j++) {
let t = -e[j] / e[k + 1];
for (let i = k + 1; i < m; i++) {
a.set(i, j, a.get(i, j) + t * work[i]);
}
}
}
if (wantv) {
for (let i = k + 1; i < n; i++) {
V.set(i, k, e[i]);
}
}
}
}
let p = Math.min(n, m + 1);
if (nct < n) {
s[nct] = a.get(nct, nct);
}
if (m < p) {
s[p - 1] = 0;
}
if (nrt + 1 < p) {
e[nrt] = a.get(nrt, p - 1);
}
e[p - 1] = 0;
if (wantu) {
for (let j = nct; j < nu; j++) {
for (let i = 0; i < m; i++) {
U.set(i, j, 0);
}
U.set(j, j, 1);
}
for (let k = nct - 1; k >= 0; k--) {
if (s[k] !== 0) {
for (let j = k + 1; j < nu; j++) {
let t = 0;
for (let i = k; i < m; i++) {
t += U.get(i, k) * U.get(i, j);
}
t = -t / U.get(k, k);
for (let i = k; i < m; i++) {
U.set(i, j, U.get(i, j) + t * U.get(i, k));
}
}
for (let i = k; i < m; i++) {
U.set(i, k, -U.get(i, k));
}
U.set(k, k, 1 + U.get(k, k));
for (let i = 0; i < k - 1; i++) {
U.set(i, k, 0);
}
} else {
for (let i = 0; i < m; i++) {
U.set(i, k, 0);
}
U.set(k, k, 1);
}
}
}
if (wantv) {
for (let k = n - 1; k >= 0; k--) {
if (k < nrt && e[k] !== 0) {
for (let j = k + 1; j < n; j++) {
let t = 0;
for (let i = k + 1; i < n; i++) {
t += V.get(i, k) * V.get(i, j);
}
t = -t / V.get(k + 1, k);
for (let i = k + 1; i < n; i++) {
V.set(i, j, V.get(i, j) + t * V.get(i, k));
}
}
}
for (let i = 0; i < n; i++) {
V.set(i, k, 0);
}
V.set(k, k, 1);
}
}
let pp = p - 1;
let iter = 0;
let eps = Number.EPSILON;
while (p > 0) {
let k, kase;
for (k = p - 2; k >= -1; k--) {
if (k === -1) {
break;
}
const alpha =
Number.MIN_VALUE + eps * Math.abs(s[k] + Math.abs(s[k + 1]));
if (Math.abs(e[k]) <= alpha || Number.isNaN(e[k])) {
e[k] = 0;
break;
}
}
if (k === p - 2) {
kase = 4;
} else {
let ks;
for (ks = p - 1; ks >= k; ks--) {
if (ks === k) {
break;
}
let t =
(ks !== p ? Math.abs(e[ks]) : 0) +
(ks !== k + 1 ? Math.abs(e[ks - 1]) : 0);
if (Math.abs(s[ks]) <= eps * t) {
s[ks] = 0;
break;
}
}
if (ks === k) {
kase = 3;
} else if (ks === p - 1) {
kase = 1;
} else {
kase = 2;
k = ks;
}
}
k++;
switch (kase) {
case 1: {
let f = e[p - 2];
e[p - 2] = 0;
for (let j = p - 2; j >= k; j--) {
let t = hypotenuse(s[j], f);
let cs = s[j] / t;
let sn = f / t;
s[j] = t;
if (j !== k) {
f = -sn * e[j - 1];
e[j - 1] = cs * e[j - 1];
}
if (wantv) {
for (let i = 0; i < n; i++) {
t = cs * V.get(i, j) + sn * V.get(i, p - 1);
V.set(i, p - 1, -sn * V.get(i, j) + cs * V.get(i, p - 1));
V.set(i, j, t);
}
}
}
break;
}
case 2: {
let f = e[k - 1];
e[k - 1] = 0;
for (let j = k; j < p; j++) {
let t = hypotenuse(s[j], f);
let cs = s[j] / t;
let sn = f / t;
s[j] = t;
f = -sn * e[j];
e[j] = cs * e[j];
if (wantu) {
for (let i = 0; i < m; i++) {
t = cs * U.get(i, j) + sn * U.get(i, k - 1);
U.set(i, k - 1, -sn * U.get(i, j) + cs * U.get(i, k - 1));
U.set(i, j, t);
}
}
}
break;
}
case 3: {
const scale = Math.max(
Math.abs(s[p - 1]),
Math.abs(s[p - 2]),
Math.abs(e[p - 2]),
Math.abs(s[k]),
Math.abs(e[k]),
);
const sp = s[p - 1] / scale;
const spm1 = s[p - 2] / scale;
const epm1 = e[p - 2] / scale;
const sk = s[k] / scale;
const ek = e[k] / scale;
const b = ((spm1 + sp) * (spm1 - sp) + epm1 * epm1) / 2;
const c = sp * epm1 * (sp * epm1);
let shift = 0;
if (b !== 0 || c !== 0) {
if (b < 0) {
shift = 0 - Math.sqrt(b * b + c);
} else {
shift = Math.sqrt(b * b + c);
}
shift = c / (b + shift);
}
let f = (sk + sp) * (sk - sp) + shift;
let g = sk * ek;
for (let j = k; j < p - 1; j++) {
let t = hypotenuse(f, g);
if (t === 0) t = Number.MIN_VALUE;
let cs = f / t;
let sn = g / t;
if (j !== k) {
e[j - 1] = t;
}
f = cs * s[j] + sn * e[j];
e[j] = cs * e[j] - sn * s[j];
g = sn * s[j + 1];
s[j + 1] = cs * s[j + 1];
if (wantv) {
for (let i = 0; i < n; i++) {
t = cs * V.get(i, j) + sn * V.get(i, j + 1);
V.set(i, j + 1, -sn * V.get(i, j) + cs * V.get(i, j + 1));
V.set(i, j, t);
}
}
t = hypotenuse(f, g);
if (t === 0) t = Number.MIN_VALUE;
cs = f / t;
sn = g / t;
s[j] = t;
f = cs * e[j] + sn * s[j + 1];
s[j + 1] = -sn * e[j] + cs * s[j + 1];
g = sn * e[j + 1];
e[j + 1] = cs * e[j + 1];
if (wantu && j < m - 1) {
for (let i = 0; i < m; i++) {
t = cs * U.get(i, j) + sn * U.get(i, j + 1);
U.set(i, j + 1, -sn * U.get(i, j) + cs * U.get(i, j + 1));
U.set(i, j, t);
}
}
}
e[p - 2] = f;
iter = iter + 1;
break;
}
case 4: {
if (s[k] <= 0) {
s[k] = s[k] < 0 ? -s[k] : 0;
if (wantv) {
for (let i = 0; i <= pp; i++) {
V.set(i, k, -V.get(i, k));
}
}
}
while (k < pp) {
if (s[k] >= s[k + 1]) {
break;
}
let t = s[k];
s[k] = s[k + 1];
s[k + 1] = t;
if (wantv && k < n - 1) {
for (let i = 0; i < n; i++) {
t = V.get(i, k + 1);
V.set(i, k + 1, V.get(i, k));
V.set(i, k, t);
}
}
if (wantu && k < m - 1) {
for (let i = 0; i < m; i++) {
t = U.get(i, k + 1);
U.set(i, k + 1, U.get(i, k));
U.set(i, k, t);
}
}
k++;
}
iter = 0;
p--;
break;
}
// no default
}
}
if (swapped) {
let tmp = V;
V = U;
U = tmp;
}
this.m = m;
this.n = n;
this.s = s;
this.U = U;
this.V = V;
}
solve(value) {
let Y = value;
let e = this.threshold;
let scols = this.s.length;
let Ls = Matrix.zeros(scols, scols);
for (let i = 0; i < scols; i++) {
if (Math.abs(this.s[i]) <= e) {
Ls.set(i, i, 0);
} else {
Ls.set(i, i, 1 / this.s[i]);
}
}
let U = this.U;
let V = this.rightSingularVectors;
let VL = V.mmul(Ls);
let vrows = V.rows;
let urows = U.rows;
let VLU = Matrix.zeros(vrows, urows);
for (let i = 0; i < vrows; i++) {
for (let j = 0; j < urows; j++) {
let sum = 0;
for (let k = 0; k < scols; k++) {
sum += VL.get(i, k) * U.get(j, k);
}
VLU.set(i, j, sum);
}
}
return VLU.mmul(Y);
}
solveForDiagonal(value) {
return this.solve(Matrix.diag(value));
}
inverse() {
let V = this.V;
let e = this.threshold;
let vrows = V.rows;
let vcols = V.columns;
let X = new Matrix(vrows, this.s.length);
for (let i = 0; i < vrows; i++) {
for (let j = 0; j < vcols; j++) {
if (Math.abs(this.s[j]) > e) {
X.set(i, j, V.get(i, j) / this.s[j]);
}
}
}
let U = this.U;
let urows = U.rows;
let ucols = U.columns;
let Y = new Matrix(vrows, urows);
for (let i = 0; i < vrows; i++) {
for (let j = 0; j < urows; j++) {
let sum = 0;
for (let k = 0; k < ucols; k++) {
sum += X.get(i, k) * U.get(j, k);
}
Y.set(i, j, sum);
}
}
return Y;
}
get condition() {
return this.s[0] / this.s[Math.min(this.m, this.n) - 1];
}
get norm2() {
return this.s[0];
}
get rank() {
let tol = Math.max(this.m, this.n) * this.s[0] * Number.EPSILON;
let r = 0;
let s = this.s;
for (let i = 0, ii = s.length; i < ii; i++) {
if (s[i] > tol) {
r++;
}
}
return r;
}
get diagonal() {
return Array.from(this.s);
}
get threshold() {
return (Number.EPSILON / 2) * Math.max(this.m, this.n) * this.s[0];
}
get leftSingularVectors() {
return this.U;
}
get rightSingularVectors() {
return this.V;
}
get diagonalMatrix() {
return Matrix.diag(this.s);
}
}