UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

149 lines (148 loc) 5.42 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/engine/executor" /> /** * Executor: Evaluates SymbolicTensor based on feeds. */ import { Tensor } from '@tensorflow/tfjs-core'; import { Kwargs } from '../types'; import { LruCache } from '../utils/executor_utils'; import { SymbolicTensor } from './topology'; /** * A concrete Tensor value for a symbolic tensor as the key. */ export interface Feed { key: SymbolicTensor; value: Tensor; } /** * FeedDict: A mapping from unique SymbolicTensors to feed values for them. * A feed value is a concrete value represented as an `Tensor`. */ export declare class FeedDict { private id2Value; private id2Mask; private name2Id; /** * Constructor, optionally does copy-construction. * @param feeds An Array of `Feed`s, or another `FeedDict`, in which case * copy-construction will be performed. */ constructor(feeds?: Feed[] | FeedDict); /** * Add a key-value pair to the FeedDict. * * @param key The key of the feed. * @param value The value of the tensor feed. * @param mask The value of the mask feed (optional). * @returns This `FeedDict`. * @throws ValueError: If the key `SymbolicTensor` already exists in the * `FeedDict`. */ add(key: SymbolicTensor, value: Tensor, mask?: Tensor): FeedDict; /** * Add a Feed to the FeedDict. * @param feed The new `Feed` to add. * @returns This `FeedDict`. */ addFeed(feed: Feed): void; /** * Probe whether a key already exists in the FeedDict. * @param key */ hasKey(key: SymbolicTensor): boolean; /** * Get all the SymbolicTensor available in this FeedDict. */ names(): string[]; /** * Get the feed value for given key. * @param key The SymbolicTensor, or its name (as a string), of which the * value is sought. * @returns If `key` exists, the corresponding feed value. * @throws ValueError: If `key` does not exist in this `FeedDict`. */ getValue(key: SymbolicTensor | string): Tensor; /** * Get the feed mask for given key. * @param key The SymbolicTensor, or its name (as a string), of which the * value is sought. * @returns If `key` exists, the corresponding feed mask. * @throws ValueError: If `key` does not exist in this `FeedDict`. */ getMask(key: SymbolicTensor | string): Tensor; /** Dispose all mask Tensors held by this object. */ disposeMasks(): void; } export declare const cachedSorted: LruCache<SymbolicTensor[]>; export declare const cachedRecipientCounts: LruCache<RecipientCounts>; export declare function updateCacheMaxEntries(maxEntries: number): void; /** * Interface for the optional object used for probing the memory * usage and other statistics during execution. */ export interface ExecutionProbe { /** * Maximum number of tensors that exist during all steps of the * execution. Tensor counts are measured at the beginning of every * step. */ maxNumTensors?: number; /** * Minimum number of tensors that exist during all steps of the * execution. Tensor counts are measured at the beginning of every * step. */ minNumTensors?: number; } /** * Execute a SymbolicTensor by using concrete feed values. * * A `SymbolicTensor` object is a node in a computation graph of TF.js * Layers. The object is backed by a source layer and input * `SymbolicTensor`s to the source layer. This method evaluates * the `call()` method of the source layer, using concrete values of the * inputs obtained from either * * `feedDict`, if the input key exists in `feedDict`, or else, * * a recursive call to `execute()` itself. * * @param x: The `SymbolicTensor` to execute. * @param feedDict: The feed values, as base condition of the recursion. * execution. * @param kwargs: Optional keyword arguments. * @param probe: A probe object (of interface `ExecutionProbe`) used for * testing memory footprint of `execute` calls. * @returns Result of the execution. * @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s * encountered during the execution lacks a feed value in `feedDict`. */ export declare function execute(fetches: SymbolicTensor | SymbolicTensor[], feedDict: FeedDict, kwargs?: Kwargs, probe?: ExecutionProbe): Tensor | Tensor[] | [Tensor | Tensor[]]; type RecipientCounts = { [fetchName: string]: number; }; export type RecipientMap = { [fetchName: string]: Set<string>; }; /** * Sort the `SymbolicTensor`s topologically, for a single fetch. * * This helper function processes the upstream SymbolicTensors of a single * fetch. * * @param fetch The single fetch requested. * @param feedDict The dictionary of fed values. * @returns sorted: Topologically-sorted array of SymbolicTensors. * recipientMap: Recipient names for all SymbolicTensors in `sorted`. */ export declare function getTopologicalSortAndRecipientCountsForOneFetch(fetch: SymbolicTensor, feedDict: FeedDict): { sorted: SymbolicTensor[]; recipientMap: RecipientMap; }; export {};