@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
66 lines (65 loc) • 3.12 kB
TypeScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-layers/dist/keras_format/layers/core_serialization" />
import { ActivationSerialization } from '../activation_config';
import { Shape } from '../common';
import { ConstraintSerialization } from '../constraint_config';
import { InitializerSerialization } from '../initializer_config';
import { RegularizerSerialization } from '../regularizer_config';
import { BaseLayerSerialization, LayerConfig } from '../topology_config';
export interface DropoutLayerConfig extends LayerConfig {
rate: number;
noise_shape?: number[];
seed?: number;
}
export type DropoutLayerSerialization = BaseLayerSerialization<'Dropout', DropoutLayerConfig>;
export interface DenseLayerConfig extends LayerConfig {
units: number;
activation?: ActivationSerialization;
use_bias?: boolean;
input_dim?: number;
kernel_initializer?: InitializerSerialization;
bias_initializer?: InitializerSerialization;
kernel_constraint?: ConstraintSerialization;
bias_constraint?: ConstraintSerialization;
kernel_regularizer?: RegularizerSerialization;
bias_regularizer?: RegularizerSerialization;
activity_regularizer?: RegularizerSerialization;
}
export type DenseLayerSerialization = BaseLayerSerialization<'Dense', DenseLayerConfig>;
export type FlattenLayerSerialization = BaseLayerSerialization<'Flatten', LayerConfig>;
export interface ActivationLayerConfig extends LayerConfig {
activation: ActivationSerialization;
}
export type ActivationLayerSerialization = BaseLayerSerialization<'Activation', ActivationLayerConfig>;
export interface RepeatVectorLayerConfig extends LayerConfig {
n: number;
}
export type RepeatVectorLayerSerialization = BaseLayerSerialization<'RepeatVector', RepeatVectorLayerConfig>;
export interface ReshapeLayerConfig extends LayerConfig {
target_shape: Shape;
}
export type ReshapeLayerSerialization = BaseLayerSerialization<'Reshape', ReshapeLayerConfig>;
export interface PermuteLayerConfig extends LayerConfig {
dims: number[];
}
export type PermuteLayerSerialization = BaseLayerSerialization<'Permute', PermuteLayerConfig>;
export interface MaskingLayerConfig extends LayerConfig {
maskValue: number;
}
export type MaskingLayerSerialization = BaseLayerSerialization<'Masking', MaskingLayerConfig>;
export type CoreLayerSerialization = DropoutLayerSerialization | DenseLayerSerialization | FlattenLayerSerialization | ActivationLayerSerialization | RepeatVectorLayerSerialization | ReshapeLayerSerialization | PermuteLayerSerialization | MaskingLayerSerialization;
export type CoreLayerClassName = CoreLayerSerialization['class_name'];
/**
* A string array of valid CoreLayer class names.
*
* This is guaranteed to match the `CoreLayerClassName` union type.
*/
export declare const coreLayerClassNames: CoreLayerClassName[];