@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
99 lines (98 loc) • 4.09 kB
TypeScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/wrappers" />
import { serialization, Tensor } from '@tensorflow/tfjs-core';
import { Layer, LayerArgs, SymbolicTensor } from '../engine/topology';
import { BidirectionalMergeMode, Shape } from '../keras_format/common';
import { Kwargs } from '../types';
import { RegularizerFn } from '../types';
import { LayerVariable } from '../variables';
import { RNN } from './recurrent';
export declare interface WrapperLayerArgs extends LayerArgs {
/**
* The layer to be wrapped.
*/
layer: Layer;
}
/**
* Abstract wrapper base class.
*
* Wrappers take another layer and augment it in various ways.
* Do not use this class as a layer, it is only an abstract base class.
* Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
*/
export declare abstract class Wrapper extends Layer {
readonly layer: Layer;
constructor(args: WrapperLayerArgs);
build(inputShape: Shape | Shape[]): void;
get trainable(): boolean;
set trainable(value: boolean);
get trainableWeights(): LayerVariable[];
get nonTrainableWeights(): LayerVariable[];
get updates(): Tensor[];
get losses(): RegularizerFn[];
getWeights(): Tensor[];
setWeights(weights: Tensor[]): void;
getConfig(): serialization.ConfigDict;
setFastWeightInitDuringBuild(value: boolean): void;
/** @nocollapse */
static fromConfig<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>, config: serialization.ConfigDict, customObjects?: serialization.ConfigDict): T;
}
export declare class TimeDistributed extends Wrapper {
/** @nocollapse */
static className: string;
constructor(args: WrapperLayerArgs);
build(inputShape: Shape | Shape[]): void;
computeOutputShape(inputShape: Shape | Shape[]): Shape | Shape[];
call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[];
}
export declare function checkBidirectionalMergeMode(value?: string): void;
export declare interface BidirectionalLayerArgs extends WrapperLayerArgs {
/**
* The instance of an `RNN` layer to be wrapped.
*/
layer: RNN;
/**
* Mode by which outputs of the forward and backward RNNs are
* combined. If `null` or `undefined`, the output will not be
* combined, they will be returned as an `Array`.
*
* If `undefined` (i.e., not provided), defaults to `'concat'`.
*/
mergeMode?: BidirectionalMergeMode;
}
export declare class Bidirectional extends Wrapper {
/** @nocollapse */
static className: string;
mergeMode: BidirectionalMergeMode;
private forwardLayer;
private backwardLayer;
private returnSequences;
private returnState;
private numConstants?;
private _trainable;
constructor(args: BidirectionalLayerArgs);
get trainable(): boolean;
set trainable(value: boolean);
getWeights(): Tensor[];
setWeights(weights: Tensor[]): void;
computeOutputShape(inputShape: Shape | Shape[]): Shape | Shape[];
apply(inputs: Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[], kwargs?: Kwargs): Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[];
call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[];
resetStates(states?: Tensor | Tensor[]): void;
build(inputShape: Shape | Shape[]): void;
computeMask(inputs: Tensor | Tensor[], mask?: Tensor | Tensor[]): Tensor | Tensor[];
get trainableWeights(): LayerVariable[];
get nonTrainableWeights(): LayerVariable[];
setFastWeightInitDuringBuild(value: boolean): void;
getConfig(): serialization.ConfigDict;
/** @nocollapse */
static fromConfig<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>, config: serialization.ConfigDict): T;
}