@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
232 lines (231 loc) • 7.63 kB
TypeScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-layers/dist/initializers" />
import { DataType, serialization, Tensor } from '@tensorflow/tfjs-core';
import { Shape } from './keras_format/common';
import { Distribution, FanMode } from './keras_format/initializer_config';
export declare function checkFanMode(value?: string): void;
export declare function checkDistribution(value?: string): void;
/**
* Initializer base class.
*
* @doc {
* heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'}
*/
export declare abstract class Initializer extends serialization.Serializable {
fromConfigUsesCustomObjects(): boolean;
/**
* Generate an initial value.
* @param shape
* @param dtype
* @return The init value.
*/
abstract apply(shape: Shape, dtype?: DataType): Tensor;
getConfig(): serialization.ConfigDict;
}
export declare class Zeros extends Initializer {
/** @nocollapse */
static className: string;
apply(shape: Shape, dtype?: DataType): Tensor;
}
export declare class Ones extends Initializer {
/** @nocollapse */
static className: string;
apply(shape: Shape, dtype?: DataType): Tensor;
}
export interface ConstantArgs {
/** The value for each element in the variable. */
value: number;
}
export declare class Constant extends Initializer {
/** @nocollapse */
static className: string;
private value;
constructor(args: ConstantArgs);
apply(shape: Shape, dtype?: DataType): Tensor;
getConfig(): serialization.ConfigDict;
}
export interface RandomUniformArgs {
/** Lower bound of the range of random values to generate. */
minval?: number;
/** Upper bound of the range of random values to generate. */
maxval?: number;
/** Used to seed the random generator. */
seed?: number;
}
export declare class RandomUniform extends Initializer {
/** @nocollapse */
static className: string;
readonly DEFAULT_MINVAL = -0.05;
readonly DEFAULT_MAXVAL = 0.05;
private minval;
private maxval;
private seed;
constructor(args: RandomUniformArgs);
apply(shape: Shape, dtype?: DataType): Tensor;
getConfig(): serialization.ConfigDict;
}
export interface RandomNormalArgs {
/** Mean of the random values to generate. */
mean?: number;
/** Standard deviation of the random values to generate. */
stddev?: number;
/** Used to seed the random generator. */
seed?: number;
}
export declare class RandomNormal extends Initializer {
/** @nocollapse */
static className: string;
readonly DEFAULT_MEAN = 0;
readonly DEFAULT_STDDEV = 0.05;
private mean;
private stddev;
private seed;
constructor(args: RandomNormalArgs);
apply(shape: Shape, dtype?: DataType): Tensor;
getConfig(): serialization.ConfigDict;
}
export interface TruncatedNormalArgs {
/** Mean of the random values to generate. */
mean?: number;
/** Standard deviation of the random values to generate. */
stddev?: number;
/** Used to seed the random generator. */
seed?: number;
}
export declare class TruncatedNormal extends Initializer {
/** @nocollapse */
static className: string;
readonly DEFAULT_MEAN = 0;
readonly DEFAULT_STDDEV = 0.05;
private mean;
private stddev;
private seed;
constructor(args: TruncatedNormalArgs);
apply(shape: Shape, dtype?: DataType): Tensor;
getConfig(): serialization.ConfigDict;
}
export interface IdentityArgs {
/**
* Multiplicative factor to apply to the identity matrix.
*/
gain?: number;
}
export declare class Identity extends Initializer {
/** @nocollapse */
static className: string;
private gain;
constructor(args: IdentityArgs);
apply(shape: Shape, dtype?: DataType): Tensor;
getConfig(): serialization.ConfigDict;
}
export interface VarianceScalingArgs {
/** Scaling factor (positive float). */
scale?: number;
/** Fanning mode for inputs and outputs. */
mode?: FanMode;
/** Probabilistic distribution of the values. */
distribution?: Distribution;
/** Random number generator seed. */
seed?: number;
}
export declare class VarianceScaling extends Initializer {
/** @nocollapse */
static className: string;
private scale;
private mode;
private distribution;
private seed;
/**
* Constructor of VarianceScaling.
* @throws ValueError for invalid value in scale.
*/
constructor(args: VarianceScalingArgs);
apply(shape: Shape, dtype?: DataType): Tensor;
getConfig(): serialization.ConfigDict;
}
export interface SeedOnlyInitializerArgs {
/** Random number generator seed. */
seed?: number;
}
export declare class GlorotUniform extends VarianceScaling {
/** @nocollapse */
static className: string;
/**
* Constructor of GlorotUniform
* @param scale
* @param mode
* @param distribution
* @param seed
*/
constructor(args?: SeedOnlyInitializerArgs);
getClassName(): string;
}
export declare class GlorotNormal extends VarianceScaling {
/** @nocollapse */
static className: string;
/**
* Constructor of GlorotNormal.
* @param scale
* @param mode
* @param distribution
* @param seed
*/
constructor(args?: SeedOnlyInitializerArgs);
getClassName(): string;
}
export declare class HeNormal extends VarianceScaling {
/** @nocollapse */
static className: string;
constructor(args?: SeedOnlyInitializerArgs);
getClassName(): string;
}
export declare class HeUniform extends VarianceScaling {
/** @nocollapse */
static className: string;
constructor(args?: SeedOnlyInitializerArgs);
getClassName(): string;
}
export declare class LeCunNormal extends VarianceScaling {
/** @nocollapse */
static className: string;
constructor(args?: SeedOnlyInitializerArgs);
getClassName(): string;
}
export declare class LeCunUniform extends VarianceScaling {
/** @nocollapse */
static className: string;
constructor(args?: SeedOnlyInitializerArgs);
getClassName(): string;
}
export interface OrthogonalArgs extends SeedOnlyInitializerArgs {
/**
* Multiplicative factor to apply to the orthogonal matrix. Defaults to 1.
*/
gain?: number;
}
export declare class Orthogonal extends Initializer {
/** @nocollapse */
static className: string;
readonly DEFAULT_GAIN = 1;
readonly ELEMENTS_WARN_SLOW = 2000;
protected readonly gain: number;
protected readonly seed: number;
constructor(args?: OrthogonalArgs);
apply(shape: Shape, dtype?: DataType): Tensor;
getConfig(): serialization.ConfigDict;
}
/** @docinline */
export type InitializerIdentifier = 'constant' | 'glorotNormal' | 'glorotUniform' | 'heNormal' | 'heUniform' | 'identity' | 'leCunNormal' | 'leCunUniform' | 'ones' | 'orthogonal' | 'randomNormal' | 'randomUniform' | 'truncatedNormal' | 'varianceScaling' | 'zeros' | string;
export declare const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP: {
[identifier in InitializerIdentifier]: string;
};
export declare function serializeInitializer(initializer: Initializer): serialization.ConfigDictValue;
export declare function getInitializer(identifier: InitializerIdentifier | Initializer | serialization.ConfigDict): Initializer;