@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
100 lines • 3.87 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var util = require("./util");
function getFilteredNodesXToY(tape, xs, y) {
var tensorsFromX = {};
var nodesFromX = {};
for (var i = 0; i < xs.length; i++) {
tensorsFromX[xs[i].id] = true;
}
for (var i = 0; i < tape.length; i++) {
var node = tape[i];
var nodeInputs = node.inputs;
for (var inputName in nodeInputs) {
var input = nodeInputs[inputName];
var anyInputFromX = false;
for (var j = 0; j < xs.length; j++) {
if (tensorsFromX[input.id]) {
tensorsFromX[node.output.id] = true;
anyInputFromX = true;
nodesFromX[node.id] = true;
break;
}
}
if (anyInputFromX) {
break;
}
}
}
var tensorsLeadToY = {};
tensorsLeadToY[y.id] = true;
var nodesToY = {};
for (var i = tape.length - 1; i >= 0; i--) {
var node = tape[i];
var nodeInputs = node.inputs;
var outputs = [];
outputs.push(node.output);
for (var j = 0; j < outputs.length; j++) {
if (tensorsLeadToY[outputs[j].id]) {
for (var inputName in nodeInputs) {
tensorsLeadToY[nodeInputs[inputName].id] = true;
nodesToY[node.id] = true;
}
break;
}
}
}
var filteredTape = [];
for (var i = 0; i < tape.length; i++) {
var node = tape[i];
if (nodesFromX[node.id] && nodesToY[node.id]) {
var prunedInputs = {};
for (var inputName in node.inputs) {
var nodeInput = node.inputs[inputName];
if (tensorsFromX[nodeInput.id]) {
prunedInputs[inputName] = nodeInput;
}
}
var prunedNode = Object.assign({}, node);
prunedNode.inputs = prunedInputs;
prunedNode.output = node.output;
filteredTape.push(prunedNode);
}
}
return filteredTape;
}
exports.getFilteredNodesXToY = getFilteredNodesXToY;
function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape) {
for (var i = filteredTape.length - 1; i >= 0; i--) {
var node = filteredTape[i];
var dy = tensorAccumulatedGradientMap[node.output.id];
if (node.gradient == null) {
throw new Error("Cannot compute gradient: gradient function not found " +
("for " + node.name + "."));
}
var inputGradients = node.gradient(dy);
for (var inputName in node.inputs) {
if (!(inputName in inputGradients)) {
throw new Error("Cannot backprop through input " + inputName + ". " +
("Available gradients found: " + Object.keys(inputGradients) + "."));
}
var dx = inputGradients[inputName]();
var x = node.inputs[inputName];
if (!util.arraysEqual(dx.shape, x.shape)) {
throw new Error("Error in gradient for op " + node.name + ". The gradient of input " +
("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") +
("the shape of the input '" + x.shape + "'"));
}
if (tensorAccumulatedGradientMap[x.id] == null) {
tensorAccumulatedGradientMap[x.id] = dx;
}
else {
var curGradient = tensorAccumulatedGradientMap[x.id];
tensorAccumulatedGradientMap[x.id] = curGradient.add(dx);
curGradient.dispose();
}
}
}
}
exports.backpropagateGradients = backpropagateGradients;
//# sourceMappingURL=tape.js.map