UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

100 lines 3.87 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var util = require("./util"); function getFilteredNodesXToY(tape, xs, y) { var tensorsFromX = {}; var nodesFromX = {}; for (var i = 0; i < xs.length; i++) { tensorsFromX[xs[i].id] = true; } for (var i = 0; i < tape.length; i++) { var node = tape[i]; var nodeInputs = node.inputs; for (var inputName in nodeInputs) { var input = nodeInputs[inputName]; var anyInputFromX = false; for (var j = 0; j < xs.length; j++) { if (tensorsFromX[input.id]) { tensorsFromX[node.output.id] = true; anyInputFromX = true; nodesFromX[node.id] = true; break; } } if (anyInputFromX) { break; } } } var tensorsLeadToY = {}; tensorsLeadToY[y.id] = true; var nodesToY = {}; for (var i = tape.length - 1; i >= 0; i--) { var node = tape[i]; var nodeInputs = node.inputs; var outputs = []; outputs.push(node.output); for (var j = 0; j < outputs.length; j++) { if (tensorsLeadToY[outputs[j].id]) { for (var inputName in nodeInputs) { tensorsLeadToY[nodeInputs[inputName].id] = true; nodesToY[node.id] = true; } break; } } } var filteredTape = []; for (var i = 0; i < tape.length; i++) { var node = tape[i]; if (nodesFromX[node.id] && nodesToY[node.id]) { var prunedInputs = {}; for (var inputName in node.inputs) { var nodeInput = node.inputs[inputName]; if (tensorsFromX[nodeInput.id]) { prunedInputs[inputName] = nodeInput; } } var prunedNode = Object.assign({}, node); prunedNode.inputs = prunedInputs; prunedNode.output = node.output; filteredTape.push(prunedNode); } } return filteredTape; } exports.getFilteredNodesXToY = getFilteredNodesXToY; function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape) { for (var i = filteredTape.length - 1; i >= 0; i--) { var node = filteredTape[i]; var dy = tensorAccumulatedGradientMap[node.output.id]; if (node.gradient == null) { throw new Error("Cannot compute gradient: gradient function not found " + ("for " + node.name + ".")); } var inputGradients = node.gradient(dy); for (var inputName in node.inputs) { if (!(inputName in inputGradients)) { throw new Error("Cannot backprop through input " + inputName + ". " + ("Available gradients found: " + Object.keys(inputGradients) + ".")); } var dx = inputGradients[inputName](); var x = node.inputs[inputName]; if (!util.arraysEqual(dx.shape, x.shape)) { throw new Error("Error in gradient for op " + node.name + ". The gradient of input " + ("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") + ("the shape of the input '" + x.shape + "'")); } if (tensorAccumulatedGradientMap[x.id] == null) { tensorAccumulatedGradientMap[x.id] = dx; } else { var curGradient = tensorAccumulatedGradientMap[x.id]; tensorAccumulatedGradientMap[x.id] = curGradient.add(dx); curGradient.dispose(); } } } } exports.backpropagateGradients = backpropagateGradients; //# sourceMappingURL=tape.js.map