UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

130 lines (129 loc) 5.04 kB
/** * @license * Copyright 2020 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/layers/convolutional_recurrent" /> import * as tfc from '@tensorflow/tfjs-core'; import { Tensor } from '@tensorflow/tfjs-core'; import { Activation } from '../activations'; import { Constraint } from '../constraints'; import { Initializer } from '../initializers'; import { DataFormat, PaddingMode, Shape } from '../keras_format/common'; import { Regularizer } from '../regularizers'; import { Kwargs } from '../types'; import { BaseRNNLayerArgs, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, SimpleRNNCellLayerArgs } from './recurrent'; declare interface ConvRNN2DCellArgs extends Omit<SimpleRNNCellLayerArgs, 'units'> { /** * The dimensionality of the output space (i.e. the number of filters in the * convolution). */ filters: number; /** * The dimensions of the convolution window. If kernelSize is a number, the * convolutional window will be square. */ kernelSize: number | number[]; /** * The strides of the convolution in each dimension. If strides is a number, * strides in both dimensions are equal. * * Specifying any stride value != 1 is incompatible with specifying any * `dilationRate` value != 1. */ strides?: number | number[]; /** * Padding mode. */ padding?: PaddingMode; /** * Format of the data, which determines the ordering of the dimensions in * the inputs. * * `channels_last` corresponds to inputs with shape * `(batch, ..., channels)` * * `channels_first` corresponds to inputs with shape `(batch, channels, * ...)`. * * Defaults to `channels_last`. */ dataFormat?: DataFormat; /** * The dilation rate to use for the dilated convolution in each dimension. * Should be an integer or array of two or three integers. * * Currently, specifying any `dilationRate` value != 1 is incompatible with * specifying any `strides` value != 1. */ dilationRate?: number | [number] | [number, number]; } declare abstract class ConvRNN2DCell extends RNNCell { readonly filters: number; readonly kernelSize: number[]; readonly strides: number[]; readonly padding: PaddingMode; readonly dataFormat: DataFormat; readonly dilationRate: number[]; readonly activation: Activation; readonly useBias: boolean; readonly kernelInitializer: Initializer; readonly recurrentInitializer: Initializer; readonly biasInitializer: Initializer; readonly kernelConstraint: Constraint; readonly recurrentConstraint: Constraint; readonly biasConstraint: Constraint; readonly kernelRegularizer: Regularizer; readonly recurrentRegularizer: Regularizer; readonly biasRegularizer: Regularizer; readonly dropout: number; readonly recurrentDropout: number; } declare interface ConvRNN2DLayerArgs extends BaseRNNLayerArgs, ConvRNN2DCellArgs { } /** * Base class for convolutional-recurrent layers. */ declare class ConvRNN2D extends RNN { /** @nocollapse */ static className: string; readonly cell: ConvRNN2DCell; constructor(args: ConvRNN2DLayerArgs); call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[]; computeOutputShape(inputShape: Shape): Shape | Shape[]; getInitialState(inputs: tfc.Tensor): tfc.Tensor[]; resetStates(states?: Tensor | Tensor[], training?: boolean): void; protected computeSingleOutputShape(inputShape: Shape): Shape; } export declare interface ConvLSTM2DCellArgs extends Omit<LSTMCellLayerArgs, 'units'>, ConvRNN2DCellArgs { } export declare class ConvLSTM2DCell extends LSTMCell implements ConvRNN2DCell { /** @nocollapse */ static className: string; readonly filters: number; readonly kernelSize: number[]; readonly strides: number[]; readonly padding: PaddingMode; readonly dataFormat: DataFormat; readonly dilationRate: number[]; constructor(args: ConvLSTM2DCellArgs); build(inputShape: Shape | Shape[]): void; call(inputs: tfc.Tensor[], kwargs: Kwargs): tfc.Tensor[]; getConfig(): tfc.serialization.ConfigDict; inputConv(x: Tensor, w: Tensor, b?: Tensor, padding?: PaddingMode): tfc.Tensor3D; recurrentConv(x: Tensor, w: Tensor): tfc.Tensor3D; } export declare interface ConvLSTM2DArgs extends Omit<LSTMLayerArgs, 'units' | 'cell'>, ConvRNN2DLayerArgs { } export declare class ConvLSTM2D extends ConvRNN2D { /** @nocollapse */ static className: string; constructor(args: ConvLSTM2DArgs); /** @nocollapse */ static fromConfig<T extends tfc.serialization.Serializable>(cls: tfc.serialization.SerializableConstructor<T>, config: tfc.serialization.ConfigDict): T; } export {};