UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

112 lines 5.93 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var environment_1 = require("./environment"); var tensor_1 = require("./tensor"); var util = require("./util"); function gradScope(nameOrScopeFn, scopeFn) { return environment_1.ENV.engine.tidy(nameOrScopeFn, scopeFn, true); } exports.gradScope = gradScope; function grad(f) { util.assert(util.isFunction(f), 'The f passed in grad(f) must be a function'); return function (x, dy) { util.assert(x instanceof tensor_1.Tensor, 'The x passed in grad(f)(x) must be a tensor'); util.assert(dy == null || dy instanceof tensor_1.Tensor, 'The dy passed in grad(f)(x, dy) must be a tensor'); return environment_1.ENV.engine.tidy(function () { var _a = environment_1.ENV.engine.gradients(function () { return f(x); }, [x], dy), value = _a.value, grads = _a.grads; if (dy != null) { util.assertShapesMatch(value.shape, dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' + 'returned by f(x)'); } checkGrads(grads); return grads[0]; }); }; } exports.grad = grad; function grads(f) { util.assert(util.isFunction(f), 'The f passed in grads(f) must be a function'); return function (args, dy) { util.assert(Array.isArray(args) && args.every(function (arg) { return arg instanceof tensor_1.Tensor; }), 'The args passed in grads(f)(args) must be an array of tensors'); util.assert(dy == null || dy instanceof tensor_1.Tensor, 'The dy passed in grads(f)(args, dy) must be a tensor'); return environment_1.ENV.engine.tidy(function () { var _a = environment_1.ENV.engine.gradients(function () { return f.apply(void 0, args); }, args, dy), value = _a.value, grads = _a.grads; if (dy != null) { util.assertShapesMatch(value.shape, dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' + 'match the shape returned by f([x1,...])'); } checkGrads(grads); return grads; }); }; } exports.grads = grads; function valueAndGrad(f) { util.assert(util.isFunction(f), 'The f passed in valueAndGrad(f) must be a function'); return function (x, dy) { util.assert(x instanceof tensor_1.Tensor, 'The x passed in valueAndGrad(f)(x) must be a tensor'); util.assert(dy == null || dy instanceof tensor_1.Tensor, 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor'); var _a = environment_1.ENV.engine.gradients(function () { return f(x); }, [x], dy), grads = _a.grads, value = _a.value; checkGrads(grads); return { grad: grads[0], value: value }; }; } exports.valueAndGrad = valueAndGrad; function valueAndGrads(f) { util.assert(util.isFunction(f), 'The f passed in valueAndGrads(f) must be a function'); return function (args, dy) { util.assert(Array.isArray(args) && args.every(function (arg) { return arg instanceof tensor_1.Tensor; }), 'The args passed in valueAndGrads(f)(args) must be array of tensors'); util.assert(dy == null || dy instanceof tensor_1.Tensor, 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor'); var res = environment_1.ENV.engine.gradients(function () { return f.apply(void 0, args); }, args, dy); if (dy != null) { util.assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' + 'match the shape returned by f([x1,...])'); } checkGrads(res.grads); return res; }; } exports.valueAndGrads = valueAndGrads; function variableGrads(f, varList) { util.assert(util.isFunction(f), 'The f passed in variableGrads(f) must be a function'); util.assert(varList == null || Array.isArray(varList) && varList.every(function (v) { return v instanceof tensor_1.Variable; }), 'The varList passed in variableGrads(f, varList) must be an array ' + 'of variables'); if (varList == null) { varList = []; for (var varName in environment_1.ENV.engine.registeredVariables) { varList.push(environment_1.ENV.engine.registeredVariables[varName]); } } var originalVarCount = varList.length; varList = varList.filter(function (variable) { return variable.trainable; }); util.assert(varList.length > 0, "variableGrads() expects at least one of the input variables to be " + ("trainable, but none of the " + originalVarCount + " variables is ") + "trainable."); var allowNoGradients = true; var _a = environment_1.ENV.engine.gradients(f, varList, null, allowNoGradients), value = _a.value, grads = _a.grads; util.assert(grads.some(function (g) { return g != null; }), 'Cannot find a connection between any variable and the result of the ' + 'loss function y=f(x). Please make sure the operations that use ' + 'variables are inside the function f passed to minimize().'); util.assert(value.rank === 0, "The f passed in variableGrads(f) must return a scalar, but it " + ("returned a rank-" + value.rank + " tensor")); var namedGrads = {}; varList.forEach(function (v, i) { if (grads[i] != null) { namedGrads[v.name] = grads[i]; } }); return { value: value, grads: namedGrads }; } exports.variableGrads = variableGrads; function customGrad(f) { return environment_1.ENV.engine.customGrad(f); } exports.customGrad = customGrad; function checkGrads(grads) { var numNullGradients = grads.filter(function (g) { return g == null; }).length; if (numNullGradients > 0) { throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y."); } } //# sourceMappingURL=gradients.js.map