UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

66 lines (65 loc) 3.36 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/keras_format/layers/convolutional_serialization" /> import { DataFormatSerialization, PaddingMode } from '../common'; import { ConstraintSerialization } from '../constraint_config'; import { InitializerSerialization } from '../initializer_config'; import { RegularizerSerialization } from '../regularizer_config'; import { BaseLayerSerialization, LayerConfig } from '../topology_config'; export interface BaseConvLayerConfig extends LayerConfig { kernel_size: number | number[]; strides?: number | number[]; padding?: PaddingMode; data_format?: DataFormatSerialization; dilation_rate?: number | [number] | [number, number]; activation?: string; use_bias?: boolean; kernel_initializer?: InitializerSerialization; bias_initializer?: InitializerSerialization; kernel_constraint?: ConstraintSerialization; bias_constraint?: ConstraintSerialization; kernel_regularizer?: RegularizerSerialization; bias_regularizer?: RegularizerSerialization; activity_regularizer?: RegularizerSerialization; } export interface ConvLayerConfig extends BaseConvLayerConfig { filters: number; } export type Conv1DLayerSerialization = BaseLayerSerialization<'Conv1D', ConvLayerConfig>; export type Conv2DLayerSerialization = BaseLayerSerialization<'Conv2D', ConvLayerConfig>; export type Conv2DTransposeLayerSerialization = BaseLayerSerialization<'Conv2DTranspose', ConvLayerConfig>; export interface SeparableConvLayerConfig extends ConvLayerConfig { depth_multiplier?: number; depthwise_initializer?: InitializerSerialization; pointwise_initializer?: InitializerSerialization; depthwise_regularizer?: RegularizerSerialization; pointwise_regularizer?: RegularizerSerialization; depthwise_constraint?: ConstraintSerialization; pointwise_constraint?: ConstraintSerialization; } export type SeparableConv2DLayerSerialization = BaseLayerSerialization<'SeparableConv2D', ConvLayerConfig>; export interface Cropping2DLayerConfig extends LayerConfig { cropping: number | [number, number] | [[number, number], [number, number]]; data_format?: DataFormatSerialization; } export type Cropping2DLayerSerialization = BaseLayerSerialization<'Cropping2D', Cropping2DLayerConfig>; export interface UpSampling2DLayerConfig extends LayerConfig { size?: number[]; data_format?: DataFormatSerialization; } export type UpSampling2DLayerSerialization = BaseLayerSerialization<'UpSampling2D', UpSampling2DLayerConfig>; export type ConvolutionalLayerSerialization = Conv1DLayerSerialization | Conv2DLayerSerialization | Conv2DTransposeLayerSerialization | SeparableConv2DLayerSerialization | Cropping2DLayerSerialization | UpSampling2DLayerSerialization; export type ConvolutionalLayerClassName = ConvolutionalLayerSerialization['class_name']; /** * A string array of valid ConvolutionalLayer class names. * * This is guaranteed to match the `ConvolutionalLayerClassName` union type. */ export declare const convolutionalLayerClassNames: ConvolutionalLayerClassName[];