@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
68 lines (67 loc) • 2.81 kB
TypeScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-layers/dist/keras_format/initializer_config" />
import { BaseSerialization } from './types';
/** @docinline */
export type FanMode = 'fanIn' | 'fanOut' | 'fanAvg';
export declare const VALID_FAN_MODE_VALUES: string[];
export type FanModeSerialization = 'fan_in' | 'fan_out' | 'fan_avg';
/** @docinline */
export type Distribution = 'normal' | 'uniform' | 'truncatedNormal';
export declare const VALID_DISTRIBUTION_VALUES: string[];
export type DistributionSerialization = 'normal' | 'uniform' | 'truncated_normal';
export type ZerosSerialization = BaseSerialization<'Zeros', {}>;
export type OnesSerialization = BaseSerialization<'Ones', {}>;
export type ConstantConfig = {
value: number;
};
export type ConstantSerialization = BaseSerialization<'Constant', ConstantConfig>;
export type RandomNormalConfig = {
mean?: number;
stddev?: number;
seed?: number;
};
export type RandomNormalSerialization = BaseSerialization<'RandomNormal', RandomNormalConfig>;
export type RandomUniformConfig = {
minval?: number;
maxval?: number;
seed?: number;
};
export type RandomUniformSerialization = BaseSerialization<'RandomUniform', RandomUniformConfig>;
export type TruncatedNormalConfig = {
mean?: number;
stddev?: number;
seed?: number;
};
export type TruncatedNormalSerialization = BaseSerialization<'TruncatedNormal', TruncatedNormalConfig>;
export type VarianceScalingConfig = {
scale?: number;
mode?: FanModeSerialization;
distribution?: DistributionSerialization;
seed?: number;
};
export type VarianceScalingSerialization = BaseSerialization<'VarianceScaling', VarianceScalingConfig>;
export type OrthogonalConfig = {
seed?: number;
gain?: number;
};
export type OrthogonalSerialization = BaseSerialization<'Orthogonal', OrthogonalConfig>;
export type IdentityConfig = {
gain?: number;
};
export type IdentitySerialization = BaseSerialization<'Identity', IdentityConfig>;
export type InitializerSerialization = ZerosSerialization | OnesSerialization | ConstantSerialization | RandomUniformSerialization | RandomNormalSerialization | TruncatedNormalSerialization | IdentitySerialization | VarianceScalingSerialization | OrthogonalSerialization;
export type InitializerClassName = InitializerSerialization['class_name'];
/**
* A string array of valid Initializer class names.
*
* This is guaranteed to match the `InitializerClassName` union type.
*/
export declare const initializerClassNames: InitializerClassName[];