@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
121 lines • 4.79 kB
JavaScript
;
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
Object.defineProperty(exports, "__esModule", { value: true });
var tensor_util_env_1 = require("../tensor_util_env");
var util_1 = require("../util");
var axis_util = require("./axis_util");
var operation_1 = require("./operation");
var tensor_ops_1 = require("./tensor_ops");
/**
* Computes the norm of scalar, vectors, and matrices.
* This function can compute several different vector norms (the 1-norm, the
* Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0)
* and matrix norms (Frobenius, 1-norm, and inf-norm).
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
*
* x.norm().print(); // or tf.norm(x)
* ```
*
* @param x The input array.
* @param ord Optional. Order of the norm. Supported norm types are
* following:
*
* | ord | norm for matrices | norm for vectors
* |------------|---------------------------|---------------------
* |'euclidean' |Frobenius norm |2-norm
* |'fro' |Frobenius norm |
* |Infinity |max(sum(abs(x), axis=1)) |max(abs(x))
* |-Infinity |min(sum(abs(x), axis=1)) |min(abs(x))
* |1 |max(sum(abs(x), axis=0)) |sum(abs(x))
* |2 | |sum(abs(x)^2)^1/2*
*
* @param axis Optional. If axis is null (the default), the input is
* considered a vector and a single vector norm is computed over the entire
* set of values in the Tensor, i.e. norm(x, ord) is equivalent
* to norm(x.reshape([-1]), ord). If axis is a integer, the input
* is considered a batch of vectors, and axis determines the axis in x
* over which to compute vector norms. If axis is a 2-tuple of integer it is
* considered a batch of matrices and axis determines the axes in NDArray
* over which to compute a matrix norm.
* @param keepDims Optional. If true, the norm have the same dimensionality
* as the input.
*/
/** @doc {heading: 'Operations', subheading: 'Matrices'} */
function norm_(x, ord, axis, keepDims) {
if (ord === void 0) { ord = 'euclidean'; }
if (axis === void 0) { axis = null; }
if (keepDims === void 0) { keepDims = false; }
x = tensor_util_env_1.convertToTensor(x, 'x', 'norm');
var norm = normImpl(x, ord, axis);
var keepDimsShape = norm.shape;
if (keepDims) {
var axes = util_1.parseAxisParam(axis, x.shape);
keepDimsShape = axis_util.expandShapeToKeepDim(norm.shape, axes);
}
return norm.reshape(keepDimsShape);
}
function normImpl(x, p, axis) {
if (axis === void 0) { axis = null; }
if (x.rank === 0) {
return x.abs();
}
// consider vector when no axis is specified
if (x.rank !== 1 && axis === null) {
return normImpl(x.reshape([-1]), p, axis);
}
// vector
if (x.rank === 1 || typeof axis === 'number' ||
Array.isArray(axis) && axis.length === 1) {
if (p === 1) {
return x.abs().sum(axis);
}
if (p === Infinity) {
return x.abs().max(axis);
}
if (p === -Infinity) {
return x.abs().min(axis);
}
if (p === 'euclidean' || p === 2) {
// norm(x, 2) = sum(abs(xi) ^ 2) ^ 1/2
return x.abs().pow(tensor_ops_1.scalar(2, 'int32')).sum(axis).sqrt();
}
throw new Error("Error in norm: invalid ord value: " + p);
}
// matrix (assumption axis[0] < axis[1])
if (Array.isArray(axis) && axis.length === 2) {
if (p === 1) {
return x.abs().sum(axis[0]).max(axis[1] - 1);
}
if (p === Infinity) {
return x.abs().sum(axis[1]).max(axis[0]);
}
if (p === -Infinity) {
return x.abs().sum(axis[1]).min(axis[0]);
}
if (p === 'fro' || p === 'euclidean') {
// norm(x) = sqrt(sum(pow(x, 2)))
return x.square().sum(axis).sqrt();
}
throw new Error("Error in norm: invalid ord value: " + p);
}
throw new Error("Error in norm: invalid axis: " + axis);
}
exports.norm = operation_1.op({ norm_: norm_ });
//# sourceMappingURL=norm.js.map