UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

144 lines (143 loc) 4.99 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/constraints" /> import { serialization, Tensor } from '@tensorflow/tfjs-core'; /** * Base class for functions that impose constraints on weight values * * @doc { * heading: 'Constraints', * subheading: 'Classes', * namespace: 'constraints' * } */ export declare abstract class Constraint extends serialization.Serializable { abstract apply(w: Tensor): Tensor; getConfig(): serialization.ConfigDict; } export interface MaxNormArgs { /** * Maximum norm for incoming weights */ maxValue?: number; /** * Axis along which to calculate norms. * * For instance, in a `Dense` layer the weight matrix * has shape `[inputDim, outputDim]`, * set `axis` to `0` to constrain each weight vector * of length `[inputDim,]`. * In a `Conv2D` layer with `dataFormat="channels_last"`, * the weight tensor has shape * `[rows, cols, inputDepth, outputDepth]`, * set `axis` to `[0, 1, 2]` * to constrain the weights of each filter tensor of size * `[rows, cols, inputDepth]`. */ axis?: number; } export declare class MaxNorm extends Constraint { /** @nocollapse */ static readonly className = "MaxNorm"; private maxValue; private axis; private readonly defaultMaxValue; private readonly defaultAxis; constructor(args: MaxNormArgs); apply(w: Tensor): Tensor; getConfig(): serialization.ConfigDict; } export interface UnitNormArgs { /** * Axis along which to calculate norms. * * For instance, in a `Dense` layer the weight matrix * has shape `[inputDim, outputDim]`, * set `axis` to `0` to constrain each weight vector * of length `[inputDim,]`. * In a `Conv2D` layer with `dataFormat="channels_last"`, * the weight tensor has shape * `[rows, cols, inputDepth, outputDepth]`, * set `axis` to `[0, 1, 2]` * to constrain the weights of each filter tensor of size * `[rows, cols, inputDepth]`. */ axis?: number; } export declare class UnitNorm extends Constraint { /** @nocollapse */ static readonly className = "UnitNorm"; private axis; private readonly defaultAxis; constructor(args: UnitNormArgs); apply(w: Tensor): Tensor; getConfig(): serialization.ConfigDict; } export declare class NonNeg extends Constraint { /** @nocollapse */ static readonly className = "NonNeg"; apply(w: Tensor): Tensor; } export interface MinMaxNormArgs { /** * Minimum norm for incoming weights */ minValue?: number; /** * Maximum norm for incoming weights */ maxValue?: number; /** * Axis along which to calculate norms. * For instance, in a `Dense` layer the weight matrix * has shape `[inputDim, outputDim]`, * set `axis` to `0` to constrain each weight vector * of length `[inputDim,]`. * In a `Conv2D` layer with `dataFormat="channels_last"`, * the weight tensor has shape * `[rows, cols, inputDepth, outputDepth]`, * set `axis` to `[0, 1, 2]` * to constrain the weights of each filter tensor of size * `[rows, cols, inputDepth]`. */ axis?: number; /** * Rate for enforcing the constraint: weights will be rescaled to yield: * `(1 - rate) * norm + rate * norm.clip(minValue, maxValue)`. * Effectively, this means that rate=1.0 stands for strict * enforcement of the constraint, while rate<1.0 means that * weights will be rescaled at each step to slowly move * towards a value inside the desired interval. */ rate?: number; } export declare class MinMaxNorm extends Constraint { /** @nocollapse */ static readonly className = "MinMaxNorm"; private minValue; private maxValue; private rate; private axis; private readonly defaultMinValue; private readonly defaultMaxValue; private readonly defaultRate; private readonly defaultAxis; constructor(args: MinMaxNormArgs); apply(w: Tensor): Tensor; getConfig(): serialization.ConfigDict; } /** @docinline */ export type ConstraintIdentifier = 'maxNorm' | 'minMaxNorm' | 'nonNeg' | 'unitNorm' | string; export declare const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP: { [identifier in ConstraintIdentifier]: string; }; export declare function serializeConstraint(constraint: Constraint): serialization.ConfigDictValue; export declare function deserializeConstraint(config: serialization.ConfigDict, customObjects?: serialization.ConfigDict): Constraint; export declare function getConstraint(identifier: ConstraintIdentifier | serialization.ConfigDict | Constraint): Constraint;