@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
196 lines (174 loc) • 6.64 kB
text/typescript
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {Tensor} from './tensor';
import {NamedTensorMap} from './tensor_types';
import * as util from './util';
export interface TapeNode {
id: number;
name: string;
outputs: Tensor[];
inputs: NamedTensorMap;
// Optional params, defined only for ops with gradient impl.
gradient?: (dy: Tensor|Tensor[]) => NamedGradientMap;
saved?: Tensor[];
}
export type NamedGradientMap = {
[inputName: string]: () => Tensor;
};
/**
* Computes a list of TapeNodes that connect x to y, filtering everything else
* out and preserving the order of the original tape elements.
*
* @param tape The tape elements to filter.
* @param xs The input Tensors.
* @param y The output Tensor.
*/
export function getFilteredNodesXToY(
tape: TapeNode[], xs: Tensor[], y: Tensor): TapeNode[] {
// Forward pass to compute all the nodes and Tensors that are transitively a
// function of x.
const tensorsFromX: {[tensorId: number]: boolean} = {};
const nodesFromX: {[nodeId: number]: boolean} = {};
for (let i = 0; i < xs.length; i++) {
tensorsFromX[xs[i].id] = true;
}
for (let i = 0; i < tape.length; i++) {
const node = tape[i];
const nodeInputs = node.inputs;
for (const inputName in nodeInputs) {
const input = nodeInputs[inputName];
let anyInputFromX = false;
for (let j = 0; j < xs.length; j++) {
if (tensorsFromX[input.id]) {
node.outputs.forEach(output => tensorsFromX[output.id] = true);
anyInputFromX = true;
nodesFromX[node.id] = true;
break;
}
}
if (anyInputFromX) {
break;
}
}
}
// Backward pass to find all of the nodes and Tensors that lead to y.
const tensorsLeadToY: {[tensorId: number]: boolean} = {};
tensorsLeadToY[y.id] = true;
const nodesToY: {[nodeId: number]: boolean} = {};
for (let i = tape.length - 1; i >= 0; i--) {
const node = tape[i];
const nodeInputs = node.inputs;
// If any of the outputs lead to y, mark all of the inputs as leading to y.
for (let j = 0; j < node.outputs.length; j++) {
if (tensorsLeadToY[node.outputs[j].id]) {
for (const inputName in nodeInputs) {
tensorsLeadToY[nodeInputs[inputName].id] = true;
nodesToY[node.id] = true;
}
break;
}
}
}
// Return the paths that come from x and lead to y.
const filteredTape: TapeNode[] = [];
for (let i = 0; i < tape.length; i++) {
const node = tape[i];
if (nodesFromX[node.id] && nodesToY[node.id]) {
// Prune the inputs from the node that aren't a function of x.
const prunedInputs: {[inputName: string]: Tensor} = {};
for (const inputName in node.inputs) {
const nodeInput = node.inputs[inputName];
if (tensorsFromX[nodeInput.id]) {
prunedInputs[inputName] = nodeInput;
}
}
// Copy the node and overwrite inputsAndArgs to the pruned version.
const prunedNode = Object.assign({}, node) as TapeNode;
prunedNode.inputs = prunedInputs;
prunedNode.outputs = node.outputs;
filteredTape.push(prunedNode);
}
}
return filteredTape;
}
/**
* Backpropagate gradients through the filtered TapeNodes.
*
* @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
* is mutated by this method.
* @param filteredTape The filtered TapeNodes to backprop through.
*/
export function backpropagateGradients(
tensorAccumulatedGradientMap: {[tensorId: number]: Tensor},
filteredTape: TapeNode[], tidy: (f: Function) => Tensor) {
// Walk the tape backward and keep a map of Tensor to its gradient.
for (let i = filteredTape.length - 1; i >= 0; i--) {
const node = filteredTape[i];
const dys: Tensor[] = [];
node.outputs.forEach(o => {
const gradTensor = tensorAccumulatedGradientMap[o.id];
if (gradTensor != null) {
dys.push(gradTensor);
} else {
// This particular output is not in the back-propagation subgraph, so it
// does not affect the final output, thus we put zeros for its dy.
const dy = Tensor.make(
o.shape, {values: util.makeZerosTypedArray(o.size, o.dtype)},
o.dtype);
dys.push(dy);
}
});
if (node.gradient == null) {
throw new Error(
`Cannot compute gradient: gradient function not found ` +
`for ${node.name}.`);
}
// Backprop dy through this node and accumulate gradients over the inputs.
const inputGradients =
// Grad functions of ops with single outputs expect a dy, while ops
// with multiple outputs expect dys (array of dy).
node.gradient(node.outputs.length === 1 ? dys[0] : dys);
for (const inputName in node.inputs) {
if (!(inputName in inputGradients)) {
throw new Error(
`Cannot backprop through input ${inputName}. ` +
`Available gradients found: ${Object.keys(inputGradients)}.`);
}
// Call the gradient function.
const dx = tidy(() => inputGradients[inputName]());
if (dx.dtype !== 'float32') {
throw new Error(
`Error in gradient for op ${node.name}. The gradient of input ` +
`${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
}
const x = node.inputs[inputName];
if (!util.arraysEqual(dx.shape, x.shape)) {
throw new Error(
`Error in gradient for op ${node.name}. The gradient of input ` +
`'${inputName}' has shape '${dx.shape}', which does not match ` +
`the shape of the input '${x.shape}'`);
}
if (tensorAccumulatedGradientMap[x.id] == null) {
tensorAccumulatedGradientMap[x.id] = dx;
} else {
const curGradient = tensorAccumulatedGradientMap[x.id];
tensorAccumulatedGradientMap[x.id] = curGradient.add(dx);
curGradient.dispose();
}
}
}
}