@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
149 lines (148 loc) • 5.42 kB
TypeScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-layers/dist/engine/executor" />
/**
* Executor: Evaluates SymbolicTensor based on feeds.
*/
import { Tensor } from '@tensorflow/tfjs-core';
import { Kwargs } from '../types';
import { LruCache } from '../utils/executor_utils';
import { SymbolicTensor } from './topology';
/**
* A concrete Tensor value for a symbolic tensor as the key.
*/
export interface Feed {
key: SymbolicTensor;
value: Tensor;
}
/**
* FeedDict: A mapping from unique SymbolicTensors to feed values for them.
* A feed value is a concrete value represented as an `Tensor`.
*/
export declare class FeedDict {
private id2Value;
private id2Mask;
private name2Id;
/**
* Constructor, optionally does copy-construction.
* @param feeds An Array of `Feed`s, or another `FeedDict`, in which case
* copy-construction will be performed.
*/
constructor(feeds?: Feed[] | FeedDict);
/**
* Add a key-value pair to the FeedDict.
*
* @param key The key of the feed.
* @param value The value of the tensor feed.
* @param mask The value of the mask feed (optional).
* @returns This `FeedDict`.
* @throws ValueError: If the key `SymbolicTensor` already exists in the
* `FeedDict`.
*/
add(key: SymbolicTensor, value: Tensor, mask?: Tensor): FeedDict;
/**
* Add a Feed to the FeedDict.
* @param feed The new `Feed` to add.
* @returns This `FeedDict`.
*/
addFeed(feed: Feed): void;
/**
* Probe whether a key already exists in the FeedDict.
* @param key
*/
hasKey(key: SymbolicTensor): boolean;
/**
* Get all the SymbolicTensor available in this FeedDict.
*/
names(): string[];
/**
* Get the feed value for given key.
* @param key The SymbolicTensor, or its name (as a string), of which the
* value is sought.
* @returns If `key` exists, the corresponding feed value.
* @throws ValueError: If `key` does not exist in this `FeedDict`.
*/
getValue(key: SymbolicTensor | string): Tensor;
/**
* Get the feed mask for given key.
* @param key The SymbolicTensor, or its name (as a string), of which the
* value is sought.
* @returns If `key` exists, the corresponding feed mask.
* @throws ValueError: If `key` does not exist in this `FeedDict`.
*/
getMask(key: SymbolicTensor | string): Tensor;
/** Dispose all mask Tensors held by this object. */
disposeMasks(): void;
}
export declare const cachedSorted: LruCache<SymbolicTensor[]>;
export declare const cachedRecipientCounts: LruCache<RecipientCounts>;
export declare function updateCacheMaxEntries(maxEntries: number): void;
/**
* Interface for the optional object used for probing the memory
* usage and other statistics during execution.
*/
export interface ExecutionProbe {
/**
* Maximum number of tensors that exist during all steps of the
* execution. Tensor counts are measured at the beginning of every
* step.
*/
maxNumTensors?: number;
/**
* Minimum number of tensors that exist during all steps of the
* execution. Tensor counts are measured at the beginning of every
* step.
*/
minNumTensors?: number;
}
/**
* Execute a SymbolicTensor by using concrete feed values.
*
* A `SymbolicTensor` object is a node in a computation graph of TF.js
* Layers. The object is backed by a source layer and input
* `SymbolicTensor`s to the source layer. This method evaluates
* the `call()` method of the source layer, using concrete values of the
* inputs obtained from either
* * `feedDict`, if the input key exists in `feedDict`, or else,
* * a recursive call to `execute()` itself.
*
* @param x: The `SymbolicTensor` to execute.
* @param feedDict: The feed values, as base condition of the recursion.
* execution.
* @param kwargs: Optional keyword arguments.
* @param probe: A probe object (of interface `ExecutionProbe`) used for
* testing memory footprint of `execute` calls.
* @returns Result of the execution.
* @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s
* encountered during the execution lacks a feed value in `feedDict`.
*/
export declare function execute(fetches: SymbolicTensor | SymbolicTensor[], feedDict: FeedDict, kwargs?: Kwargs, probe?: ExecutionProbe): Tensor | Tensor[] | [Tensor | Tensor[]];
type RecipientCounts = {
[fetchName: string]: number;
};
export type RecipientMap = {
[fetchName: string]: Set<string>;
};
/**
* Sort the `SymbolicTensor`s topologically, for a single fetch.
*
* This helper function processes the upstream SymbolicTensors of a single
* fetch.
*
* @param fetch The single fetch requested.
* @param feedDict The dictionary of fed values.
* @returns sorted: Topologically-sorted array of SymbolicTensors.
* recipientMap: Recipient names for all SymbolicTensors in `sorted`.
*/
export declare function getTopologicalSortAndRecipientCountsForOneFetch(fetch: SymbolicTensor, feedDict: FeedDict): {
sorted: SymbolicTensor[];
recipientMap: RecipientMap;
};
export {};