UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

389 lines (373 loc) 13.2 kB
/** * @license * Copyright 2018 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {CustomGradientFunc, ENGINE} from './engine'; import {Scalar, Tensor, Variable} from './tensor'; import {NamedTensorMap} from './tensor_types'; import {convertToTensor, convertToTensorArray} from './tensor_util_env'; import {TensorLike} from './types'; import * as util from './util'; /** * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the * gradient of `f(x)` with respect to `x`. * * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to * `x` is computed instead. `f(x)` must take a single tensor `x` and return a * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead. * * ```js * // f(x) = x ^ 2 * const f = x => x.square(); * // f'(x) = 2x * const g = tf.grad(f); * * const x = tf.tensor1d([2, 3]); * g(x).print(); * ``` * * ```js * // f(x) = x ^ 3 * const f = x => x.pow(tf.scalar(3, 'int32')); * // f'(x) = 3x ^ 2 * const g = tf.grad(f); * // f''(x) = 6x * const gg = tf.grad(g); * * const x = tf.tensor1d([2, 3]); * gg(x).print(); * ``` * * @param f The function f(x), to compute gradient for. */ /** @doc {heading: 'Training', subheading: 'Gradients'} */ function grad(f: (x: Tensor) => Tensor): ( x: TensorLike|Tensor, dy?: TensorLike|Tensor) => Tensor { util.assert( util.isFunction(f), () => 'The f passed in grad(f) must be a function'); return (x: TensorLike|Tensor, dy?: TensorLike|Tensor): Tensor => { // x can be of any dtype, thus null as the last argument. const $x = convertToTensor(x, 'x', 'tf.grad', null); const $dy: Tensor = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null; return ENGINE.tidy(() => { const {value, grads} = ENGINE.gradients(() => f($x), [$x], $dy); if ($dy != null) { util.assertShapesMatch( value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' + 'returned by f(x)'); } checkGrads(grads); return grads[0]; }); }; } /** * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`, * which gives an array of gradients of `f()` with respect to each input * [`x1`,`x2`,...]. * * If `dy` is passed when calling `g()`, the gradient of * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead. * The provided `f` must take one or more tensors and return a single tensor * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead. * * ```js * // f(a, b) = a * b * const f = (a, b) => a.mul(b); * // df / da = b, df / db = a * const g = tf.grads(f); * * const a = tf.tensor1d([2, 3]); * const b = tf.tensor1d([-2, -3]); * const [da, db] = g([a, b]); * console.log('da'); * da.print(); * console.log('db'); * db.print(); * ``` * * @param f The function `f(x1, x2,...)` to compute gradients for. */ /** @doc {heading: 'Training', subheading: 'Gradients'} */ function grads(f: (...args: Tensor[]) => Tensor): ( args: Array<Tensor|TensorLike>, dy?: Tensor|TensorLike) => Tensor[] { util.assert( util.isFunction(f), () => 'The f passed in grads(f) must be a function'); return (args: Array<Tensor|TensorLike>, dy?: Tensor|TensorLike): Tensor[] => { util.assert( Array.isArray(args), () => 'The args passed in grads(f)(args) must be an array ' + 'of `Tensor`s or `TensorLike`s'); // args can be of any dtype, thus null as the last argument. const $args = convertToTensorArray(args, 'args', 'tf.grads', null); const $dy: Tensor = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null; return ENGINE.tidy(() => { const {value, grads} = ENGINE.gradients(() => f(...$args), $args, $dy); if ($dy != null) { util.assertShapesMatch( value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' + 'match the shape returned by f([x1,...])'); } checkGrads(grads); return grads; }); }; } /** * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()` * returns a metric you want to show. * * The result is a rich object with the following properties: * - grad: The gradient of `f(x)` w.r.t `x` (result of `tf.grad`). * - value: The value returned by `f(x)`. * * ```js * // f(x) = x ^ 2 * const f = x => x.square(); * // f'(x) = 2x * const g = tf.valueAndGrad(f); * * const x = tf.tensor1d([2, 3]); * const {value, grad} = g(x); * * console.log('value'); * value.print(); * console.log('grad'); * grad.print(); * ``` */ /** @doc {heading: 'Training', subheading: 'Gradients'} */ function valueAndGrad<I extends Tensor, O extends Tensor>(f: (x: I) => O): ( x: I, dy?: O) => { value: O; grad: I; } { util.assert( util.isFunction(f), () => 'The f passed in valueAndGrad(f) must be a function'); return (x: I, dy?: O) => { util.assert( x instanceof Tensor, () => 'The x passed in valueAndGrad(f)(x) must be a tensor'); util.assert( dy == null || dy instanceof Tensor, () => 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor'); const {grads, value} = ENGINE.gradients(() => f(x), [x], dy); checkGrads(grads); return {grad: grads[0] as I, value: value as O}; }; } /** * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()` * returns a metric you want to show. * * The result is a rich object with the following properties: * - grads: The gradients of `f()` w.r.t each input (result of `tf.grads`). * - value: The value returned by `f(x)`. * * ```js * // f(a, b) = a * b * const f = (a, b) => a.mul(b); * // df/da = b, df/db = a * const g = tf.valueAndGrads(f); * * const a = tf.tensor1d([2, 3]); * const b = tf.tensor1d([-2, -3]); * const {value, grads} = g([a, b]); * * const [da, db] = grads; * * console.log('value'); * value.print(); * * console.log('da'); * da.print(); * console.log('db'); * db.print(); * ``` */ /** @doc {heading: 'Training', subheading: 'Gradients'} */ function valueAndGrads<O extends Tensor>(f: (...args: Tensor[]) => O): ( args: Tensor[], dy?: O) => { grads: Tensor[]; value: O; } { util.assert( util.isFunction(f), () => 'The f passed in valueAndGrads(f) must be a function'); return (args: Tensor[], dy?: O) => { util.assert( Array.isArray(args) && args.every(arg => arg instanceof Tensor), () => 'The args passed in valueAndGrads(f)(args) must be array of ' + 'tensors'); util.assert( dy == null || dy instanceof Tensor, () => 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor'); const res = ENGINE.gradients(() => f(...args), args, dy); if (dy != null) { util.assertShapesMatch( res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' + 'match the shape returned by f([x1,...])'); } checkGrads(res.grads); return res; }; } /** * Computes and returns the gradient of f(x) with respect to the list of * trainable variables provided by `varList`. If no list is provided, it * defaults to all trainable variables. * * ```js * const a = tf.variable(tf.tensor1d([3, 4])); * const b = tf.variable(tf.tensor1d([5, 6])); * const x = tf.tensor1d([1, 2]); * * // f(a, b) = a * x ^ 2 + b * x * const f = () => a.mul(x.square()).add(b.mul(x)).sum(); * // df/da = x ^ 2, df/db = x * const {value, grads} = tf.variableGrads(f); * * Object.keys(grads).forEach(varName => grads[varName].print()); * ``` * * @param f The function to execute. f() should return a scalar. * @param varList The list of variables to compute the gradients with respect * to. Defaults to all trainable variables. * @returns An object with the following keys and values: * - `value`: The value of the function `f`. * - `grads`: A map from the names of the variables to the gradients. * If the `varList` argument is provided explicitly and contains a subset of * non-trainable variables, this map in the return value will contain keys * that map the names of the non-trainable variables to `null`. */ /** @doc {heading: 'Training', subheading: 'Gradients'} */ function variableGrads(f: () => Scalar, varList?: Variable[]): {value: Scalar, grads: NamedTensorMap} { util.assert( util.isFunction(f), () => 'The f passed in variableGrads(f) must be a function'); util.assert( varList == null || Array.isArray(varList) && varList.every(v => v instanceof Variable), () => 'The varList passed in variableGrads(f, varList) must be an array ' + 'of variables'); const specifiedVarList = varList != null; if (!specifiedVarList) { // Get all of the trainable variables. varList = []; for (const varName in ENGINE.registeredVariables) { varList.push(ENGINE.registeredVariables[varName]); } } const specifiedNonTrainable: Variable[] = specifiedVarList ? varList.filter(variable => !variable.trainable) : null; // Prune non-trainable variables. const originalVarCount = varList.length; varList = varList.filter(variable => variable.trainable); util.assert( varList.length > 0, () => `variableGrads() expects at least one of the input variables to ` + `be trainable, but none of the ${originalVarCount} variables is ` + `trainable.`); const allowNoGradients = true; const {value, grads} = ENGINE.gradients(f, varList, null, allowNoGradients); util.assert( grads.some(g => g != null), () => 'Cannot find a connection between any variable and the result of ' + 'the loss function y=f(x). Please make sure the operations that ' + 'use variables are inside the function f passed to minimize().'); util.assert( value.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it ` + `returned a rank-${value.rank} tensor`); const namedGrads: NamedTensorMap = {}; varList.forEach((v, i) => { if (grads[i] != null) { namedGrads[v.name] = grads[i]; } }); if (specifiedNonTrainable != null) { // If varList is explicitly provided and contains non-trainable values, // add them to the returned gradients with `null` values. specifiedNonTrainable.forEach(v => namedGrads[v.name] = null); } return {value, grads: namedGrads}; } /** * Overrides the gradient computation of a function `f`. * * Takes a function * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}` * and returns another function `g(...inputs)` which takes the same inputs as * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients * with respect to each input of `f` are computed using `f().gradFunc`. * * The `save` function passsed to `f` should be used for saving tensors needed * in the gradient. And the `saved` passed to the `gradFunc` is a * `NamedTensorMap`, which contains those saved tensor. * * ```js * const customOp = tf.customGrad((x, save) => { * // Save x to make sure it's available later for the gradient. * save([x]); * // Override gradient of our custom x ^ 2 op to be dy * abs(x); * return { * value: x.square(), * // Note `saved.x` which points to the `x` we saved earlier. * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())] * }; * }); * * const x = tf.tensor1d([-1, -2, 3]); * const dx = tf.grad(x => customOp(x)); * * console.log(`f(x):`); * customOp(x).print(); * console.log(`f'(x):`); * dx(x).print(); * ``` * * @param f The function to evaluate in forward mode, which should return * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc` * returns the custom gradients of `f` with respect to its inputs. */ /** @doc {heading: 'Training', subheading: 'Gradients'} */ function customGrad<T extends Tensor>(f: CustomGradientFunc<T>): (...args: Tensor[]) => T { return ENGINE.customGrad(f); } function checkGrads(grads: Tensor[]) { const numNullGradients = grads.filter(g => g == null).length; if (numNullGradients > 0) { throw new Error( `Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.`); } } export { customGrad, variableGrads, valueAndGrad, valueAndGrads, grad, grads, };