UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

116 lines 4.61 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); var tensor_1 = require("./tensor"); var util = require("./util"); function getFilteredNodesXToY(tape, xs, y) { var tensorsFromX = {}; var nodesFromX = {}; for (var i = 0; i < xs.length; i++) { tensorsFromX[xs[i].id] = true; } for (var i = 0; i < tape.length; i++) { var node = tape[i]; var nodeInputs = node.inputs; for (var inputName in nodeInputs) { var input = nodeInputs[inputName]; var anyInputFromX = false; for (var j = 0; j < xs.length; j++) { if (tensorsFromX[input.id]) { node.outputs.forEach(function (output) { return tensorsFromX[output.id] = true; }); anyInputFromX = true; nodesFromX[node.id] = true; break; } } if (anyInputFromX) { break; } } } var tensorsLeadToY = {}; tensorsLeadToY[y.id] = true; var nodesToY = {}; for (var i = tape.length - 1; i >= 0; i--) { var node = tape[i]; var nodeInputs = node.inputs; for (var j = 0; j < node.outputs.length; j++) { if (tensorsLeadToY[node.outputs[j].id]) { for (var inputName in nodeInputs) { tensorsLeadToY[nodeInputs[inputName].id] = true; nodesToY[node.id] = true; } break; } } } var filteredTape = []; for (var i = 0; i < tape.length; i++) { var node = tape[i]; if (nodesFromX[node.id] && nodesToY[node.id]) { var prunedInputs = {}; for (var inputName in node.inputs) { var nodeInput = node.inputs[inputName]; if (tensorsFromX[nodeInput.id]) { prunedInputs[inputName] = nodeInput; } } var prunedNode = Object.assign({}, node); prunedNode.inputs = prunedInputs; prunedNode.outputs = node.outputs; filteredTape.push(prunedNode); } } return filteredTape; } exports.getFilteredNodesXToY = getFilteredNodesXToY; function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape) { var _loop_1 = function (i) { var node = filteredTape[i]; var dys = []; node.outputs.forEach(function (o) { var gradTensor = tensorAccumulatedGradientMap[o.id]; if (gradTensor != null) { dys.push(gradTensor); } else { var dy = tensor_1.Tensor.make(o.shape, { values: util.makeZerosTypedArray(o.size, o.dtype) }, o.dtype); dys.push(dy); } }); if (node.gradient == null) { throw new Error("Cannot compute gradient: gradient function not found " + ("for " + node.name + ".")); } var inputGradients = node.gradient(node.outputs.length === 1 ? dys[0] : dys); for (var inputName in node.inputs) { if (!(inputName in inputGradients)) { throw new Error("Cannot backprop through input " + inputName + ". " + ("Available gradients found: " + Object.keys(inputGradients) + ".")); } var dx = inputGradients[inputName](); if (dx.dtype !== 'float32') { throw new Error("Error in gradient for op " + node.name + ". The gradient of input " + (inputName + " must have 'float32' dtype, but has '" + dx.dtype + "'")); } var x = node.inputs[inputName]; if (!util.arraysEqual(dx.shape, x.shape)) { throw new Error("Error in gradient for op " + node.name + ". The gradient of input " + ("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") + ("the shape of the input '" + x.shape + "'")); } if (tensorAccumulatedGradientMap[x.id] == null) { tensorAccumulatedGradientMap[x.id] = dx; } else { var curGradient = tensorAccumulatedGradientMap[x.id]; tensorAccumulatedGradientMap[x.id] = curGradient.add(dx); curGradient.dispose(); } } }; for (var i = filteredTape.length - 1; i >= 0; i--) { _loop_1(i); } } exports.backpropagateGradients = backpropagateGradients; //# sourceMappingURL=tape.js.map