@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
130 lines (129 loc) • 5.04 kB
TypeScript
/**
* @license
* Copyright 2020 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/convolutional_recurrent" />
import * as tfc from '@tensorflow/tfjs-core';
import { Tensor } from '@tensorflow/tfjs-core';
import { Activation } from '../activations';
import { Constraint } from '../constraints';
import { Initializer } from '../initializers';
import { DataFormat, PaddingMode, Shape } from '../keras_format/common';
import { Regularizer } from '../regularizers';
import { Kwargs } from '../types';
import { BaseRNNLayerArgs, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, SimpleRNNCellLayerArgs } from './recurrent';
declare interface ConvRNN2DCellArgs extends Omit<SimpleRNNCellLayerArgs, 'units'> {
/**
* The dimensionality of the output space (i.e. the number of filters in the
* convolution).
*/
filters: number;
/**
* The dimensions of the convolution window. If kernelSize is a number, the
* convolutional window will be square.
*/
kernelSize: number | number[];
/**
* The strides of the convolution in each dimension. If strides is a number,
* strides in both dimensions are equal.
*
* Specifying any stride value != 1 is incompatible with specifying any
* `dilationRate` value != 1.
*/
strides?: number | number[];
/**
* Padding mode.
*/
padding?: PaddingMode;
/**
* Format of the data, which determines the ordering of the dimensions in
* the inputs.
*
* `channels_last` corresponds to inputs with shape
* `(batch, ..., channels)`
*
* `channels_first` corresponds to inputs with shape `(batch, channels,
* ...)`.
*
* Defaults to `channels_last`.
*/
dataFormat?: DataFormat;
/**
* The dilation rate to use for the dilated convolution in each dimension.
* Should be an integer or array of two or three integers.
*
* Currently, specifying any `dilationRate` value != 1 is incompatible with
* specifying any `strides` value != 1.
*/
dilationRate?: number | [number] | [number, number];
}
declare abstract class ConvRNN2DCell extends RNNCell {
readonly filters: number;
readonly kernelSize: number[];
readonly strides: number[];
readonly padding: PaddingMode;
readonly dataFormat: DataFormat;
readonly dilationRate: number[];
readonly activation: Activation;
readonly useBias: boolean;
readonly kernelInitializer: Initializer;
readonly recurrentInitializer: Initializer;
readonly biasInitializer: Initializer;
readonly kernelConstraint: Constraint;
readonly recurrentConstraint: Constraint;
readonly biasConstraint: Constraint;
readonly kernelRegularizer: Regularizer;
readonly recurrentRegularizer: Regularizer;
readonly biasRegularizer: Regularizer;
readonly dropout: number;
readonly recurrentDropout: number;
}
declare interface ConvRNN2DLayerArgs extends BaseRNNLayerArgs, ConvRNN2DCellArgs {
}
/**
* Base class for convolutional-recurrent layers.
*/
declare class ConvRNN2D extends RNN {
/** @nocollapse */
static className: string;
readonly cell: ConvRNN2DCell;
constructor(args: ConvRNN2DLayerArgs);
call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[];
computeOutputShape(inputShape: Shape): Shape | Shape[];
getInitialState(inputs: tfc.Tensor): tfc.Tensor[];
resetStates(states?: Tensor | Tensor[], training?: boolean): void;
protected computeSingleOutputShape(inputShape: Shape): Shape;
}
export declare interface ConvLSTM2DCellArgs extends Omit<LSTMCellLayerArgs, 'units'>, ConvRNN2DCellArgs {
}
export declare class ConvLSTM2DCell extends LSTMCell implements ConvRNN2DCell {
/** @nocollapse */
static className: string;
readonly filters: number;
readonly kernelSize: number[];
readonly strides: number[];
readonly padding: PaddingMode;
readonly dataFormat: DataFormat;
readonly dilationRate: number[];
constructor(args: ConvLSTM2DCellArgs);
build(inputShape: Shape | Shape[]): void;
call(inputs: tfc.Tensor[], kwargs: Kwargs): tfc.Tensor[];
getConfig(): tfc.serialization.ConfigDict;
inputConv(x: Tensor, w: Tensor, b?: Tensor, padding?: PaddingMode): tfc.Tensor3D;
recurrentConv(x: Tensor, w: Tensor): tfc.Tensor3D;
}
export declare interface ConvLSTM2DArgs extends Omit<LSTMLayerArgs, 'units' | 'cell'>, ConvRNN2DLayerArgs {
}
export declare class ConvLSTM2D extends ConvRNN2D {
/** @nocollapse */
static className: string;
constructor(args: ConvLSTM2DArgs);
/** @nocollapse */
static fromConfig<T extends tfc.serialization.Serializable>(cls: tfc.serialization.SerializableConstructor<T>, config: tfc.serialization.ConfigDict): T;
}
export {};