@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
112 lines • 5.93 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
var environment_1 = require("./environment");
var tensor_1 = require("./tensor");
var util = require("./util");
function gradScope(nameOrScopeFn, scopeFn) {
return environment_1.ENV.engine.tidy(nameOrScopeFn, scopeFn, true);
}
exports.gradScope = gradScope;
function grad(f) {
util.assert(util.isFunction(f), 'The f passed in grad(f) must be a function');
return function (x, dy) {
util.assert(x instanceof tensor_1.Tensor, 'The x passed in grad(f)(x) must be a tensor');
util.assert(dy == null || dy instanceof tensor_1.Tensor, 'The dy passed in grad(f)(x, dy) must be a tensor');
return environment_1.ENV.engine.tidy(function () {
var _a = environment_1.ENV.engine.gradients(function () { return f(x); }, [x], dy), value = _a.value, grads = _a.grads;
if (dy != null) {
util.assertShapesMatch(value.shape, dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' +
'returned by f(x)');
}
checkGrads(grads);
return grads[0];
});
};
}
exports.grad = grad;
function grads(f) {
util.assert(util.isFunction(f), 'The f passed in grads(f) must be a function');
return function (args, dy) {
util.assert(Array.isArray(args) && args.every(function (arg) { return arg instanceof tensor_1.Tensor; }), 'The args passed in grads(f)(args) must be an array of tensors');
util.assert(dy == null || dy instanceof tensor_1.Tensor, 'The dy passed in grads(f)(args, dy) must be a tensor');
return environment_1.ENV.engine.tidy(function () {
var _a = environment_1.ENV.engine.gradients(function () { return f.apply(void 0, args); }, args, dy), value = _a.value, grads = _a.grads;
if (dy != null) {
util.assertShapesMatch(value.shape, dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' +
'match the shape returned by f([x1,...])');
}
checkGrads(grads);
return grads;
});
};
}
exports.grads = grads;
function valueAndGrad(f) {
util.assert(util.isFunction(f), 'The f passed in valueAndGrad(f) must be a function');
return function (x, dy) {
util.assert(x instanceof tensor_1.Tensor, 'The x passed in valueAndGrad(f)(x) must be a tensor');
util.assert(dy == null || dy instanceof tensor_1.Tensor, 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor');
var _a = environment_1.ENV.engine.gradients(function () { return f(x); }, [x], dy), grads = _a.grads, value = _a.value;
checkGrads(grads);
return { grad: grads[0], value: value };
};
}
exports.valueAndGrad = valueAndGrad;
function valueAndGrads(f) {
util.assert(util.isFunction(f), 'The f passed in valueAndGrads(f) must be a function');
return function (args, dy) {
util.assert(Array.isArray(args) && args.every(function (arg) { return arg instanceof tensor_1.Tensor; }), 'The args passed in valueAndGrads(f)(args) must be array of tensors');
util.assert(dy == null || dy instanceof tensor_1.Tensor, 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor');
var res = environment_1.ENV.engine.gradients(function () { return f.apply(void 0, args); }, args, dy);
if (dy != null) {
util.assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' +
'match the shape returned by f([x1,...])');
}
checkGrads(res.grads);
return res;
};
}
exports.valueAndGrads = valueAndGrads;
function variableGrads(f, varList) {
util.assert(util.isFunction(f), 'The f passed in variableGrads(f) must be a function');
util.assert(varList == null ||
Array.isArray(varList) && varList.every(function (v) { return v instanceof tensor_1.Variable; }), 'The varList passed in variableGrads(f, varList) must be an array ' +
'of variables');
if (varList == null) {
varList = [];
for (var varName in environment_1.ENV.engine.registeredVariables) {
varList.push(environment_1.ENV.engine.registeredVariables[varName]);
}
}
var originalVarCount = varList.length;
varList = varList.filter(function (variable) { return variable.trainable; });
util.assert(varList.length > 0, "variableGrads() expects at least one of the input variables to be " +
("trainable, but none of the " + originalVarCount + " variables is ") +
"trainable.");
var allowNoGradients = true;
var _a = environment_1.ENV.engine.gradients(f, varList, null, allowNoGradients), value = _a.value, grads = _a.grads;
util.assert(grads.some(function (g) { return g != null; }), 'Cannot find a connection between any variable and the result of the ' +
'loss function y=f(x). Please make sure the operations that use ' +
'variables are inside the function f passed to minimize().');
util.assert(value.rank === 0, "The f passed in variableGrads(f) must return a scalar, but it " +
("returned a rank-" + value.rank + " tensor"));
var namedGrads = {};
varList.forEach(function (v, i) {
if (grads[i] != null) {
namedGrads[v.name] = grads[i];
}
});
return { value: value, grads: namedGrads };
}
exports.variableGrads = variableGrads;
function customGrad(f) {
return environment_1.ENV.engine.customGrad(f);
}
exports.customGrad = customGrad;
function checkGrads(grads) {
var numNullGradients = grads.filter(function (g) { return g == null; }).length;
if (numNullGradients > 0) {
throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y.");
}
}
//# sourceMappingURL=gradients.js.map