@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
206 lines (190 loc) • 7.29 kB
text/typescript
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {whereImpl} from '../backends/where_impl';
import {ENGINE} from '../engine';
import {Tensor, Tensor2D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {assert, assertShapesMatch} from '../util';
import {assertAndGetBroadcastShape} from './broadcast_util';
import {op} from './operation';
import {zerosLike} from './tensor_ops';
/**
* Returns the truth value of `NOT x` element-wise.
*
* ```js
* const a = tf.tensor1d([false, true], 'bool');
*
* a.logicalNot().print();
* ```
*
* @param x The input tensor. Must be of dtype 'bool'.
*/
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function logicalNot_<T extends Tensor>(x: T|TensorLike): T {
const $x = convertToTensor(x, 'x', 'logicalNot', 'bool');
return ENGINE.runKernel(backend => backend.logicalNot($x), {$x});
}
/**
* Returns the truth value of `a AND b` element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([false, false, true, true], 'bool');
* const b = tf.tensor1d([false, true, false, true], 'bool');
*
* a.logicalAnd(b).print();
* ```
*
* @param a The first input tensor. Must be of dtype bool.
* @param b The second input tensor. Must be of dtype bool.
*/
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function logicalAnd_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');
const $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');
assertAndGetBroadcastShape($a.shape, $b.shape);
return ENGINE.runKernel(backend => backend.logicalAnd($a, $b), {$a, $b}) as T;
}
/**
* Returns the truth value of `a OR b` element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([false, false, true, true], 'bool');
* const b = tf.tensor1d([false, true, false, true], 'bool');
*
* a.logicalOr(b).print();
* ```
* @param a The first input tensor. Must be of dtype bool.
* @param b The second input tensor. Must be of dtype bool.
*/
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function logicalOr_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'logicalOr', 'bool');
const $b = convertToTensor(b, 'b', 'logicalOr', 'bool');
assertAndGetBroadcastShape($a.shape, $b.shape);
return ENGINE.runKernel(backend => backend.logicalOr($a, $b), {$a, $b}) as T;
}
/**
* Returns the truth value of `a XOR b` element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([false, false, true, true], 'bool');
* const b = tf.tensor1d([false, true, false, true], 'bool');
*
* a.logicalXor(b).print();
* ```
*
* @param a The first input tensor. Must be of dtype bool.
* @param b The second input tensor. Must be of dtype bool.
*/
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function logicalXor_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'logicalXor', 'bool');
const $b = convertToTensor(b, 'b', 'logicalXor', 'bool');
assertAndGetBroadcastShape($a.shape, $b.shape);
// x ^ y = (x | y) & ~(x & y)
return logicalOr(a, b).logicalAnd(logicalAnd(a, b).logicalNot()) as T;
}
/**
* Returns the elements, either `a` or `b` depending on the `condition`.
*
* If the condition is true, select from `a`, otherwise select from `b`.
*
* ```js
* const cond = tf.tensor1d([false, false, true], 'bool');
* const a = tf.tensor1d([1 , 2, 3]);
* const b = tf.tensor1d([-1, -2, -3]);
*
* a.where(cond, b).print();
* ```
*
* @param condition The input condition. Must be of dtype bool.
* @param a If `condition` is rank 1, `a` may have a higher rank but
* its first dimension must match the size of `condition`.
* @param b A tensor with the same shape and type as `a`.
*/
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function where_<T extends Tensor>(
condition: Tensor|TensorLike, a: T|TensorLike, b: T|TensorLike): T {
const $a = convertToTensor(a, 'a', 'where');
const $b = convertToTensor(b, 'b', 'where');
const $condition = convertToTensor(condition, 'condition', 'where', 'bool');
assertShapesMatch($a.shape, $b.shape, 'Error in where: ');
if ($condition.rank === 1) {
// If condition rank is 1, then the first dimension must match the size of
// condition.
assert(
$condition.shape[0] === $a.shape[0],
() => 'The first dimension of `a` must match the size of `condition`.');
} else {
// A must have the same shape as condition.
assertShapesMatch($condition.shape, $b.shape, 'Error in where: ');
}
// TODO(julianoks): Return null for condition gradient
// when backprop supports it.
const grad = (dy: T, saved: Tensor[]) => {
const [$condition] = saved;
return {
$condition: () => zerosLike($condition).toFloat(),
$a: () => dy.mul($condition.cast(dy.dtype)) as T,
$b: () => dy.mul($condition.logicalNot().cast(dy.dtype)) as T
};
};
return ENGINE.runKernel((backend, save) => {
const res = backend.select($condition, $a, $b);
save([$condition]);
return res;
}, {$condition, $a, $b}, grad) as T;
}
/**
* Returns the coordinates of true elements of condition.
*
* The coordinates are returned in a 2-D tensor where the first dimension (rows)
* represents the number of true elements, and the second dimension (columns)
* represents the coordinates of the true elements. Keep in mind, the shape of
* the output tensor can vary depending on how many true values there are in
* input. Indices are output in row-major order. The resulting tensor has the
* shape `[numTrueElems, condition.rank]`.
*
* This is analogous to calling the python `tf.where(cond)` without an x or y.
*
* ```js
* const cond = tf.tensor1d([false, false, true], 'bool');
* const result = await tf.whereAsync(cond);
* result.print();
* ```
*/
/** @doc {heading: 'Operations', subheading: 'Logical'} */
async function whereAsync_(condition: Tensor|TensorLike): Promise<Tensor2D> {
const $condition =
convertToTensor(condition, 'condition', 'whereAsync', 'bool');
const vals = await $condition.data();
const res = whereImpl($condition.shape, vals);
if (condition !== $condition) {
$condition.dispose();
}
return res;
}
export const logicalAnd = op({logicalAnd_});
export const logicalNot = op({logicalNot_});
export const logicalOr = op({logicalOr_});
export const logicalXor = op({logicalXor_});
export const where = op({where_});
export const whereAsync = whereAsync_;