@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
197 lines (196 loc) • 6.8 kB
TypeScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/normalization" />
import { serialization, Tensor } from '@tensorflow/tfjs-core';
import { Constraint, ConstraintIdentifier } from '../constraints';
import { Layer, LayerArgs } from '../engine/topology';
import { Initializer, InitializerIdentifier } from '../initializers';
import { Shape } from '../keras_format/common';
import { Regularizer, RegularizerIdentifier } from '../regularizers';
import { Kwargs } from '../types';
/**
* Applies batch normalization on x given mean, var, beta and gamma.
*
* I.e. returns:
* `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
*
* @param x Input tensor.
* @param mean Mean of batch.
* @param variance Variance of batch.
* @param beta Tensor with which to center the input.
* @param gamma Tensor by which to scale the input.
* @param epsilon Fuzz factor.
* @returns The result of the batch normalization.
*/
export declare function batchNormalization(x: Tensor, mean: Tensor, variance: Tensor, beta?: Tensor, gamma?: Tensor, epsilon?: number): Tensor;
/**
* Batch normalization for use in training (not inference).
*
* @param x Input tensor to be normalized.
* @param gamma Tensor by which to scale the input.
* @param beta Tensor by which to center the input.
* @param reductionAxes Axes over which to normalize.
* @param epsilon Fuzz factor.
* @returns An `Array` of three `Tensors`:
* [normalized tensor, mean of input, variance of input].
*/
export declare function normalizeBatchInTraining(x: Tensor, gamma: Tensor, beta: Tensor, reductionAxes: number[], epsilon?: number): [Tensor, Tensor, Tensor];
export declare interface BatchNormalizationLayerArgs extends LayerArgs {
/**
* The integer axis that should be normalized (typically the features axis).
* Defaults to -1.
*
* For instance, after a `Conv2D` layer with `data_format="channels_first"`,
* set `axis=1` in `batchNormalization`.
*/
axis?: number;
/**
* Momentum of the moving average. Defaults to 0.99.
*/
momentum?: number;
/**
* Small float added to the variance to avoid dividing by zero. Defaults to
* 1e-3.
*/
epsilon?: number;
/**
* If `true`, add offset of `beta` to normalized tensor.
* If `false`, `beta` is ignored.
* Defaults to `true`.
*/
center?: boolean;
/**
* If `true`, multiply by `gamma`.
* If `false`, `gamma` is not used.
* When the next layer is linear (also e.g. `nn.relu`),
* this can be disabled since the scaling will be done by the next layer.
* Defaults to `true`.
*/
scale?: boolean;
/**
* Initializer for the beta weight.
* Defaults to 'zeros'.
*/
betaInitializer?: InitializerIdentifier | Initializer;
/**
* Initializer for the gamma weight.
* Defaults to `ones`.
*/
gammaInitializer?: InitializerIdentifier | Initializer;
/**
* Initializer for the moving mean.
* Defaults to `zeros`
*/
movingMeanInitializer?: InitializerIdentifier | Initializer;
/**
* Initializer for the moving variance.
* Defaults to 'Ones'.
*/
movingVarianceInitializer?: InitializerIdentifier | Initializer;
/**
* Constraint for the beta weight.
*/
betaConstraint?: ConstraintIdentifier | Constraint;
/**
* Constraint for gamma weight.
*/
gammaConstraint?: ConstraintIdentifier | Constraint;
/**
* Regularizer for the beta weight.
*/
betaRegularizer?: RegularizerIdentifier | Regularizer;
/**
* Regularizer for the gamma weight.
*/
gammaRegularizer?: RegularizerIdentifier | Regularizer;
}
export declare class BatchNormalization extends Layer {
/** @nocollapse */
static className: string;
private readonly axis;
private readonly momentum;
private readonly epsilon;
private readonly center;
private readonly scale;
private readonly betaInitializer;
private readonly gammaInitializer;
private readonly movingMeanInitializer;
private readonly movingVarianceInitializer;
private readonly betaConstraint;
private readonly gammaConstraint;
private readonly betaRegularizer;
private readonly gammaRegularizer;
private gamma;
private beta;
private movingMean;
private movingVariance;
constructor(args?: BatchNormalizationLayerArgs);
build(inputShape: Shape | Shape[]): void;
call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[];
getConfig(): serialization.ConfigDict;
}
export interface LayerNormalizationLayerArgs extends LayerArgs {
/**
* The axis or axes that should be normalized (typically, the feature axis).
* Defaults to -1 (the last axis).
*/
axis?: number | number[];
/**
* A small positive float added to variance to avoid division by zero.
* Defaults to 1e-3.
*/
epsilon?: number;
/**
* If `true`, add offset of `beta` to normalized tensor.
* If `false`, `beta` is ignored.
* Default: `true`.
*/
center?: boolean;
/**
* If `true`, multiply output by `gamma`.
* If `false`, `gamma` is not used.
* When the next layer is linear, this can be disabled since scaling will
* be done by the next layer.
* Default: `true`.
*/
scale?: boolean;
/**
* Initializer for the beta weight.
* Default: `'zeros'`.
*/
betaInitializer?: InitializerIdentifier | Initializer;
/**
* Initializer for the gamma weight.
* Default: `'ones'`.
*/
gammaInitializer?: InitializerIdentifier | Initializer;
/** Regularizer for the beta weight. */
betaRegularizer?: RegularizerIdentifier | Regularizer;
/** Regularizer for the gamma weight. */
gammaRegularizer?: RegularizerIdentifier | Regularizer;
}
export declare class LayerNormalization extends Layer {
/** @nocollapse */
static className: string;
private axis;
readonly epsilon: number;
readonly center: boolean;
readonly scale: boolean;
readonly betaInitializer: Initializer;
readonly gammaInitializer: Initializer;
readonly betaRegularizer: Regularizer;
readonly gammaRegularizer: Regularizer;
private gamma;
private beta;
constructor(args?: LayerNormalizationLayerArgs);
build(inputShape: Shape | Shape[]): void;
call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[];
getConfig(): serialization.ConfigDict;
}