UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

232 lines (231 loc) 7.63 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/initializers" /> import { DataType, serialization, Tensor } from '@tensorflow/tfjs-core'; import { Shape } from './keras_format/common'; import { Distribution, FanMode } from './keras_format/initializer_config'; export declare function checkFanMode(value?: string): void; export declare function checkDistribution(value?: string): void; /** * Initializer base class. * * @doc { * heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'} */ export declare abstract class Initializer extends serialization.Serializable { fromConfigUsesCustomObjects(): boolean; /** * Generate an initial value. * @param shape * @param dtype * @return The init value. */ abstract apply(shape: Shape, dtype?: DataType): Tensor; getConfig(): serialization.ConfigDict; } export declare class Zeros extends Initializer { /** @nocollapse */ static className: string; apply(shape: Shape, dtype?: DataType): Tensor; } export declare class Ones extends Initializer { /** @nocollapse */ static className: string; apply(shape: Shape, dtype?: DataType): Tensor; } export interface ConstantArgs { /** The value for each element in the variable. */ value: number; } export declare class Constant extends Initializer { /** @nocollapse */ static className: string; private value; constructor(args: ConstantArgs); apply(shape: Shape, dtype?: DataType): Tensor; getConfig(): serialization.ConfigDict; } export interface RandomUniformArgs { /** Lower bound of the range of random values to generate. */ minval?: number; /** Upper bound of the range of random values to generate. */ maxval?: number; /** Used to seed the random generator. */ seed?: number; } export declare class RandomUniform extends Initializer { /** @nocollapse */ static className: string; readonly DEFAULT_MINVAL = -0.05; readonly DEFAULT_MAXVAL = 0.05; private minval; private maxval; private seed; constructor(args: RandomUniformArgs); apply(shape: Shape, dtype?: DataType): Tensor; getConfig(): serialization.ConfigDict; } export interface RandomNormalArgs { /** Mean of the random values to generate. */ mean?: number; /** Standard deviation of the random values to generate. */ stddev?: number; /** Used to seed the random generator. */ seed?: number; } export declare class RandomNormal extends Initializer { /** @nocollapse */ static className: string; readonly DEFAULT_MEAN = 0; readonly DEFAULT_STDDEV = 0.05; private mean; private stddev; private seed; constructor(args: RandomNormalArgs); apply(shape: Shape, dtype?: DataType): Tensor; getConfig(): serialization.ConfigDict; } export interface TruncatedNormalArgs { /** Mean of the random values to generate. */ mean?: number; /** Standard deviation of the random values to generate. */ stddev?: number; /** Used to seed the random generator. */ seed?: number; } export declare class TruncatedNormal extends Initializer { /** @nocollapse */ static className: string; readonly DEFAULT_MEAN = 0; readonly DEFAULT_STDDEV = 0.05; private mean; private stddev; private seed; constructor(args: TruncatedNormalArgs); apply(shape: Shape, dtype?: DataType): Tensor; getConfig(): serialization.ConfigDict; } export interface IdentityArgs { /** * Multiplicative factor to apply to the identity matrix. */ gain?: number; } export declare class Identity extends Initializer { /** @nocollapse */ static className: string; private gain; constructor(args: IdentityArgs); apply(shape: Shape, dtype?: DataType): Tensor; getConfig(): serialization.ConfigDict; } export interface VarianceScalingArgs { /** Scaling factor (positive float). */ scale?: number; /** Fanning mode for inputs and outputs. */ mode?: FanMode; /** Probabilistic distribution of the values. */ distribution?: Distribution; /** Random number generator seed. */ seed?: number; } export declare class VarianceScaling extends Initializer { /** @nocollapse */ static className: string; private scale; private mode; private distribution; private seed; /** * Constructor of VarianceScaling. * @throws ValueError for invalid value in scale. */ constructor(args: VarianceScalingArgs); apply(shape: Shape, dtype?: DataType): Tensor; getConfig(): serialization.ConfigDict; } export interface SeedOnlyInitializerArgs { /** Random number generator seed. */ seed?: number; } export declare class GlorotUniform extends VarianceScaling { /** @nocollapse */ static className: string; /** * Constructor of GlorotUniform * @param scale * @param mode * @param distribution * @param seed */ constructor(args?: SeedOnlyInitializerArgs); getClassName(): string; } export declare class GlorotNormal extends VarianceScaling { /** @nocollapse */ static className: string; /** * Constructor of GlorotNormal. * @param scale * @param mode * @param distribution * @param seed */ constructor(args?: SeedOnlyInitializerArgs); getClassName(): string; } export declare class HeNormal extends VarianceScaling { /** @nocollapse */ static className: string; constructor(args?: SeedOnlyInitializerArgs); getClassName(): string; } export declare class HeUniform extends VarianceScaling { /** @nocollapse */ static className: string; constructor(args?: SeedOnlyInitializerArgs); getClassName(): string; } export declare class LeCunNormal extends VarianceScaling { /** @nocollapse */ static className: string; constructor(args?: SeedOnlyInitializerArgs); getClassName(): string; } export declare class LeCunUniform extends VarianceScaling { /** @nocollapse */ static className: string; constructor(args?: SeedOnlyInitializerArgs); getClassName(): string; } export interface OrthogonalArgs extends SeedOnlyInitializerArgs { /** * Multiplicative factor to apply to the orthogonal matrix. Defaults to 1. */ gain?: number; } export declare class Orthogonal extends Initializer { /** @nocollapse */ static className: string; readonly DEFAULT_GAIN = 1; readonly ELEMENTS_WARN_SLOW = 2000; protected readonly gain: number; protected readonly seed: number; constructor(args?: OrthogonalArgs); apply(shape: Shape, dtype?: DataType): Tensor; getConfig(): serialization.ConfigDict; } /** @docinline */ export type InitializerIdentifier = 'constant' | 'glorotNormal' | 'glorotUniform' | 'heNormal' | 'heUniform' | 'identity' | 'leCunNormal' | 'leCunUniform' | 'ones' | 'orthogonal' | 'randomNormal' | 'randomUniform' | 'truncatedNormal' | 'varianceScaling' | 'zeros' | string; export declare const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP: { [identifier in InitializerIdentifier]: string; }; export declare function serializeInitializer(initializer: Initializer): serialization.ConfigDictValue; export declare function getInitializer(identifier: InitializerIdentifier | Initializer | serialization.ConfigDict): Initializer;