UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

99 lines (98 loc) 4.09 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/layers/wrappers" /> import { serialization, Tensor } from '@tensorflow/tfjs-core'; import { Layer, LayerArgs, SymbolicTensor } from '../engine/topology'; import { BidirectionalMergeMode, Shape } from '../keras_format/common'; import { Kwargs } from '../types'; import { RegularizerFn } from '../types'; import { LayerVariable } from '../variables'; import { RNN } from './recurrent'; export declare interface WrapperLayerArgs extends LayerArgs { /** * The layer to be wrapped. */ layer: Layer; } /** * Abstract wrapper base class. * * Wrappers take another layer and augment it in various ways. * Do not use this class as a layer, it is only an abstract base class. * Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers. */ export declare abstract class Wrapper extends Layer { readonly layer: Layer; constructor(args: WrapperLayerArgs); build(inputShape: Shape | Shape[]): void; get trainable(): boolean; set trainable(value: boolean); get trainableWeights(): LayerVariable[]; get nonTrainableWeights(): LayerVariable[]; get updates(): Tensor[]; get losses(): RegularizerFn[]; getWeights(): Tensor[]; setWeights(weights: Tensor[]): void; getConfig(): serialization.ConfigDict; setFastWeightInitDuringBuild(value: boolean): void; /** @nocollapse */ static fromConfig<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>, config: serialization.ConfigDict, customObjects?: serialization.ConfigDict): T; } export declare class TimeDistributed extends Wrapper { /** @nocollapse */ static className: string; constructor(args: WrapperLayerArgs); build(inputShape: Shape | Shape[]): void; computeOutputShape(inputShape: Shape | Shape[]): Shape | Shape[]; call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[]; } export declare function checkBidirectionalMergeMode(value?: string): void; export declare interface BidirectionalLayerArgs extends WrapperLayerArgs { /** * The instance of an `RNN` layer to be wrapped. */ layer: RNN; /** * Mode by which outputs of the forward and backward RNNs are * combined. If `null` or `undefined`, the output will not be * combined, they will be returned as an `Array`. * * If `undefined` (i.e., not provided), defaults to `'concat'`. */ mergeMode?: BidirectionalMergeMode; } export declare class Bidirectional extends Wrapper { /** @nocollapse */ static className: string; mergeMode: BidirectionalMergeMode; private forwardLayer; private backwardLayer; private returnSequences; private returnState; private numConstants?; private _trainable; constructor(args: BidirectionalLayerArgs); get trainable(): boolean; set trainable(value: boolean); getWeights(): Tensor[]; setWeights(weights: Tensor[]): void; computeOutputShape(inputShape: Shape | Shape[]): Shape | Shape[]; apply(inputs: Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[], kwargs?: Kwargs): Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[]; call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[]; resetStates(states?: Tensor | Tensor[]): void; build(inputShape: Shape | Shape[]): void; computeMask(inputs: Tensor | Tensor[], mask?: Tensor | Tensor[]): Tensor | Tensor[]; get trainableWeights(): LayerVariable[]; get nonTrainableWeights(): LayerVariable[]; setFastWeightInitDuringBuild(value: boolean): void; getConfig(): serialization.ConfigDict; /** @nocollapse */ static fromConfig<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>, config: serialization.ConfigDict): T; }