@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
144 lines (143 loc) • 4.99 kB
TypeScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-layers/dist/constraints" />
import { serialization, Tensor } from '@tensorflow/tfjs-core';
/**
* Base class for functions that impose constraints on weight values
*
* @doc {
* heading: 'Constraints',
* subheading: 'Classes',
* namespace: 'constraints'
* }
*/
export declare abstract class Constraint extends serialization.Serializable {
abstract apply(w: Tensor): Tensor;
getConfig(): serialization.ConfigDict;
}
export interface MaxNormArgs {
/**
* Maximum norm for incoming weights
*/
maxValue?: number;
/**
* Axis along which to calculate norms.
*
* For instance, in a `Dense` layer the weight matrix
* has shape `[inputDim, outputDim]`,
* set `axis` to `0` to constrain each weight vector
* of length `[inputDim,]`.
* In a `Conv2D` layer with `dataFormat="channels_last"`,
* the weight tensor has shape
* `[rows, cols, inputDepth, outputDepth]`,
* set `axis` to `[0, 1, 2]`
* to constrain the weights of each filter tensor of size
* `[rows, cols, inputDepth]`.
*/
axis?: number;
}
export declare class MaxNorm extends Constraint {
/** @nocollapse */
static readonly className = "MaxNorm";
private maxValue;
private axis;
private readonly defaultMaxValue;
private readonly defaultAxis;
constructor(args: MaxNormArgs);
apply(w: Tensor): Tensor;
getConfig(): serialization.ConfigDict;
}
export interface UnitNormArgs {
/**
* Axis along which to calculate norms.
*
* For instance, in a `Dense` layer the weight matrix
* has shape `[inputDim, outputDim]`,
* set `axis` to `0` to constrain each weight vector
* of length `[inputDim,]`.
* In a `Conv2D` layer with `dataFormat="channels_last"`,
* the weight tensor has shape
* `[rows, cols, inputDepth, outputDepth]`,
* set `axis` to `[0, 1, 2]`
* to constrain the weights of each filter tensor of size
* `[rows, cols, inputDepth]`.
*/
axis?: number;
}
export declare class UnitNorm extends Constraint {
/** @nocollapse */
static readonly className = "UnitNorm";
private axis;
private readonly defaultAxis;
constructor(args: UnitNormArgs);
apply(w: Tensor): Tensor;
getConfig(): serialization.ConfigDict;
}
export declare class NonNeg extends Constraint {
/** @nocollapse */
static readonly className = "NonNeg";
apply(w: Tensor): Tensor;
}
export interface MinMaxNormArgs {
/**
* Minimum norm for incoming weights
*/
minValue?: number;
/**
* Maximum norm for incoming weights
*/
maxValue?: number;
/**
* Axis along which to calculate norms.
* For instance, in a `Dense` layer the weight matrix
* has shape `[inputDim, outputDim]`,
* set `axis` to `0` to constrain each weight vector
* of length `[inputDim,]`.
* In a `Conv2D` layer with `dataFormat="channels_last"`,
* the weight tensor has shape
* `[rows, cols, inputDepth, outputDepth]`,
* set `axis` to `[0, 1, 2]`
* to constrain the weights of each filter tensor of size
* `[rows, cols, inputDepth]`.
*/
axis?: number;
/**
* Rate for enforcing the constraint: weights will be rescaled to yield:
* `(1 - rate) * norm + rate * norm.clip(minValue, maxValue)`.
* Effectively, this means that rate=1.0 stands for strict
* enforcement of the constraint, while rate<1.0 means that
* weights will be rescaled at each step to slowly move
* towards a value inside the desired interval.
*/
rate?: number;
}
export declare class MinMaxNorm extends Constraint {
/** @nocollapse */
static readonly className = "MinMaxNorm";
private minValue;
private maxValue;
private rate;
private axis;
private readonly defaultMinValue;
private readonly defaultMaxValue;
private readonly defaultRate;
private readonly defaultAxis;
constructor(args: MinMaxNormArgs);
apply(w: Tensor): Tensor;
getConfig(): serialization.ConfigDict;
}
/** @docinline */
export type ConstraintIdentifier = 'maxNorm' | 'minMaxNorm' | 'nonNeg' | 'unitNorm' | string;
export declare const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP: {
[identifier in ConstraintIdentifier]: string;
};
export declare function serializeConstraint(constraint: Constraint): serialization.ConfigDictValue;
export declare function deserializeConstraint(config: serialization.ConfigDict, customObjects?: serialization.ConfigDict): Constraint;
export declare function getConstraint(identifier: ConstraintIdentifier | serialization.ConfigDict | Constraint): Constraint;