@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
96 lines (95 loc) • 4.3 kB
TypeScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-layers/dist/keras_format/layers/recurrent_serialization" />
import { ActivationSerialization } from '../activation_config';
import { ConstraintSerialization } from '../constraint_config';
import { InitializerSerialization } from '../initializer_config';
import { RegularizerSerialization } from '../regularizer_config';
import { BaseLayerSerialization, LayerConfig } from '../topology_config';
import { BaseSerialization } from '../types';
export interface BaseRNNLayerConfig extends LayerConfig {
cell?: RNNCellSerialization | RNNCellSerialization[];
return_sequences?: boolean;
return_state?: boolean;
go_backwards?: boolean;
stateful?: boolean;
unroll?: boolean;
input_dim?: number;
input_length?: number;
}
export interface SimpleRNNCellConfig extends LayerConfig {
units: number;
activation?: ActivationSerialization;
use_bias?: boolean;
kernel_initializer?: InitializerSerialization;
recurrent_initializer?: InitializerSerialization;
bias_initializer?: InitializerSerialization;
kernel_regularizer?: RegularizerSerialization;
recurrent_regularizer?: RegularizerSerialization;
bias_regularizer?: RegularizerSerialization;
kernel_constraint?: ConstraintSerialization;
recurrent_constraint?: ConstraintSerialization;
bias_constraint?: ConstraintSerialization;
dropout?: number;
recurrent_dropout?: number;
}
export type SimpleRNNCellSerialization = BaseSerialization<'SimpleRNNCell', SimpleRNNCellConfig>;
export interface SimpleRNNLayerConfig extends BaseRNNLayerConfig {
units: number;
activation?: ActivationSerialization;
use_bias?: boolean;
kernel_initializer?: InitializerSerialization;
recurrent_initializer?: InitializerSerialization;
bias_initializer?: InitializerSerialization;
kernel_regularizer?: RegularizerSerialization;
recurrent_regularizer?: RegularizerSerialization;
bias_regularizer?: RegularizerSerialization;
kernel_constraint?: ConstraintSerialization;
recurrent_constraint?: ConstraintSerialization;
bias_constraint?: ConstraintSerialization;
dropout?: number;
recurrent_dropout?: number;
}
export type SimpleRNNLayerSerialization = BaseLayerSerialization<'SimpleRNN', SimpleRNNLayerConfig>;
export interface GRUCellConfig extends SimpleRNNCellConfig {
recurrent_activation?: string;
implementation?: number;
}
export type GRUCellSerialization = BaseSerialization<'GRUCell', GRUCellConfig>;
export interface GRULayerConfig extends SimpleRNNLayerConfig {
recurrent_activation?: ActivationSerialization;
implementation?: number;
}
export type GRULayerSerialization = BaseLayerSerialization<'GRU', GRULayerConfig>;
export interface LSTMCellConfig extends SimpleRNNCellConfig {
recurrent_activation?: ActivationSerialization;
unit_forget_bias?: boolean;
implementation?: number;
}
export type LSTMCellSerialization = BaseSerialization<'LSTMCell', LSTMCellConfig>;
export interface LSTMLayerConfig extends SimpleRNNLayerConfig {
recurrent_activation?: ActivationSerialization;
unit_forget_bias?: boolean;
implementation?: number;
}
export type LSTMLayerSerialization = BaseLayerSerialization<'LSTM', LSTMLayerConfig>;
export interface StackedRNNCellsConfig extends LayerConfig {
cells: RNNCellSerialization[];
}
export type StackedRNNCellsSerialization = BaseSerialization<'StackedRNNCells', StackedRNNCellsConfig>;
export type RNNCellSerialization = SimpleRNNCellSerialization | GRUCellSerialization | LSTMCellSerialization | StackedRNNCellsSerialization;
export type RecurrentLayerSerialization = SimpleRNNLayerSerialization | LSTMLayerSerialization | GRULayerSerialization;
export type RecurrentLayerClassName = RecurrentLayerSerialization['class_name'];
/**
* A string array of valid RecurrentLayer class names.
*
* This is guaranteed to match the `RecurrentLayerClassName` union type.
*/
export declare const recurrentLayerClassNames: RecurrentLayerClassName[];