UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

858 lines (815 loc) 27.4 kB
/** * @license * Copyright 2018 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {ENGINE} from '../engine'; import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {makeTypesMatch} from '../tensor_util'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike, upcastType} from '../types'; import * as util from '../util'; import * as broadcast_util from './broadcast_util'; import {op} from './operation'; import {scalar, zerosLike} from './tensor_ops'; import {neg} from './unary_ops'; /** * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting. * * We also expose `tf.addStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * * ```js * const a = tf.tensor1d([1, 2, 3, 4]); * const b = tf.tensor1d([10, 20, 30, 40]); * * a.add(b).print(); // or tf.add(a, b) * ``` * * ```js * // Broadcast add a with b. * const a = tf.scalar(5); * const b = tf.tensor1d([10, 20, 30, 40]); * * a.add(b).print(); // or tf.add(a, b) * ``` * @param a The first `tf.Tensor` to add. * @param b The second `tf.Tensor` to add. Must have the same type as `a`. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function add_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'add'); let $b = convertToTensor(b, 'b', 'add'); [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor) => { const derA = () => { let res = dy; const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape($a.shape); }; const derB = () => { let res = dy; const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape($b.shape); }; return {$a: derA, $b: derB}; }; return ENGINE.runKernel(backend => backend.add($a, $b), {$a, $b}, der) as T; } /** * Adds a list of `tf.Tensor`s element-wise, each with the same shape and dtype. * * ```js * const a = tf.tensor1d([1, 2]); * const b = tf.tensor1d([3, 4]); * const c = tf.tensor1d([5, 6]); * * tf.addN([a, b, c]).print(); * ``` * @param tensors A list of tensors with the same shape and dtype. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function addN_<T extends Tensor>(tensors: Array<T|TensorLike>): T { util.assert( Array.isArray(tensors), () => 'The argument passed to tf.addN() must be a list of tensors'); util.assert( tensors.length >= 1, () => `Must pass at least one tensor to tf.addN(), but got ` + `${tensors.length}`); const $tensors = tensors.map((t, i) => convertToTensor(t, `tensors${i}`, 'addN')); const firstTensor = $tensors[0]; $tensors.forEach(t => { if (t.dtype !== firstTensor.dtype) { throw new Error( 'All tensors passed to tf.addN() must have the same dtype'); } }); $tensors.forEach(t => { if (!util.arraysEqual(t.shape, firstTensor.shape)) { throw new Error( 'All tensors passed to tf.addN() must have the same shape'); } }); const der = (dy: T) => { const ders: {[key: string]: () => Tensor} = {}; $tensors.forEach((t, i) => { ders[i] = () => dy.clone(); }); return ders; }; const inputs: NamedTensorMap = $tensors as {} as NamedTensorMap; return ENGINE.runKernel(backend => backend.addN($tensors), inputs, der); } /** * Adds two `tf.Tensor`s element-wise, A + B. * * Inputs must be the same shape. For broadcasting support, use add() instead. * * @param a The first Tensor to add element-wise. * @param b The second Tensor to add element-wise. */ function addStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T { const $a = convertToTensor(a, 'a', 'addStrict'); const $b = convertToTensor(b, 'b', 'addStrict'); util.assertShapesMatch($a.shape, $b.shape, 'Error in addStrict: '); return $a.add($b); } /** * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting. * * We also expose `tf.subStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * * ```js * const a = tf.tensor1d([10, 20, 30, 40]); * const b = tf.tensor1d([1, 2, 3, 4]); * * a.sub(b).print(); // or tf.sub(a, b) * ``` * * ```js * // Broadcast subtract a with b. * const a = tf.tensor1d([10, 20, 30, 40]); * const b = tf.scalar(5); * * a.sub(b).print(); // or tf.sub(a, b) * ``` * @param a The first `tf.Tensor` to subtract from. * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as * `a`. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function sub_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'sub'); let $b = convertToTensor(b, 'b', 'sub'); [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor) => { const derA = () => { let res = dy; const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape($a.shape); }; const derB = () => { let res = dy; const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.neg().reshape($b.shape); }; return {$a: derA, $b: derB}; }; return ENGINE.runKernel(backend => backend.subtract($a, $b), {$a, $b}, der) as T; } /** * Subtracts two `tf.Tensor`s element-wise, A - B. Inputs must * be the same shape. * * For broadcasting support, use `tf.sub` instead. * * @param a The first Tensor to subtract element-wise. * @param b The second Tensor to subtract element-wise. */ function subStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T { const $a = convertToTensor(a, 'a', 'subStrict'); const $b = convertToTensor(b, 'b', 'subStrict'); util.assertShapesMatch($a.shape, $b.shape, 'Error in subStrict: '); return $a.sub($b); } /** * Computes the power of one `tf.Tensor` to another. Supports broadcasting. * * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for * corresponding elements in x and y. The result's dtype will be the upcasted * type of the `base` and `exp` dtypes. * * ```js * const a = tf.tensor([[2, 3], [4, 5]]) * const b = tf.tensor([[1, 2], [3, 0]]).toInt(); * * a.pow(b).print(); // or tf.pow(a, b) * ``` * * ```js * const a = tf.tensor([[1, 2], [3, 4]]) * const b = tf.tensor(2).toInt(); * * a.pow(b).print(); // or tf.pow(a, b) * ``` * We also expose `powStrict` which has the same signature as this op and * asserts that `base` and `exp` are the same shape (does not broadcast). * * @param base The base `tf.Tensor` to pow element-wise. * @param exp The exponent `tf.Tensor` to pow element-wise. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function pow_<T extends Tensor>(base: T|TensorLike, exp: Tensor|TensorLike): T { const $base = convertToTensor(base, 'base', 'pow'); const $exp = convertToTensor(exp, 'exp', 'pow'); const outShape = broadcast_util.assertAndGetBroadcastShape($base.shape, $exp.shape); base = $base.cast(upcastType($base.dtype, $exp.dtype)); exp = $exp.cast(upcastType($base.dtype, $exp.dtype)); const grad = (dy: Tensor, saved: Tensor[]) => { const [$base, $exp, y] = saved; const derBase = () => { const expFloat = $exp.toFloat(); let res = dy.mul(expFloat.mul($base.pow(expFloat.sub(scalar(1))))); const reduceAxes = broadcast_util.getReductionAxes($base.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape($base.shape) as T; }; const derExp = () => { const condition = $base.greater(0); const logBase = $base.log().where(condition, zerosLike($base)); let res = dy.mul(y.mul(logBase)); const reduceAxes = broadcast_util.getReductionAxes($exp.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape($exp.shape); }; return {$base: derBase, $exp: derExp}; }; return ENGINE.runKernel((backend, save) => { const y = backend.pow($base, $exp); save([$base, $exp, y]); return y; }, {$base, $exp}, grad) as T; } /** * Computes the power of one `tf.Tensor` to another. Inputs must * be the same shape. * * For broadcasting support, use `tf.pow` instead. * * @param base The base tensor to pow element-wise. * @param exp The exponent tensor to pow element-wise. */ function powStrict_<T extends Tensor>(base: T, exp: Tensor): T { util.assertShapesMatch(base.shape, exp.shape, 'Error in powStrict: '); return base.pow(exp); } /** * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting. * * We also expose `tf.mulStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * * ```js * const a = tf.tensor1d([1, 2, 3, 4]); * const b = tf.tensor1d([2, 3, 4, 5]); * * a.mul(b).print(); // or tf.mul(a, b) * ``` * * ```js * // Broadcast mul a with b. * const a = tf.tensor1d([1, 2, 3, 4]); * const b = tf.scalar(5); * * a.mul(b).print(); // or tf.mul(a, b) * ``` * @param a The first tensor to multiply. * @param b The second tensor to multiply. Must have the same dtype as `a`. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function mul_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'mul'); let $b = convertToTensor(b, 'b', 'mul'); [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor, saved: Tensor[]) => { const [$a, $b] = saved; const derA = () => { const res = dy.mul($b.toFloat()); const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape); if (reduceAxes.length > 0) { return res.sum(reduceAxes).reshape($a.shape); } return res; }; const derB = () => { const res = dy.mul($a.toFloat()); const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape); if (reduceAxes.length > 0) { return res.sum(reduceAxes).reshape($b.shape); } return res; }; return {$a: derA, $b: derB}; }; return ENGINE.runKernel((backend, save) => { const res = backend.multiply($a, $b); save([$a, $b]); return res; }, {$a, $b}, der) as T; } /** * Multiplies two `tf.Tensor`s element-wise, A * B. * * Inputs must be the same shape. For broadcasting support, use `tf.mul`. * * @param a The first tensor to multiply. * @param b The first tensor to multiply. Must have the same * dtype as `a`. */ function mulStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T { const $a = convertToTensor(a, 'a', 'mul'); const $b = convertToTensor(b, 'b', 'mul'); util.assertShapesMatch($a.shape, $b.shape, 'Error in multiplyStrict: '); return $a.mul($b) as T; } /** * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. * * We also expose `tf.divStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * * ```js * const a = tf.tensor1d([1, 4, 9, 16]); * const b = tf.tensor1d([1, 2, 3, 4]); * * a.div(b).print(); // or tf.div(a, b) * ``` * * ```js * // Broadcast div a with b. * const a = tf.tensor1d([2, 4, 6, 8]); * const b = tf.scalar(2); * * a.div(b).print(); // or tf.div(a, b) * ``` * * @param a The first tensor as the numerator. * @param b The second tensor as the denominator. Must have the same dtype as * `a`. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function div_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'div'); let $b = convertToTensor(b, 'b', 'div'); [$a, $b] = makeTypesMatch($a, $b); if ($a.dtype === 'int32' && $b.dtype === 'int32') { return floorDiv($a, $b); } const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor, saved: Tensor[]) => { const [$a, $b] = saved; const derA = () => { const res = dy.div($b.toFloat()); const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape); if (reduceAxes.length > 0) { return res.sum(reduceAxes).reshape($a.shape); } return res; }; const derB = () => { let res = dy.mul($a.toFloat()); const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes).reshape($b.shape); } const tmp = $b.square() as Tensor; return res.div(tmp.toFloat()).neg() as Tensor; }; return {$a: derA, $b: derB}; }; return ENGINE.runKernel((backend, save) => { const res = backend.realDivide($a, $b); save([$a, $b]); return res; }, {$a, $b}, der) as T; } /** * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. * The result is rounded with floor function. * * * ```js * const a = tf.tensor1d([1, 4, 9, 16]); * const b = tf.tensor1d([1, 2, 3, 4]); * * a.floorDiv(b).print(); // or tf.div(a, b) * ``` * * ```js * // Broadcast div a with b. * const a = tf.tensor1d([2, 4, 6, 8]); * const b = tf.scalar(2); * * a.floorDiv(b).print(); // or tf.floorDiv(a, b) * ``` * * @param a The first tensor as the numerator. * @param b The second tensor as the denominator. Must have the same dtype as * `a`. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function floorDiv_<T extends Tensor>( a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'floorDiv'); let $b = convertToTensor(b, 'b', 'floorDiv'); [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor, saved: Tensor[]) => { const [$a, $b] = saved; const derA = () => { const res = dy.div($b.toFloat()); const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape); if (reduceAxes.length > 0) { return res.sum(reduceAxes).reshape($a.shape); } return res; }; const derB = () => { let res = dy.mul($a.toFloat()); const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes).reshape($b.shape); } const tmp = $b.square() as Tensor; return res.div(tmp.toFloat()).neg() as Tensor; }; return {$a: derA, $b: derB}; }; return ENGINE.runKernel((backend, save) => { const res = backend.floorDiv($a, $b); save([$a, $b]); return res; }, {$a, $b}, der) as T; } /** * Divides two `tf.Tensor`s element-wise, A / B. Inputs must * be the same shape. * * @param a The first tensor as the numerator for element-wise division. * @param b The second tensor as the denominator for element-wise division. */ function divStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T { const $a = convertToTensor(a, 'a', 'div'); const $b = convertToTensor(b, 'b', 'div'); util.assertShapesMatch($a.shape, $b.shape, 'Error in divideStrict: '); return $a.div($b) as T; } /** * Returns the mod of a and b element-wise. * `floor(x / y) * y + mod(x, y) = x` * Supports broadcasting. * * We also expose `tf.modStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * * ```js * const a = tf.tensor1d([1, 4, 3, 16]); * const b = tf.tensor1d([1, 2, 9, 4]); * * a.mod(b).print(); // or tf.mod(a, b) * ``` * * ```js * // Broadcast a mod b. * const a = tf.tensor1d([2, 4, 6, 8]); * const b = tf.scalar(5); * * a.mod(b).print(); // or tf.mod(a, b) * ``` * * @param a The first tensor. * @param b The second tensor. Must have the same type as `a`. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function mod_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'mod'); let $b = convertToTensor(b, 'b', 'mod'); [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor, saved: Tensor[]) => { const [$a, $b] = saved; const derA = () => { const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape); if (reduceAxes.length > 0) { return dy.sum(reduceAxes).reshape($a.shape); } return dy; }; const derB = () => { const res = dy.mul($a.div($b).floor().neg()); const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape); if (reduceAxes.length > 0) { return res.sum(reduceAxes).reshape($b.shape); } return res; }; return {$a: derA, $b: derB}; }; return ENGINE.runKernel((backend, save) => { const res = backend.mod($a, $b); save([$a, $b]); return res; }, {$a, $b}, der) as T; } /** * Returns the mod of a and b (`a < b ? a : b`) element-wise. Inputs must * be the same shape. For broadcasting support, use mod(). * * @param a The first tensor. * @param b The second tensor. Must have the same dtype as `a`. */ function modStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T { const $a = convertToTensor(a, 'a', 'modStrict'); const $b = convertToTensor(b, 'b', 'modStrict'); util.assertShapesMatch($a.shape, $b.shape, 'Error in modStrict: '); return $a.mod($b); } /** * Returns the min of a and b (`a < b ? a : b`) element-wise. * Supports broadcasting. * * We also expose `minimumStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * * ```js * const a = tf.tensor1d([1, 4, 3, 16]); * const b = tf.tensor1d([1, 2, 9, 4]); * * a.minimum(b).print(); // or tf.minimum(a, b) * ``` * * ```js * // Broadcast minimum a with b. * const a = tf.tensor1d([2, 4, 6, 8]); * const b = tf.scalar(5); * * a.minimum(b).print(); // or tf.minimum(a, b) * ``` * * @param a The first tensor. * @param b The second tensor. Must have the same type as `a`. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function minimum_<T extends Tensor>( a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'minimum'); let $b = convertToTensor(b, 'b', 'minimum'); [$a, $b] = makeTypesMatch($a, $b); if ($a.dtype === 'bool') { $a = $a.toInt(); $b = $b.toInt(); } broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor, saved: Tensor[]) => { const [$a, $b] = saved; const derA = () => dy.mul($a.lessEqual($b).toFloat()); const derB = () => dy.mul($a.greater($b).toFloat()); return {$a: derA, $b: derB}; }; return ENGINE.runKernel((backend, save) => { const res = backend.minimum($a, $b); save([$a, $b]); return res; }, {$a, $b}, der) as T; } /** * Returns the min of a and b (`a < b ? a : b`) element-wise. Inputs must * be the same shape. For broadcasting support, use minimum(). * * @param a The first tensor. * @param b The second tensor. Must have the same dtype as `a`. */ function minimumStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T { const $a = convertToTensor(a, 'a', 'minimumStrict'); const $b = convertToTensor(b, 'b', 'minimumStrict'); util.assertShapesMatch($a.shape, $b.shape, 'Error in minimumStrict: '); return $a.minimum($b); } /** * Returns the max of a and b (`a > b ? a : b`) element-wise. * Supports broadcasting. * * We also expose `tf.maximumStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * * ```js * const a = tf.tensor1d([1, 4, 3, 16]); * const b = tf.tensor1d([1, 2, 9, 4]); * * a.maximum(b).print(); // or tf.maximum(a, b) * ``` * * ```js * // Broadcast maximum a with b. * const a = tf.tensor1d([2, 4, 6, 8]); * const b = tf.scalar(5); * * a.maximum(b).print(); // or tf.maximum(a, b) * ``` * * @param a The first tensor. * @param b The second tensor. Must have the same type as `a`. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function maximum_<T extends Tensor>( a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'maximum'); let $b = convertToTensor(b, 'b', 'maximum'); [$a, $b] = makeTypesMatch($a, $b); if ($a.dtype === 'bool') { $a = $a.toInt(); $b = $b.toInt(); } broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor, saved: Tensor[]) => { const [$a, $b] = saved; const derA = () => dy.mul($a.greaterEqual($b).toFloat()); const derB = () => dy.mul($a.less($b).toFloat()); return {$a: derA, $b: derB}; }; return ENGINE.runKernel((backend, save) => { const res = backend.maximum($a, $b); save([$a, $b]); return res; }, {$a, $b}, der) as T; } /** * Returns the max of a and b (`a > b ? a : b`) element-wise. Inputs must * be the same shape. For broadcasting support, use maximum(). * * @param a The first tensor. * @param b The second tensor. Must have the same dtype as `a`. */ function maximumStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T { const $a = convertToTensor(a, 'a', 'maximumStrict'); const $b = convertToTensor(b, 'b', 'maximumStrict'); util.assertShapesMatch($a.shape, $b.shape, 'Error in maximumStrict: '); return $a.maximum($b); } /** * Returns (a - b) * (a - b) element-wise. * Supports broadcasting. * * We also expose `tf.squaredDifferenceStrict` which has the same signature as * this op and asserts that `a` and `b` are the same shape (does not * broadcast). * * ```js * const a = tf.tensor1d([1, 4, 3, 16]); * const b = tf.tensor1d([1, 2, 9, 4]); * * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b) * ``` * * ```js * // Broadcast squared difference a with b. * const a = tf.tensor1d([2, 4, 6, 8]); * const b = tf.scalar(5); * * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b) * ``` * * @param a The first tensor. * @param b The second tensor. Must have the same type as `a`. */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function squaredDifference_<T extends Tensor>( a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'squaredDifference'); let $b = convertToTensor(b, 'b', 'squaredDifference'); [$a, $b] = makeTypesMatch($a, $b); broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor, saved: Tensor[]) => { const [$a, $b] = saved; const two = scalar(2); const derA = () => dy.mul($a.sub($b).mul(two)); const derB = () => dy.mul($b.sub($a).mul(two)); return {$a: derA, $b: derB}; }; return ENGINE.runKernel((backend, save) => { const res = backend.squaredDifference($a, $b); save([$a, $b]); return res; }, {$a, $b}, der) as T; } /** * Returns (a - b) * (a - b) element-wise. * * Inputs must be the same shape. For broadcasting support, use * `tf.squaredDifference` instead. * * @param a The first tensor. * @param b The second tensor. Must have the same type as `a`. */ function squaredDifferenceStrict_<T extends Tensor>( a: T|TensorLike, b: T|TensorLike): T { const $a = convertToTensor(a, 'a', 'squaredDifferenceStrict'); const $b = convertToTensor(b, 'b', 'squaredDifferenceStrict'); util.assertShapesMatch( $a.shape, $b.shape, 'Error in squaredDifferenceStrict: '); return $a.squaredDifference($b); } /** * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`. * Supports broadcasting. * * ```js * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]); * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]); * * tf.atan2(a, b).print() * ``` * * @param a The first tensor. * @param b The second tensor. Must have the same dtype as `a`. * */ /** @doc {heading: 'Operations', subheading: 'Basic math'} */ function atan2_<T extends Tensor>( a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'atan2'); let $b = convertToTensor(b, 'b', 'atan2'); [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor, saved: Tensor[]) => { const [$a, $b] = saved; const derA = () => { const d = add($a.square(), $b.square()); let res = dy.mul($b.div(d)); const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape($a.shape); }; const derB = () => { const d = add($a.square(), $b.square()) as T; let res = neg(dy.mul($a.div(d))); const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes); } return res.reshape($b.shape); }; return {$a: derA, $b: derB}; }; return ENGINE.runKernel((backend, save) => { const res = backend.atan2($a, $b); save([$a, $b]); return res; }, {$a, $b}, der) as T; } export const add = op({add_}); export const addN = op({addN_}); export const addStrict = op({addStrict_}); export const atan2 = op({atan2_}); export const div = op({div_}); export const divStrict = op({divStrict_}); export const floorDiv = op({floorDiv_}); export const maximum = op({maximum_}); export const maximumStrict = op({maximumStrict_}); export const minimum = op({minimum_}); export const minimumStrict = op({minimumStrict_}); export const mod = op({mod_}); export const modStrict = op({modStrict_}); export const mul = op({mul_}); export const mulStrict = op({mulStrict_}); export const pow = op({pow_}); export const powStrict = op({powStrict_}); export const squaredDifference = op({squaredDifference_}); export const squaredDifferenceStrict = op({squaredDifferenceStrict_}); export const sub = op({sub_}); export const subStrict = op({subStrict_});