@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
368 lines (339 loc) • 12.3 kB
text/typescript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Linear algebra ops.
*/
import {ENGINE} from '../engine';
import {dispose} from '../globals';
import {Tensor, Tensor1D, Tensor2D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {assert} from '../util';
import {squeeze, stack, unstack} from './array_ops';
import {sub} from './binary_ops';
import {split} from './concat_split';
import {eye} from './eye';
import {logicalAnd, where} from './logical_ops';
import {norm} from './norm';
import {op} from './operation';
import {sum} from './reduction_ops';
import {range, scalar, tensor2d, zeros} from './tensor_ops';
/**
* Copy a tensor setting everything outside a central band in each innermost
* matrix to zero.
*
* The band part is computed as follows: Assume input has `k` dimensions
* `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where
* `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
* The indicator function
* `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))`
* `&& (num_upper < 0 || (n-m) <= num_upper)`
*
* ```js
* const x = tf.tensor2d([[ 0, 1, 2, 3],
* [-1, 0, 1, 2],
* [-2, -1, 0, 1],
* [-3, -2, -1, 0]]);
* let y = tf.linalg.bandPart(x, 1, -1);
* y.print(); // [[ 0, 1, 2, 3],
* // [-1, 0, 1, 2],
* // [ 0, -1, 0, 1],
* // [ 0, 0 , -1, 0]]
* let z = tf.linalg.bandPart(x, 2, 1);
* z.print(); // [[ 0, 1, 0, 0],
* // [-1, 0, 1, 0],
* // [-2, -1, 0, 1],
* // [ 0, -2, -1, 0]]
* ```
*
* @param x Rank `k` tensor
* @param numLower Number of subdiagonals to keep.
* If negative, keep entire lower triangle.
* @param numUpper Number of subdiagonals to keep.
* If negative, keep entire upper triangle.
* @returns Rank `k` tensor of the same shape as input.
* The extracted banded tensor.
*/
/**
* @doc {heading:'Operations',
* subheading:'Linear Algebra',
* namespace:'linalg'}
*/
function bandPart_<T extends Tensor>(
a: T|TensorLike, numLower: number, numUpper: number): T {
if (numLower % 1 !== 0) {
throw new Error(
`bandPart(): numLower must be an integer, got ${numLower}.`);
}
if (numUpper % 1 !== 0) {
throw new Error(
`bandPart(): numUpper must be an integer, got ${numUpper}.`);
}
const $a = convertToTensor(a, 'a', 'bandPart');
if ($a.rank < 2) {
throw new Error(`bandPart(): Rank must be at least 2, got ${$a.rank}.`);
}
const shape = $a.shape, [M, N] = $a.shape.slice(-2);
if (!(numLower <= M)) {
throw new Error(
`bandPart(): numLower (${numLower})` +
` must not be greater than the number of rows (${M}).`);
}
if (!(numUpper <= N)) {
throw new Error(
`bandPart(): numUpper (${numUpper})` +
` must not be greater than the number of columns (${N}).`);
}
if (numLower < 0) {
numLower = M;
}
if (numUpper < 0) {
numUpper = N;
}
const i = range(0, M, 1, 'int32').reshape([-1, 1]),
j = range(0, N, 1, 'int32'), ij = sub(i, j);
const inBand = logicalAnd(
ij.lessEqual(scalar(+numLower, 'int32')),
ij.greaterEqual(scalar(-numUpper, 'int32')));
const zero = zeros([M, N], $a.dtype);
return stack(unstack($a.reshape([-1, M, N]))
.map(mat => where(inBand, mat, zero)))
.reshape(shape) as T;
}
/**
* Gram-Schmidt orthogonalization.
*
* ```js
* const x = tf.tensor2d([[1, 2], [3, 4]]);
* let y = tf.linalg.gramSchmidt(x);
* y.print();
* console.log('Othogonalized:');
* y.dot(y.transpose()).print(); // should be nearly the identity matrix.
* console.log('First row direction maintained:');
* const data = await y.array();
* console.log(data[0][1] / data[0][0]); // should be nearly 2.
* ```
*
* @param xs The vectors to be orthogonalized, in one of the two following
* formats:
* - An Array of `tf.Tensor1D`.
* - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows
* of `xs`.
* In each case, all the vectors must have the same length and the length
* must be greater than or equal to the number of vectors.
* @returns The orthogonalized and normalized vectors or matrix.
* Orthogonalization means that the vectors or the rows of the matrix
* are orthogonal (zero inner products). Normalization means that each
* vector or each row of the matrix has an L2 norm that equals `1`.
*/
/**
* @doc {heading:'Operations',
* subheading:'Linear Algebra',
* namespace:'linalg'}
*/
function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D {
let inputIsTensor2D: boolean;
if (Array.isArray(xs)) {
inputIsTensor2D = false;
assert(
xs != null && xs.length > 0,
() => 'Gram-Schmidt process: input must not be null, undefined, or ' +
'empty');
const dim = xs[0].shape[0];
for (let i = 1; i < xs.length; ++i) {
assert(
xs[i].shape[0] === dim,
() =>
'Gram-Schmidt: Non-unique lengths found in the input vectors: ' +
`(${(xs as Tensor1D[])[i].shape[0]} vs. ${dim})`);
}
} else {
inputIsTensor2D = true;
xs = split(xs, xs.shape[0], 0).map(x => squeeze(x, [0]));
}
assert(
xs.length <= xs[0].shape[0],
() => `Gram-Schmidt: Number of vectors (${
(xs as Tensor1D[]).length}) exceeds ` +
`number of dimensions (${(xs as Tensor1D[])[0].shape[0]}).`);
const ys: Tensor1D[] = [];
const xs1d = xs;
for (let i = 0; i < xs.length; ++i) {
ys.push(ENGINE.tidy(() => {
let x = xs1d[i];
if (i > 0) {
for (let j = 0; j < i; ++j) {
const proj = sum(ys[j].mulStrict(x)).mul(ys[j]);
x = x.sub(proj);
}
}
return x.div(norm(x, 'euclidean'));
}));
}
if (inputIsTensor2D) {
return stack(ys, 0) as Tensor2D;
} else {
return ys;
}
}
/**
* Compute QR decomposition of m-by-n matrix using Householder transformation.
*
* Implementation based on
* [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf]
* (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf)
*
* ```js
* const a = tf.tensor2d([[1, 2], [3, 4]]);
* let [q, r] = tf.linalg.qr(a);
* console.log('Q');
* q.print();
* console.log('R');
* r.print();
* console.log('Orthogonalized');
* q.dot(q.transpose()).print() // should be nearly the identity matrix.
* console.log('Reconstructed');
* q.dot(r).print(); // should be nearly [[1, 2], [3, 4]];
* ```
*
* @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose
* it has the shape `[..., M, N]`.
* @param fullMatrices An optional boolean parameter. Defaults to `false`.
* If `true`, compute full-sized `Q`. If `false` (the default),
* compute only the leading N columns of `Q` and `R`.
* @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix,
* i.e., its columns all have unit norm and are mutually orthogonal.
* If `M >= N`,
* If `fullMatrices` is `false` (default),
* - `Q` has a shape of `[..., M, N]`,
* - `R` has a shape of `[..., N, N]`.
* If `fullMatrices` is `true` (default),
* - `Q` has a shape of `[..., M, M]`,
* - `R` has a shape of `[..., M, N]`.
* If `M < N`,
* - `Q` has a shape of `[..., M, M]`,
* - `R` has a shape of `[..., M, N]`.
* @throws If the rank of `x` is less than 2.
*/
/**
* @doc {heading:'Operations',
* subheading:'Linear Algebra',
* namespace:'linalg'}
*/
function qr_(x: Tensor, fullMatrices = false): [Tensor, Tensor] {
if (x.rank < 2) {
throw new Error(
`qr() requires input tensor to have a rank >= 2, but got rank ${
x.rank}`);
} else if (x.rank === 2) {
return qr2d(x as Tensor2D, fullMatrices);
} else {
// Rank > 2.
// TODO(cais): Below we split the input into individual 2D tensors,
// perform QR decomposition on them and then stack the results back
// together. We should explore whether this can be parallelized.
const outerDimsProd = x.shape.slice(0, x.shape.length - 2)
.reduce((value, prev) => value * prev);
const x2ds = unstack(
x.reshape([
outerDimsProd, x.shape[x.shape.length - 2],
x.shape[x.shape.length - 1]
]),
0);
const q2ds: Tensor2D[] = [];
const r2ds: Tensor2D[] = [];
x2ds.forEach(x2d => {
const [q2d, r2d] = qr2d(x2d as Tensor2D, fullMatrices);
q2ds.push(q2d);
r2ds.push(r2d);
});
const q = stack(q2ds, 0).reshape(x.shape);
const r = stack(r2ds, 0).reshape(x.shape);
return [q, r];
}
}
function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] {
return ENGINE.tidy(() => {
if (x.shape.length !== 2) {
throw new Error(
`qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`);
}
const m = x.shape[0];
const n = x.shape[1];
let q = eye(m); // Orthogonal transform so far.
let r = x.clone(); // Transformed matrix so far.
const one2D = tensor2d([[1]], [1, 1]);
let w: Tensor2D = one2D.clone();
const iters = m >= n ? n : m;
for (let j = 0; j < iters; ++j) {
// This tidy within the for-loop ensures we clean up temporary
// tensors as soon as they are no longer needed.
const rTemp = r;
const wTemp = w;
const qTemp = q;
[w, r, q] = ENGINE.tidy((): [Tensor2D, Tensor2D, Tensor2D] => {
// Find H = I - tau * w * w', to put zeros below R(j, j).
const rjEnd1 = r.slice([j, j], [m - j, 1]);
const normX = rjEnd1.norm();
const rjj = r.slice([j, j], [1, 1]);
// The sign() function returns 0 on 0, which causes division by zero.
const s = tensor2d([[-1]]).where(rjj.greater(0), tensor2d([[1]]));
const u1 = rjj.sub(s.mul(normX));
const wPre = rjEnd1.div(u1);
if (wPre.shape[0] === 1) {
w = one2D.clone();
} else {
w = one2D.concat(
wPre.slice([1, 0], [wPre.shape[0] - 1, wPre.shape[1]]) as
Tensor2D,
0);
}
const tau = s.matMul(u1).div(normX).neg() as Tensor2D;
// -- R := HR, Q := QH.
const rjEndAll = r.slice([j, 0], [m - j, n]);
const tauTimesW: Tensor2D = tau.mul(w);
const wT: Tensor2D = w.transpose();
if (j === 0) {
r = rjEndAll.sub(tauTimesW.matMul(wT.matMul(rjEndAll)));
} else {
const rTimesTau: Tensor2D =
rjEndAll.sub(tauTimesW.matMul(wT.matMul(rjEndAll)));
r = r.slice([0, 0], [j, n]).concat(rTimesTau, 0);
}
const tawTimesWT: Tensor2D = tauTimesW.transpose();
const qAllJEnd = q.slice([0, j], [m, q.shape[1] - j]);
if (j === 0) {
q = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tawTimesWT));
} else {
const qTimesTau: Tensor2D =
qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tawTimesWT));
q = q.slice([0, 0], [m, j]).concat(qTimesTau, 1);
}
return [w, r, q];
});
dispose([rTemp, wTemp, qTemp]);
}
if (!fullMatrices && m > n) {
q = q.slice([0, 0], [m, n]);
r = r.slice([0, 0], [n, n]);
}
return [q, r];
}) as [Tensor2D, Tensor2D];
}
export const bandPart = op({bandPart_});
export const gramSchmidt = op({gramSchmidt_});
export const qr = op({qr_});