UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

452 lines 56 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Executor: Evaluates SymbolicTensor based on feeds. */ import { cast, dispose, memory, util } from '@tensorflow/tfjs-core'; import { ValueError } from '../errors'; import { LruCache } from '../utils/executor_utils'; import { toList } from '../utils/generic_utils'; import { InputLayer } from './input_layer'; import { SymbolicTensor } from './topology'; /** * Helper function to check the dtype and shape compatibility of a feed value. */ function assertFeedCompatibility(key, val) { // Check dtype compatibility. if (key.dtype == null || key.dtype === val.dtype) { // a. If types match, return val tensor as is. return val; } try { // b. Attempt to convert to expected type. return cast(val, key.dtype); } catch (err) { // c. If conversion fails, return helpful error. throw new ValueError(`The dtype of the feed (${val.dtype}) can not be cast to the dtype ` + `of the key '${key.name}' (${key.dtype}).`); } } /** * FeedDict: A mapping from unique SymbolicTensors to feed values for them. * A feed value is a concrete value represented as an `Tensor`. */ export class FeedDict { /** * Constructor, optionally does copy-construction. * @param feeds An Array of `Feed`s, or another `FeedDict`, in which case * copy-construction will be performed. */ constructor(feeds) { this.id2Value = {}; this.id2Mask = {}; this.name2Id = {}; if (feeds instanceof FeedDict) { for (const id in feeds.id2Value) { this.id2Value[id] = feeds.id2Value[id]; if (id in feeds.id2Mask) { this.id2Mask[id] = feeds.id2Mask[id]; } } } else { if (feeds == null) { return; } for (const feed of feeds) { this.add(feed.key, feed.value); } } } /** * Add a key-value pair to the FeedDict. * * @param key The key of the feed. * @param value The value of the tensor feed. * @param mask The value of the mask feed (optional). * @returns This `FeedDict`. * @throws ValueError: If the key `SymbolicTensor` already exists in the * `FeedDict`. */ add(key, value, mask) { if (this.id2Value[key.id] == null) { this.id2Value[key.id] = assertFeedCompatibility(key, value); this.name2Id[key.name] = key.id; if (mask != null) { this.id2Mask[key.id] = mask; } } else { throw new ValueError(`Duplicate key: name=${key.name}, id=${key.id}`); } return this; } /** * Add a Feed to the FeedDict. * @param feed The new `Feed` to add. * @returns This `FeedDict`. */ addFeed(feed) { this.add(feed.key, feed.value); } /** * Probe whether a key already exists in the FeedDict. * @param key */ hasKey(key) { return this.id2Value[key.id] != null; } /** * Get all the SymbolicTensor available in this FeedDict. */ names() { return Object.keys(this.name2Id); } /** * Get the feed value for given key. * @param key The SymbolicTensor, or its name (as a string), of which the * value is sought. * @returns If `key` exists, the corresponding feed value. * @throws ValueError: If `key` does not exist in this `FeedDict`. */ getValue(key) { if (key instanceof SymbolicTensor) { if (this.id2Value[key.id] == null) { throw new ValueError(`Nonexistent key: ${key.name}`); } else { return this.id2Value[key.id]; } } else { const id = this.name2Id[key]; if (id == null) { throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`); } return this.id2Value[id]; } } /** * Get the feed mask for given key. * @param key The SymbolicTensor, or its name (as a string), of which the * value is sought. * @returns If `key` exists, the corresponding feed mask. * @throws ValueError: If `key` does not exist in this `FeedDict`. */ getMask(key) { if (key instanceof SymbolicTensor) { if (this.id2Value[key.id] == null) { throw new ValueError(`Nonexistent key: ${key.name}`); } else { return this.id2Mask[key.id]; } } else { const id = this.name2Id[key]; if (id == null) { throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`); } return this.id2Mask[id]; } } /** Dispose all mask Tensors held by this object. */ disposeMasks() { if (this.id2Mask != null) { dispose(this.id2Mask); } } } // Cache for topologically sorted SymbolicTensors for given execution // targets (i.e., fetches). export const cachedSorted = new LruCache(); // Cache for recipient count maps for given execution targets (i.e., fetches). export const cachedRecipientCounts = new LruCache(); export function updateCacheMaxEntries(maxEntries) { if (cachedSorted != null) { cachedSorted.setMaxEntries(maxEntries); } if (cachedRecipientCounts != null) { cachedRecipientCounts.setMaxEntries(maxEntries); } } /** * Execute a SymbolicTensor by using concrete feed values. * * A `SymbolicTensor` object is a node in a computation graph of TF.js * Layers. The object is backed by a source layer and input * `SymbolicTensor`s to the source layer. This method evaluates * the `call()` method of the source layer, using concrete values of the * inputs obtained from either * * `feedDict`, if the input key exists in `feedDict`, or else, * * a recursive call to `execute()` itself. * * @param x: The `SymbolicTensor` to execute. * @param feedDict: The feed values, as base condition of the recursion. * execution. * @param kwargs: Optional keyword arguments. * @param probe: A probe object (of interface `ExecutionProbe`) used for * testing memory footprint of `execute` calls. * @returns Result of the execution. * @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s * encountered during the execution lacks a feed value in `feedDict`. */ export function execute(fetches, feedDict, kwargs, probe) { const training = kwargs == null ? false : kwargs['training']; const arrayFetches = Array.isArray(fetches); const fetchArray = arrayFetches ? fetches : [fetches]; const outputNames = fetchArray.map(t => t.name); const finalOutputs = []; const feedNames = feedDict.names(); for (const outputName of outputNames) { if (feedNames.indexOf(outputName) !== -1) { finalOutputs.push(feedDict.getValue(outputName)); } else { finalOutputs.push(null); } } if (probe != null) { // For optional probing of memory footprint during execution. probe.maxNumTensors = -Infinity; probe.minNumTensors = Infinity; } // Check cache. const fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().sort().join(','); let sorted = cachedSorted.get(fetchAndFeedKey); let recipientCounts; if (sorted == null) { // Cache doesn't contain the desired combination of fetches. Compute // topological sort for the combination for the first time. const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict); sorted = out.sorted; recipientCounts = out.recipientCounts; // Store results in cache for future use. cachedSorted.put(fetchAndFeedKey, sorted); cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts); } recipientCounts = {}; if (!training) { Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey)); } const internalFeedDict = new FeedDict(feedDict); // Start iterative execution on the topologically-sorted SymbolicTensors. for (let i = 0; i < sorted.length; ++i) { if (probe != null) { // For optional probing of memory usage during execution. const numTensors = memory().numTensors; if (numTensors > probe.maxNumTensors) { probe.maxNumTensors = numTensors; } if (numTensors < probe.minNumTensors) { probe.minNumTensors = numTensors; } } const symbolic = sorted[i]; const srcLayer = symbolic.sourceLayer; if (srcLayer instanceof InputLayer) { continue; } const inputValues = []; const inputMasks = []; const tensorsToDispose = []; let maskExists = false; for (const input of symbolic.inputs) { const value = internalFeedDict.getValue(input); const mask = internalFeedDict.getMask(input); inputValues.push(value); inputMasks.push(mask); if (mask != null) { maskExists = true; } if (!training) { recipientCounts[input.name]--; if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) && outputNames.indexOf(input.name) === -1 && !value.isDisposed && input.sourceLayer.stateful !== true) { tensorsToDispose.push(value); } } } if (maskExists) { kwargs = kwargs || {}; kwargs['mask'] = inputMasks[0]; } const outputTensors = toList(srcLayer.apply(inputValues, kwargs)); let outputMask = null; if (srcLayer.supportsMasking) { outputMask = srcLayer.computeMask(inputValues, inputMasks); } const layerOutputs = getNodeOutputs(symbolic); const outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs]; for (let i = 0; i < outputSymbolicTensors.length; ++i) { if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) { internalFeedDict.add(outputSymbolicTensors[i], outputTensors[i], Array.isArray(outputMask) ? outputMask[0] : outputMask); } const index = outputNames.indexOf(outputSymbolicTensors[i].name); if (index !== -1) { finalOutputs[index] = outputTensors[i]; } } if (!training) { // Clean up Tensors that are no longer needed. dispose(tensorsToDispose); } } // NOTE(cais): Unlike intermediate tensors, we don't discard mask // tensors as we go, because these tensors are sometimes passed over a // series of mutliple layers, i.e., not obeying the immediate input // relations in the graph. If this becomes a memory-usage concern, // we can improve this in the future. internalFeedDict.disposeMasks(); return arrayFetches ? finalOutputs : finalOutputs[0]; } /** * Sort the `SymbolicTensor`s topologically, for an array of fetches. * * This function calls getTopologicalSortAndRecipientCountsForOneFetch and * merges their results. * * @param fetch The array of fetches requested. Must be a non-empty array. * @param feedDict The dictionary of fed values. * @returns sorted: Topologically-sorted array of SymbolicTensors. * recipientCounts: Recipient counts for all SymbolicTensors in `sorted`. */ function getTopologicalSortAndRecipientCounts(fetches, feedDict) { util.assert(fetches != null && fetches.length > 0, () => `Expected at least one fetch, got none`); let finalSorted = []; let finalRecipientMap = {}; if (fetches.length === 1) { // Special-casing 1 fetch for efficiency. const out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict); finalSorted = out.sorted; finalRecipientMap = out.recipientMap; } else { const visited = new Set(); for (const fetch of fetches) { const { sorted, recipientMap } = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict); // Merge sorted SymbolicTensor Arrays. for (const symbolicTensor of sorted) { if (!visited.has(symbolicTensor.name)) { finalSorted.push(symbolicTensor); visited.add(symbolicTensor.name); } } // Merge recipient maps. for (const name in recipientMap) { if (finalRecipientMap[name] == null) { finalRecipientMap[name] = new Set(); } recipientMap[name].forEach(recipient => finalRecipientMap[name].add(recipient)); } } } return { sorted: finalSorted, recipientCounts: recipientMap2Counts(finalRecipientMap) }; } function recipientMap2Counts(recipientMap) { const recipientCounts = {}; for (const name in recipientMap) { recipientCounts[name] = recipientMap[name].size; } return recipientCounts; } /** * Sort the `SymbolicTensor`s topologically, for a single fetch. * * This helper function processes the upstream SymbolicTensors of a single * fetch. * * @param fetch The single fetch requested. * @param feedDict The dictionary of fed values. * @returns sorted: Topologically-sorted array of SymbolicTensors. * recipientMap: Recipient names for all SymbolicTensors in `sorted`. */ export function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) { const visited = new Set(); const sorted = []; const recipientMap = {}; // Put keys of the feedDict into visited first, so they don't have to be // walked. This is needed in case where there are feeds for intermediate // SymbolicTensors of the graph. for (const key of feedDict.names()) { visited.add(key); } const stack = []; const marks = []; // Initial population of stack and marks. stack.push(fetch); while (stack.length > 0) { const top = stack[stack.length - 1]; if (visited.has(top.name)) { stack.pop(); continue; } const topIsMarked = marks[marks.length - 1] === stack.length - 1; if (top.inputs.length === 0 || topIsMarked) { // Input SymbolicTensor or all children have been visited. stack.pop(); sorted.push(top); visited.add(top.name); if (topIsMarked) { marks.pop(); } } else { // A non-input SymbolicTensor whose upstream SymbolicTensors haven't // been visited yet. Push them onto the stack. marks.push(stack.length - 1); for (const input of top.inputs) { // Increment the recipient count. Note that this needs to happen // regardless of whether the SymbolicTensor has been visited before. if (recipientMap[input.name] == null) { recipientMap[input.name] = new Set(); } recipientMap[input.name].add(top.name); if (visited.has(input.name)) { continue; // Avoid repeated visits to the same SymbolicTensor. } stack.push(input); } } } return { sorted, recipientMap }; } /** * Get the symbolic output tensors of the node to which a given fetch belongs. * @param fetch The fetched symbolic tensor. * @returns The Array of symbolic tensors output by the node to which `fetch` * belongs. */ function getNodeOutputs(fetch) { let layerOutputs; if (fetch.sourceLayer.inboundNodes.length === 1) { layerOutputs = fetch.sourceLayer.output; } else { let nodeIndex = null; for (let i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) { for (const outputTensor of fetch.sourceLayer.inboundNodes[i] .outputTensors) { if (outputTensor.id === fetch.id) { nodeIndex = i; break; } } } layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex); } return layerOutputs; } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"executor.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/engine/executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AAEH,OAAO,EAAC,IAAI,EAAE,OAAO,EAAE,MAAM,EAAU,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAE1E,OAAO,EAAC,UAAU,EAAC,MAAM,WAAW,CAAC;AAErC,OAAO,EAAC,QAAQ,EAAC,MAAM,yBAAyB,CAAC;AACjD,OAAO,EAAC,MAAM,EAAC,MAAM,wBAAwB,CAAC;AAE9C,OAAO,EAAC,UAAU,EAAC,MAAM,eAAe,CAAC;AACzC,OAAO,EAAC,cAAc,EAAC,MAAM,YAAY,CAAC;AAE1C;;GAEG;AACH,SAAS,uBAAuB,CAAC,GAAmB,EAAE,GAAW;IAC/D,6BAA6B;IAC7B,IAAI,GAAG,CAAC,KAAK,IAAI,IAAI,IAAI,GAAG,CAAC,KAAK,KAAK,GAAG,CAAC,KAAK,EAAE;QAChD,gDAAgD;QAChD,OAAO,GAAG,CAAC;KACZ;IACD,IAAI;QACF,2CAA2C;QAC3C,OAAO,IAAI,CAAC,GAAG,EAAE,GAAG,CAAC,KAAK,CAAC,CAAC;KAC7B;IAAC,OAAO,GAAG,EAAE;QACZ,iDAAiD;QACjD,MAAM,IAAI,UAAU,CAChB,0BAA0B,GAAG,CAAC,KAAK,iCAAiC;YACpE,eAAe,GAAG,CAAC,IAAI,MAAM,GAAG,CAAC,KAAK,IAAI,CAAC,CAAC;KACjD;AACH,CAAC;AAUD;;;GAGG;AACH,MAAM,OAAO,QAAQ;IAKnB;;;;OAIG;IACH,YAAY,KAAuB;QAT3B,aAAQ,GAA2B,EAAE,CAAC;QACtC,YAAO,GAA2B,EAAE,CAAC;QACrC,YAAO,GAA6B,EAAE,CAAC;QAQ7C,IAAI,KAAK,YAAY,QAAQ,EAAE;YAC7B,KAAK,MAAM,EAAE,IAAI,KAAK,CAAC,QAAQ,EAAE;gBAC/B,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC;gBACvC,IAAI,EAAE,IAAI,KAAK,CAAC,OAAO,EAAE;oBACvB,IAAI,CAAC,OAAO,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;iBACtC;aACF;SACF;aAAM;YACL,IAAI,KAAK,IAAI,IAAI,EAAE;gBACjB,OAAO;aACR;YACD,KAAK,MAAM,IAAI,IAAI,KAAK,EAAE;gBACxB,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;aAChC;SACF;IACH,CAAC;IAED;;;;;;;;;OASG;IACH,GAAG,CAAC,GAAmB,EAAE,KAAa,EAAE,IAAa;QACnD,IAAI,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,IAAI,EAAE;YACjC,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,uBAAuB,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;YAC5D,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,GAAG,GAAG,CAAC,EAAE,CAAC;YAChC,IAAI,IAAI,IAAI,IAAI,EAAE;gBAChB,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC;aAC7B;SACF;aAAM;YACL,MAAM,IAAI,UAAU,CAAC,uBAAuB,GAAG,CAAC,IAAI,QAAQ,GAAG,CAAC,EAAE,EAAE,CAAC,CAAC;SACvE;QACD,OAAO,IAAI,CAAC;IACd,CAAC;IAED;;;;OAIG;IACH,OAAO,CAAC,IAAU;QAChB,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;IACjC,CAAC;IAED;;;OAGG;IACH,MAAM,CAAC,GAAmB;QACxB,OAAO,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,IAAI,CAAC;IACvC,CAAC;IAED;;OAEG;IACH,KAAK;QACH,OAAO,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;IACnC,CAAC;IAED;;;;;;OAMG;IACH,QAAQ,CAAC,GAA0B;QACjC,IAAI,GAAG,YAAY,cAAc,EAAE;YACjC,IAAI,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,IAAI,EAAE;gBACjC,MAAM,IAAI,UAAU,CAAC,oBAAoB,GAAG,CAAC,IAAI,EAAE,CAAC,CAAC;aACtD;iBAAM;gBACL,OAAO,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;aAC9B;SACF;aAAM;YACL,MAAM,EAAE,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;YAC7B,IAAI,EAAE,IAAI,IAAI,EAAE;gBACd,MAAM,IAAI,UAAU,CAAC,yCAAyC,GAAG,EAAE,CAAC,CAAC;aACtE;YACD,OAAO,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC;SAC1B;IACH,CAAC;IAED;;;;;;OAMG;IACH,OAAO,CAAC,GAA0B;QAChC,IAAI,GAAG,YAAY,cAAc,EAAE;YACjC,IAAI,IAAI,CAAC,QAAQ,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,IAAI,EAAE;gBACjC,MAAM,IAAI,UAAU,CAAC,oBAAoB,GAAG,CAAC,IAAI,EAAE,CAAC,CAAC;aACtD;iBAAM;gBACL,OAAO,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;aAC7B;SACF;aAAM;YACL,MAAM,EAAE,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;YAC7B,IAAI,EAAE,IAAI,IAAI,EAAE;gBACd,MAAM,IAAI,UAAU,CAAC,yCAAyC,GAAG,EAAE,CAAC,CAAC;aACtE;YACD,OAAO,IAAI,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;SACzB;IACH,CAAC;IAED,oDAAoD;IACpD,YAAY;QACV,IAAI,IAAI,CAAC,OAAO,IAAI,IAAI,EAAE;YACxB,OAAO,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;SACvB;IACH,CAAC;CACF;AAED,qEAAqE;AACrE,2BAA2B;AAC3B,MAAM,CAAC,MAAM,YAAY,GACrB,IAAI,QAAQ,EAAoB,CAAC;AAErC,8EAA8E;AAC9E,MAAM,CAAC,MAAM,qBAAqB,GAC9B,IAAI,QAAQ,EAAmB,CAAC;AAEpC,MAAM,UAAU,qBAAqB,CAAC,UAAkB;IACtD,IAAI,YAAY,IAAI,IAAI,EAAE;QACxB,YAAY,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;KACxC;IACD,IAAI,qBAAqB,IAAI,IAAI,EAAE;QACjC,qBAAqB,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;KACjD;AACH,CAAC;AAsBD;;;;;;;;;;;;;;;;;;;;GAoBG;AACH,MAAM,UAAU,OAAO,CACnB,OAAwC,EAAE,QAAkB,EAC5D,MAAe,EAAE,KAAsB;IAEzC,MAAM,QAAQ,GAAY,MAAM,IAAI,IAAI,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,MAAM,CAAC,UAAU,CAAC,CAAC;IAEtE,MAAM,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC;IAC5C,MAAM,UAAU,GACZ,YAAY,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC;IAEvC,MAAM,WAAW,GAAG,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;IAChD,MAAM,YAAY,GAAa,EAAE,CAAC;IAClC,MAAM,SAAS,GAAG,QAAQ,CAAC,KAAK,EAAE,CAAC;IACnC,KAAK,MAAM,UAAU,IAAI,WAAW,EAAE;QACpC,IAAI,SAAS,CAAC,OAAO,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,EAAE;YACxC,YAAY,CAAC,IAAI,CAAC,QAAQ,CAAC,QAAQ,CAAC,UAAU,CAAC,CAAC,CAAC;SAClD;aAAM;YACL,YAAY,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;SACzB;KACF;IAED,IAAI,KAAK,IAAI,IAAI,EAAE;QACjB,6DAA6D;QAC7D,KAAK,CAAC,aAAa,GAAG,CAAC,QAAQ,CAAC;QAChC,KAAK,CAAC,aAAa,GAAG,QAAQ,CAAC;KAChC;IAED,eAAe;IACf,MAAM,eAAe,GACjB,WAAW,CAAC,IAAI,CAAC,GAAG,CAAC,GAAG,GAAG,GAAG,QAAQ,CAAC,KAAK,EAAE,CAAC,IAAI,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;IACpE,IAAI,MAAM,GAAqB,YAAY,CAAC,GAAG,CAAC,eAAe,CAAC,CAAC;IACjE,IAAI,eAA8C,CAAC;IACnD,IAAI,MAAM,IAAI,IAAI,EAAE;QAClB,oEAAoE;QACpE,2DAA2D;QAC3D,MAAM,GAAG,GAAG,oCAAoC,CAAC,UAAU,EAAE,QAAQ,CAAC,CAAC;QACvE,MAAM,GAAG,GAAG,CAAC,MAAM,CAAC;QACpB,eAAe,GAAG,GAAG,CAAC,eAAe,CAAC;QAEtC,yCAAyC;QACzC,YAAY,CAAC,GAAG,CAAC,eAAe,EAAE,MAAM,CAAC,CAAC;QAC1C,qBAAqB,CAAC,GAAG,CAAC,eAAe,EAAE,eAAe,CAAC,CAAC;KAC7D;IACD,eAAe,GAAG,EAAE,CAAC;IACrB,IAAI,CAAC,QAAQ,EAAE;QACb,MAAM,CAAC,MAAM,CAAC,eAAe,EAAE,qBAAqB,CAAC,GAAG,CAAC,eAAe,CAAC,CAAC,CAAC;KAC5E;IAED,MAAM,gBAAgB,GAAG,IAAI,QAAQ,CAAC,QAAQ,CAAC,CAAC;IAEhD,yEAAyE;IACzE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACtC,IAAI,KAAK,IAAI,IAAI,EAAE;YACjB,yDAAyD;YACzD,MAAM,UAAU,GAAG,MAAM,EAAE,CAAC,UAAU,CAAC;YACvC,IAAI,UAAU,GAAG,KAAK,CAAC,aAAa,EAAE;gBACpC,KAAK,CAAC,aAAa,GAAG,UAAU,CAAC;aAClC;YACD,IAAI,UAAU,GAAG,KAAK,CAAC,aAAa,EAAE;gBACpC,KAAK,CAAC,aAAa,GAAG,UAAU,CAAC;aAClC;SACF;QAED,MAAM,QAAQ,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QAC3B,MAAM,QAAQ,GAAG,QAAQ,CAAC,WAAW,CAAC;QACtC,IAAI,QAAQ,YAAY,UAAU,EAAE;YAClC,SAAS;SACV;QACD,MAAM,WAAW,GAAa,EAAE,CAAC;QACjC,MAAM,UAAU,GAAa,EAAE,CAAC;QAChC,MAAM,gBAAgB,GAAa,EAAE,CAAC;QAEtC,IAAI,UAAU,GAAG,KAAK,CAAC;QACvB,KAAK,MAAM,KAAK,IAAI,QAAQ,CAAC,MAAM,EAAE;YACnC,MAAM,KAAK,GAAG,gBAAgB,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC;YAC/C,MAAM,IAAI,GAAG,gBAAgB,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;YAC7C,WAAW,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;YACxB,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;YACtB,IAAI,IAAI,IAAI,IAAI,EAAE;gBAChB,UAAU,GAAG,IAAI,CAAC;aACnB;YACD,IAAI,CAAC,QAAQ,EAAE;gBACb,eAAe,CAAC,KAAK,CAAC,IAAI,CAAC,EAAE,CAAC;gBAC9B,IAAI,eAAe,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,KAAK,CAAC;oBAC5D,WAAW,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,UAAU;oBAC3D,KAAK,CAAC,WAAW,CAAC,QAAQ,KAAK,IAAI,EAAE;oBACvC,gBAAgB,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;iBAC9B;aACF;SACF;QAED,IAAI,UAAU,EAAE;YACd,MAAM,GAAG,MAAM,IAAI,EAAE,CAAC;YACtB,MAAM,CAAC,MAAM,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;SAChC;QACD,MAAM,aAAa,GACf,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAC,CAAa,CAAC;QAC5D,IAAI,UAAU,GAAoB,IAAI,CAAC;QACvC,IAAI,QAAQ,CAAC,eAAe,EAAE;YAC5B,UAAU,GAAG,QAAQ,CAAC,WAAW,CAAC,WAAW,EAAE,UAAU,CAAC,CAAC;SAC5D;QACD,MAAM,YAAY,GAAG,cAAc,CAAC,QAAQ,CAAC,CAAC;QAC9C,MAAM,qBAAqB,GACvB,KAAK,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,CAAC;QAChE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,qBAAqB,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YACrD,IAAI,CAAC,gBAAgB,CAAC,MAAM,CAAC,qBAAqB,CAAC,CAAC,CAAC,CAAC,EAAE;gBACtD,gBAAgB,CAAC,GAAG,CAChB,qBAAqB,CAAC,CAAC,CAAC,EAAE,aAAa,CAAC,CAAC,CAAC,EAC1C,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC;aAC7D;YACD,MAAM,KAAK,GAAG,WAAW,CAAC,OAAO,CAAC,qBAAqB,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;YACjE,IAAI,KAAK,KAAK,CAAC,CAAC,EAAE;gBAChB,YAAY,CAAC,KAAK,CAAC,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;aACxC;SACF;QAED,IAAI,CAAC,QAAQ,EAAE;YACb,8CAA8C;YAC9C,OAAO,CAAC,gBAAgB,CAAC,CAAC;SAC3B;KACF;IACD,iEAAiE;IACjE,sEAAsE;IACtE,mEAAmE;IACnE,kEAAkE;IAClE,qCAAqC;IACrC,gBAAgB,CAAC,YAAY,EAAE,CAAC;IAEhC,OAAO,YAAY,CAAC,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC;AACvD,CAAC;AAUD;;;;;;;;;;GAUG;AACH,SAAS,oCAAoC,CACzC,OAAyB,EAAE,QAAkB;IAE/C,IAAI,CAAC,MAAM,CACP,OAAO,IAAI,IAAI,IAAI,OAAO,CAAC,MAAM,GAAG,CAAC,EACrC,GAAG,EAAE,CAAC,uCAAuC,CAAC,CAAC;IAEnD,IAAI,WAAW,GAAqB,EAAE,CAAC;IACvC,IAAI,iBAAiB,GAAiB,EAAE,CAAC;IACzC,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;QACxB,yCAAyC;QACzC,MAAM,GAAG,GACL,+CAA+C,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;QAC1E,WAAW,GAAG,GAAG,CAAC,MAAM,CAAC;QACzB,iBAAiB,GAAG,GAAG,CAAC,YAAY,CAAC;KACtC;SAAM;QACL,MAAM,OAAO,GAAG,IAAI,GAAG,EAAU,CAAC;QAClC,KAAK,MAAM,KAAK,IAAI,OAAO,EAAE;YAC3B,MAAM,EAAC,MAAM,EAAE,YAAY,EAAC,GACxB,+CAA+C,CAAC,KAAK,EAAE,QAAQ,CAAC,CAAC;YAErE,sCAAsC;YACtC,KAAK,MAAM,cAAc,IAAI,MAAM,EAAE;gBACnC,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,cAAc,CAAC,IAAI,CAAC,EAAE;oBACrC,WAAW,CAAC,IAAI,CAAC,cAAc,CAAC,CAAC;oBACjC,OAAO,CAAC,GAAG,CAAC,cAAc,CAAC,IAAI,CAAC,CAAC;iBAClC;aACF;YAED,wBAAwB;YACxB,KAAK,MAAM,IAAI,IAAI,YAAY,EAAE;gBAC/B,IAAI,iBAAiB,CAAC,IAAI,CAAC,IAAI,IAAI,EAAE;oBACnC,iBAAiB,CAAC,IAAI,CAAC,GAAG,IAAI,GAAG,EAAU,CAAC;iBAC7C;gBACD,YAAY,CAAC,IAAI,CAAC,CAAC,OAAO,CACtB,SAAS,CAAC,EAAE,CAAC,iBAAiB,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC,CAAC;aAC1D;SACF;KACF;IACD,OAAO;QACL,MAAM,EAAE,WAAW;QACnB,eAAe,EAAE,mBAAmB,CAAC,iBAAiB,CAAC;KACxD,CAAC;AACJ,CAAC;AAED,SAAS,mBAAmB,CAAC,YAA0B;IACrD,MAAM,eAAe,GAAoB,EAAE,CAAC;IAC5C,KAAK,MAAM,IAAI,IAAI,YAAY,EAAE;QAC/B,eAAe,CAAC,IAAI,CAAC,GAAG,YAAY,CAAC,IAAI,CAAC,CAAC,IAAI,CAAC;KACjD;IACD,OAAO,eAAe,CAAC;AACzB,CAAC;AAED;;;;;;;;;;GAUG;AACH,MAAM,UAAU,+CAA+C,CAC3D,KAAqB,EAAE,QAAkB;IAE3C,MAAM,OAAO,GAAG,IAAI,GAAG,EAAU,CAAC;IAClC,MAAM,MAAM,GAAqB,EAAE,CAAC;IACpC,MAAM,YAAY,GAAiB,EAAE,CAAC;IAEtC,wEAAwE;IACxE,wEAAwE;IACxE,gCAAgC;IAChC,KAAK,MAAM,GAAG,IAAI,QAAQ,CAAC,KAAK,EAAE,EAAE;QAClC,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;KAClB;IAED,MAAM,KAAK,GAAqB,EAAE,CAAC;IACnC,MAAM,KAAK,GAAa,EAAE,CAAC;IAE3B,yCAAyC;IACzC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;IAElB,OAAO,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE;QACvB,MAAM,GAAG,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QACpC,IAAI,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE;YACzB,KAAK,CAAC,GAAG,EAAE,CAAC;YACZ,SAAS;SACV;QACD,MAAM,WAAW,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,KAAK,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;QACjE,IAAI,GAAG,CAAC,MAAM,CAAC,MAAM,KAAK,CAAC,IAAI,WAAW,EAAE;YAC1C,0DAA0D;YAC1D,KAAK,CAAC,GAAG,EAAE,CAAC;YACZ,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;YACjB,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;YACtB,IAAI,WAAW,EAAE;gBACf,KAAK,CAAC,GAAG,EAAE,CAAC;aACb;SACF;aAAM;YACL,oEAAoE;YACpE,8CAA8C;YAC9C,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC7B,KAAK,MAAM,KAAK,IAAI,GAAG,CAAC,MAAM,EAAE;gBAC9B,gEAAgE;gBAChE,oEAAoE;gBACpE,IAAI,YAAY,CAAC,KAAK,CAAC,IAAI,CAAC,IAAI,IAAI,EAAE;oBACpC,YAAY,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,IAAI,GAAG,EAAU,CAAC;iBAC9C;gBACD,YAAY,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;gBAEvC,IAAI,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,IAAI,CAAC,EAAE;oBAC3B,SAAS,CAAE,oDAAoD;iBAChE;gBACD,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;aACnB;SACF;KACF;IACD,OAAO,EAAC,MAAM,EAAE,YAAY,EAAC,CAAC;AAChC,CAAC;AAED;;;;;GAKG;AACH,SAAS,cAAc,CAAC,KAAqB;IAE3C,IAAI,YAA6C,CAAC;IAClD,IAAI,KAAK,CAAC,WAAW,CAAC,YAAY,CAAC,MAAM,KAAK,CAAC,EAAE;QAC/C,YAAY,GAAG,KAAK,CAAC,WAAW,CAAC,MAAM,CAAC;KACzC;SAAM;QACL,IAAI,SAAS,GAAW,IAAI,CAAC;QAC7B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,WAAW,CAAC,YAAY,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YAC9D,KAAK,MAAM,YAAY,IAAI,KAAK,CAAC,WAAW,CAAC,YAAY,CAAC,CAAC,CAAC;iBAClD,aAAa,EAAE;gBACvB,IAAI,YAAY,CAAC,EAAE,KAAK,KAAK,CAAC,EAAE,EAAE;oBAChC,SAAS,GAAG,CAAC,CAAC;oBACd,MAAM;iBACP;aACF;SACF;QACD,YAAY,GAAG,KAAK,CAAC,WAAW,CAAC,WAAW,CAAC,SAAS,CAAC,CAAC;KACzD;IACD,OAAO,YAAY,CAAC;AACtB,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * Executor: Evaluates SymbolicTensor based on feeds.\n */\n\nimport {cast, dispose, memory, Tensor, util} from '@tensorflow/tfjs-core';\n\nimport {ValueError} from '../errors';\nimport {Kwargs} from '../types';\nimport {LruCache} from '../utils/executor_utils';\nimport {toList} from '../utils/generic_utils';\n\nimport {InputLayer} from './input_layer';\nimport {SymbolicTensor} from './topology';\n\n/**\n * Helper function to check the dtype and shape compatibility of a feed value.\n */\nfunction assertFeedCompatibility(key: SymbolicTensor, val: Tensor): Tensor {\n  // Check dtype compatibility.\n  if (key.dtype == null || key.dtype === val.dtype) {\n    //  a.  If types match, return val tensor as is.\n    return val;\n  }\n  try {\n    //  b. Attempt to convert to expected type.\n    return cast(val, key.dtype);\n  } catch (err) {\n    //  c. If conversion fails, return helpful error.\n    throw new ValueError(\n        `The dtype of the feed (${val.dtype}) can not be cast to the dtype ` +\n        `of the key '${key.name}' (${key.dtype}).`);\n  }\n}\n\n/**\n * A concrete Tensor value for a symbolic tensor as the key.\n */\nexport interface Feed {\n  key: SymbolicTensor;\n  value: Tensor;\n}\n\n/**\n * FeedDict: A mapping from unique SymbolicTensors to feed values for them.\n * A feed value is a concrete value represented as an `Tensor`.\n */\nexport class FeedDict {\n  private id2Value: {[id: number]: Tensor} = {};\n  private id2Mask: {[id: number]: Tensor} = {};\n  private name2Id: {[name: string]: number} = {};\n\n  /**\n   * Constructor, optionally does copy-construction.\n   * @param feeds An Array of `Feed`s, or another `FeedDict`, in which case\n   *   copy-construction will be performed.\n   */\n  constructor(feeds?: Feed[]|FeedDict) {\n    if (feeds instanceof FeedDict) {\n      for (const id in feeds.id2Value) {\n        this.id2Value[id] = feeds.id2Value[id];\n        if (id in feeds.id2Mask) {\n          this.id2Mask[id] = feeds.id2Mask[id];\n        }\n      }\n    } else {\n      if (feeds == null) {\n        return;\n      }\n      for (const feed of feeds) {\n        this.add(feed.key, feed.value);\n      }\n    }\n  }\n\n  /**\n   * Add a key-value pair to the FeedDict.\n   *\n   * @param key The key of the feed.\n   * @param value The value of the tensor feed.\n   * @param mask The value of the mask feed (optional).\n   * @returns This `FeedDict`.\n   * @throws ValueError: If the key `SymbolicTensor` already exists in the\n   *   `FeedDict`.\n   */\n  add(key: SymbolicTensor, value: Tensor, mask?: Tensor): FeedDict {\n    if (this.id2Value[key.id] == null) {\n      this.id2Value[key.id] = assertFeedCompatibility(key, value);\n      this.name2Id[key.name] = key.id;\n      if (mask != null) {\n        this.id2Mask[key.id] = mask;\n      }\n    } else {\n      throw new ValueError(`Duplicate key: name=${key.name}, id=${key.id}`);\n    }\n    return this;\n  }\n\n  /**\n   * Add a Feed to the FeedDict.\n   * @param feed The new `Feed` to add.\n   * @returns This `FeedDict`.\n   */\n  addFeed(feed: Feed) {\n    this.add(feed.key, feed.value);\n  }\n\n  /**\n   * Probe whether a key already exists in the FeedDict.\n   * @param key\n   */\n  hasKey(key: SymbolicTensor): boolean {\n    return this.id2Value[key.id] != null;\n  }\n\n  /**\n   * Get all the SymbolicTensor available in this FeedDict.\n   */\n  names(): string[] {\n    return Object.keys(this.name2Id);\n  }\n\n  /**\n   * Get the feed value for given key.\n   * @param key The SymbolicTensor, or its name (as a string), of which the\n   *     value is sought.\n   * @returns If `key` exists, the corresponding feed value.\n   * @throws ValueError: If `key` does not exist in this `FeedDict`.\n   */\n  getValue(key: SymbolicTensor|string): Tensor {\n    if (key instanceof SymbolicTensor) {\n      if (this.id2Value[key.id] == null) {\n        throw new ValueError(`Nonexistent key: ${key.name}`);\n      } else {\n        return this.id2Value[key.id];\n      }\n    } else {\n      const id = this.name2Id[key];\n      if (id == null) {\n        throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);\n      }\n      return this.id2Value[id];\n    }\n  }\n\n  /**\n   * Get the feed mask for given key.\n   * @param key The SymbolicTensor, or its name (as a string), of which the\n   *     value is sought.\n   * @returns If `key` exists, the corresponding feed mask.\n   * @throws ValueError: If `key` does not exist in this `FeedDict`.\n   */\n  getMask(key: SymbolicTensor|string): Tensor {\n    if (key instanceof SymbolicTensor) {\n      if (this.id2Value[key.id] == null) {\n        throw new ValueError(`Nonexistent key: ${key.name}`);\n      } else {\n        return this.id2Mask[key.id];\n      }\n    } else {\n      const id = this.name2Id[key];\n      if (id == null) {\n        throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);\n      }\n      return this.id2Mask[id];\n    }\n  }\n\n  /** Dispose all mask Tensors held by this object. */\n  disposeMasks() {\n    if (this.id2Mask != null) {\n      dispose(this.id2Mask);\n    }\n  }\n}\n\n// Cache for topologically sorted SymbolicTensors for given execution\n// targets (i.e., fetches).\nexport const cachedSorted: LruCache<SymbolicTensor[]> =\n    new LruCache<SymbolicTensor[]>();\n\n// Cache for recipient count maps for given execution targets (i.e., fetches).\nexport const cachedRecipientCounts: LruCache<RecipientCounts> =\n    new LruCache<RecipientCounts>();\n\nexport function updateCacheMaxEntries(maxEntries: number) {\n  if (cachedSorted != null) {\n    cachedSorted.setMaxEntries(maxEntries);\n  }\n  if (cachedRecipientCounts != null) {\n    cachedRecipientCounts.setMaxEntries(maxEntries);\n  }\n}\n\n/**\n * Interface for the optional object used for probing the memory\n * usage and other statistics during execution.\n */\nexport interface ExecutionProbe {\n  /**\n   * Maximum number of tensors that exist during all steps of the\n   * execution. Tensor counts are measured at the beginning of every\n   * step.\n   */\n  maxNumTensors?: number;\n\n  /**\n   * Minimum number of tensors that exist during all steps of the\n   * execution. Tensor counts are measured at the beginning of every\n   * step.\n   */\n  minNumTensors?: number;\n}\n\n/**\n * Execute a SymbolicTensor by using concrete feed values.\n *\n * A `SymbolicTensor` object is a node in a computation graph of TF.js\n * Layers. The object is backed by a source layer and input\n * `SymbolicTensor`s to the source layer. This method evaluates\n * the `call()` method of the source layer, using concrete values of the\n * inputs obtained from either\n * * `feedDict`, if the input key exists in `feedDict`, or else,\n * * a recursive call to `execute()` itself.\n *\n * @param x: The `SymbolicTensor` to execute.\n * @param feedDict: The feed values, as base condition of the recursion.\n *   execution.\n * @param kwargs: Optional keyword arguments.\n * @param probe: A probe object (of interface `ExecutionProbe`) used for\n *   testing memory footprint of `execute` calls.\n * @returns Result of the execution.\n * @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s\n *   encountered during the execution lacks a feed value in `feedDict`.\n */\nexport function execute(\n    fetches: SymbolicTensor|SymbolicTensor[], feedDict: FeedDict,\n    kwargs?: Kwargs, probe?: ExecutionProbe): Tensor|\n    Tensor[]|[Tensor | Tensor[]] {\n  const training: boolean = kwargs == null ? false : kwargs['training'];\n\n  const arrayFetches = Array.isArray(fetches);\n  const fetchArray: SymbolicTensor[] =\n      arrayFetches ? fetches : [fetches];\n\n  const outputNames = fetchArray.map(t => t.name);\n  const finalOutputs: Tensor[] = [];\n  const feedNames = feedDict.names();\n  for (const outputName of outputNames) {\n    if (feedNames.indexOf(outputName) !== -1) {\n      finalOutputs.push(feedDict.getValue(outputName));\n    } else {\n      finalOutputs.push(null);\n    }\n  }\n\n  if (probe != null) {\n    // For optional probing of memory footprint during execution.\n    probe.maxNumTensors = -Infinity;\n    probe.minNumTensors = Infinity;\n  }\n\n  // Check cache.\n  const fetchAndFeedKey =\n      outputNames.join(',') + '|' + feedDict.names().sort().join(',');\n  let sorted: SymbolicTensor[] = cachedSorted.get(fetchAndFeedKey);\n  let recipientCounts: {[fetchName: string]: number};\n  if (sorted == null) {\n    // Cache doesn't contain the desired combination of fetches. Compute\n    // topological sort for the combination for the first time.\n    const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);\n    sorted = out.sorted;\n    recipientCounts = out.recipientCounts;\n\n    // Store results in cache for future use.\n    cachedSorted.put(fetchAndFeedKey, sorted);\n    cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts);\n  }\n  recipientCounts = {};\n  if (!training) {\n    Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey));\n  }\n\n  const internalFeedDict = new FeedDict(feedDict);\n\n  // Start iterative execution on the topologically-sorted SymbolicTensors.\n  for (let i = 0; i < sorted.length; ++i) {\n    if (probe != null) {\n      // For optional probing of memory usage during execution.\n      const numTensors = memory().numTensors;\n      if (numTensors > probe.maxNumTensors) {\n        probe.maxNumTensors = numTensors;\n      }\n      if (numTensors < probe.minNumTensors) {\n        probe.minNumTensors = numTensors;\n      }\n    }\n\n    const symbolic = sorted[i];\n    const srcLayer = symbolic.sourceLayer;\n    if (srcLayer instanceof InputLayer) {\n      continue;\n    }\n    const inputValues: Tensor[] = [];\n    const inputMasks: Tensor[] = [];\n    const tensorsToDispose: Tensor[] = [];\n\n    let maskExists = false;\n    for (const input of symbolic.inputs) {\n      const value = internalFeedDict.getValue(input);\n      const mask = internalFeedDict.getMask(input);\n      inputValues.push(value);\n      inputMasks.push(mask);\n      if (mask != null) {\n        maskExists = true;\n      }\n      if (!training) {\n        recipientCounts[input.name]--;\n        if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) &&\n            outputNames.indexOf(input.name) === -1 && !value.isDisposed &&\n            input.sourceLayer.stateful !== true) {\n          tensorsToDispose.push(value);\n        }\n      }\n    }\n\n    if (maskExists) {\n      kwargs = kwargs || {};\n      kwargs['mask'] = inputMasks[0];\n    }\n    const outputTensors =\n        toList(srcLayer.apply(inputValues, kwargs)) as Tensor[];\n    let outputMask: Tensor|Tensor[] = null;\n    if (srcLayer.supportsMasking) {\n      outputMask = srcLayer.computeMask(inputValues, inputMasks);\n    }\n    const layerOutputs = getNodeOutputs(symbolic);\n    const outputSymbolicTensors =\n        Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];\n    for (let i = 0; i < outputSymbolicTensors.length; ++i) {\n      if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) {\n        internalFeedDict.add(\n            outputSymbolicTensors[i], outputTensors[i],\n            Array.isArray(outputMask) ? outputMask[0] : outputMask);\n      }\n      const index = outputNames.indexOf(outputSymbolicTensors[i].name);\n      if (index !== -1) {\n        finalOutputs[index] = outputTensors[i];\n      }\n    }\n\n    if (!training) {\n      // Clean up Tensors that are no longer needed.\n      dispose(tensorsToDispose);\n    }\n  }\n  // NOTE(cais): Unlike intermediate tensors, we don't discard mask\n  // tensors as we go, because these tensors are sometimes passed over a\n  // series of mutliple layers, i.e., not obeying the immediate input\n  // relations in the graph. If this becomes a memory-usage concern,\n  // we can improve this in the future.\n  internalFeedDict.disposeMasks();\n\n  return arrayFetches ? finalOutputs : finalOutputs[0];\n}\n\ntype RecipientCounts = {\n  [fetchName: string]: number\n};\n\nexport type RecipientMap = {\n  [fetchName: string]: Set<string>;\n};\n\n/**\n * Sort the `SymbolicTensor`s topologically, for an array of fetches.\n *\n * This function calls getTopologicalSortAndRecipientCountsForOneFetch and\n * merges their results.\n *\n * @param fetch The array of fetches requested. Must be a non-empty array.\n * @param feedDict The dictionary of fed values.\n * @returns sorted: Topologically-sorted array of SymbolicTensors.\n *   recipientCounts: Recipient counts for all SymbolicTensors in `sorted`.\n */\nfunction getTopologicalSortAndRecipientCounts(\n    fetches: SymbolicTensor[], feedDict: FeedDict):\n    {sorted: SymbolicTensor[], recipientCounts: RecipientCounts} {\n  util.assert(\n      fetches != null && fetches.length > 0,\n      () => `Expected at least one fetch, got none`);\n\n  let finalSorted: SymbolicTensor[] = [];\n  let finalRecipientMap: RecipientMap = {};\n  if (fetches.length === 1) {\n    // Special-casing 1 fetch for efficiency.\n    const out =\n        getTopologicalSortAndRecipientCou