@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
100 lines • 13.8 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/* original source: keras/regularizers.py */
import * as tfc from '@tensorflow/tfjs-core';
import { abs, add, serialization, sum, tidy, zeros } from '@tensorflow/tfjs-core';
import * as K from './backend/tfjs_backend';
import { deserializeKerasObject, serializeKerasObject } from './utils/generic_utils';
function assertObjectArgs(args) {
if (args != null && typeof args !== 'object') {
throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an ` +
`object, but received: ${args}`);
}
}
/**
* Regularizer base class.
*/
export class Regularizer extends serialization.Serializable {
}
class L1L2 extends Regularizer {
constructor(args) {
super();
assertObjectArgs(args);
this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
this.hasL1 = this.l1 !== 0;
this.hasL2 = this.l2 !== 0;
}
/**
* Porting note: Renamed from __call__.
* @param x Variable of which to calculate the regularization score.
*/
apply(x) {
return tidy(() => {
let regularization = zeros([1]);
if (this.hasL1) {
regularization = add(regularization, sum(tfc.mul(this.l1, abs(x))));
}
if (this.hasL2) {
regularization =
add(regularization, sum(tfc.mul(this.l2, K.square(x))));
}
return tfc.reshape(regularization, []);
});
}
getConfig() {
return { 'l1': this.l1, 'l2': this.l2 };
}
/** @nocollapse */
static fromConfig(cls, config) {
return new cls({ l1: config['l1'], l2: config['l2'] });
}
}
/** @nocollapse */
L1L2.className = 'L1L2';
export { L1L2 };
serialization.registerClass(L1L2);
export function l1(args) {
assertObjectArgs(args);
return new L1L2({ l1: args != null ? args.l1 : null, l2: 0 });
}
export function l2(args) {
assertObjectArgs(args);
return new L1L2({ l2: args != null ? args.l2 : null, l1: 0 });
}
// Maps the JavaScript-like identifier keys to the corresponding keras symbols.
export const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
'l1l2': 'L1L2'
};
export function serializeRegularizer(constraint) {
return serializeKerasObject(constraint);
}
export function deserializeRegularizer(config, customObjects = {}) {
return deserializeKerasObject(config, serialization.SerializationMap.getMap().classNameMap, customObjects, 'regularizer');
}
export function getRegularizer(identifier) {
if (identifier == null) {
return null;
}
if (typeof identifier === 'string') {
const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
identifier;
const config = { className, config: {} };
return deserializeRegularizer(config);
}
else if (identifier instanceof Regularizer) {
return identifier;
}
else {
return deserializeRegularizer(identifier);
}
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"regularizers.js","sourceRoot":"","sources":["../../../../../tfjs-layers/src/regularizers.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,4CAA4C;AAE5C,OAAO,KAAK,GAAG,MAAM,uBAAuB,CAAC;AAC7C,OAAO,EAAC,GAAG,EAAE,GAAG,EAAU,aAAa,EAAE,GAAG,EAAU,IAAI,EAAE,KAAK,EAAC,MAAM,uBAAuB,CAAC;AAChG,OAAO,KAAK,CAAC,MAAM,wBAAwB,CAAC;AAC5C,OAAO,EAAC,sBAAsB,EAAE,oBAAoB,EAAC,MAAM,uBAAuB,CAAC;AAEnF,SAAS,gBAAgB,CAAC,IAA4B;IACpD,IAAI,IAAI,IAAI,IAAI,IAAI,OAAO,IAAI,KAAK,QAAQ,EAAE;QAC5C,MAAM,IAAI,KAAK,CACX,kEAAkE;YAClE,yBAAyB,IAAI,EAAE,CAAC,CAAC;KACtC;AACH,CAAC;AAED;;GAEG;AACH,MAAM,OAAgB,WAAY,SAAQ,aAAa,CAAC,YAAY;CAEnE;AAmBD,MAAa,IAAK,SAAQ,WAAW;IAQnC,YAAY,IAAe;QACzB,KAAK,EAAE,CAAC;QAER,gBAAgB,CAAC,IAAI,CAAC,CAAC;QAEvB,IAAI,CAAC,EAAE,GAAG,IAAI,IAAI,IAAI,IAAI,IAAI,CAAC,EAAE,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC;QAC3D,IAAI,CAAC,EAAE,GAAG,IAAI,IAAI,IAAI,IAAI,IAAI,CAAC,EAAE,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC;QAC3D,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,EAAE,KAAK,CAAC,CAAC;QAC3B,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,EAAE,KAAK,CAAC,CAAC;IAC7B,CAAC;IAED;;;OAGG;IACH,KAAK,CAAC,CAAS;QACb,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,cAAc,GAAW,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACxC,IAAI,IAAI,CAAC,KAAK,EAAE;gBACd,cAAc,GAAG,GAAG,CAAC,cAAc,EAAE,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aACrE;YACD,IAAI,IAAI,CAAC,KAAK,EAAE;gBACd,cAAc;oBACV,GAAG,CAAC,cAAc,EAAE,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,EAAE,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;aAC7D;YACD,OAAO,GAAG,CAAC,OAAO,CAAC,cAAc,EAAE,EAAE,CAAC,CAAC;QACzC,CAAC,CAAC,CAAC;IACL,CAAC;IAED,SAAS;QACP,OAAO,EAAC,IAAI,EAAE,IAAI,CAAC,EAAE,EAAE,IAAI,EAAE,IAAI,CAAC,EAAE,EAAC,CAAC;IACxC,CAAC;IAED,kBAAkB;IAClB,MAAM,CAAU,UAAU,CACtB,GAA6C,EAC7C,MAAgC;QAClC,OAAO,IAAI,GAAG,CAAC,EAAC,EAAE,EAAE,MAAM,CAAC,IAAI,CAAW,EAAE,EAAE,EAAE,MAAM,CAAC,IAAI,CAAW,EAAC,CAAC,CAAC;IAC3E,CAAC;;AA7CD,kBAAkB;AACX,cAAS,GAAG,MAAM,CAAC;SAFf,IAAI;AAgDjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC,MAAM,UAAU,EAAE,CAAC,IAAa;IAC9B,gBAAgB,CAAC,IAAI,CAAC,CAAC;IACvB,OAAO,IAAI,IAAI,CAAC,EAAC,EAAE,EAAE,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,EAAE,EAAE,EAAE,CAAC,EAAC,CAAC,CAAC;AAC9D,CAAC;AAED,MAAM,UAAU,EAAE,CAAC,IAAY;IAC7B,gBAAgB,CAAC,IAAI,CAAC,CAAC;IACvB,OAAO,IAAI,IAAI,CAAC,EAAC,EAAE,EAAE,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,EAAE,EAAE,EAAE,CAAC,EAAC,CAAC,CAAC;AAC9D,CAAC;AAKD,+EAA+E;AAC/E,MAAM,CAAC,MAAM,0CAA0C,GACD;IAChD,MAAM,EAAE,MAAM;CACf,CAAC;AAEN,MAAM,UAAU,oBAAoB,CAAC,UAAuB;IAE1D,OAAO,oBAAoB,CAAC,UAAU,CAAC,CAAC;AAC1C,CAAC;AAED,MAAM,UAAU,sBAAsB,CAClC,MAAgC,EAChC,gBAA0C,EAAE;IAC9C,OAAO,sBAAsB,CACzB,MAAM,EAAE,aAAa,CAAC,gBAAgB,CAAC,MAAM,EAAE,CAAC,YAAY,EAC5D,aAAa,EAAE,aAAa,CAAC,CAAC;AACpC,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,UAEW;IACxC,IAAI,UAAU,IAAI,IAAI,EAAE;QACtB,OAAO,IAAI,CAAC;KACb;IACD,IAAI,OAAO,UAAU,KAAK,QAAQ,EAAE;QAClC,MAAM,SAAS,GAAG,UAAU,IAAI,0CAA0C,CAAC,CAAC;YACxE,0CAA0C,CAAC,UAAU,CAAC,CAAC,CAAC;YACxD,UAAU,CAAC;QACf,MAAM,MAAM,GAAG,EAAC,SAAS,EAAE,MAAM,EAAE,EAAE,EAAC,CAAC;QACvC,OAAO,sBAAsB,CAAC,MAAM,CAAC,CAAC;KACvC;SAAM,IAAI,UAAU,YAAY,WAAW,EAAE;QAC5C,OAAO,UAAU,CAAC;KACnB;SAAM;QACL,OAAO,sBAAsB,CAAC,UAAU,CAAC,CAAC;KAC3C;AACH,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/* original source: keras/regularizers.py */\n\nimport * as tfc from '@tensorflow/tfjs-core';\nimport {abs, add, Scalar, serialization, sum, Tensor, tidy, zeros} from '@tensorflow/tfjs-core';\nimport * as K from './backend/tfjs_backend';\nimport {deserializeKerasObject, serializeKerasObject} from './utils/generic_utils';\n\nfunction assertObjectArgs(args: L1Args|L2Args|L1L2Args): void {\n  if (args != null && typeof args !== 'object') {\n    throw new Error(\n        `Argument to L1L2 regularizer's constructor is expected to be an ` +\n        `object, but received: ${args}`);\n  }\n}\n\n/**\n * Regularizer base class.\n */\nexport abstract class Regularizer extends serialization.Serializable {\n  abstract apply(x: Tensor): Scalar;\n}\n\nexport interface L1L2Args {\n  /** L1 regularization rate. Defaults to 0.01. */\n  l1?: number;\n  /** L2 regularization rate. Defaults to 0.01. */\n  l2?: number;\n}\n\nexport interface L1Args {\n  /** L1 regularization rate. Defaults to 0.01. */\n  l1: number;\n}\n\nexport interface L2Args {\n  /** L2 regularization rate. Defaults to 0.01. */\n  l2: number;\n}\n\nexport class L1L2 extends Regularizer {\n  /** @nocollapse */\n  static className = 'L1L2';\n\n  private readonly l1: number;\n  private readonly l2: number;\n  private readonly hasL1: boolean;\n  private readonly hasL2: boolean;\n  constructor(args?: L1L2Args) {\n    super();\n\n    assertObjectArgs(args);\n\n    this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;\n    this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;\n    this.hasL1 = this.l1 !== 0;\n    this.hasL2 = this.l2 !== 0;\n  }\n\n  /**\n   * Porting note: Renamed from __call__.\n   * @param x Variable of which to calculate the regularization score.\n   */\n  apply(x: Tensor): Scalar {\n    return tidy(() => {\n      let regularization: Tensor = zeros([1]);\n      if (this.hasL1) {\n        regularization = add(regularization, sum(tfc.mul(this.l1, abs(x))));\n      }\n      if (this.hasL2) {\n        regularization =\n            add(regularization, sum(tfc.mul(this.l2, K.square(x))));\n      }\n      return tfc.reshape(regularization, []);\n    });\n  }\n\n  getConfig(): serialization.ConfigDict {\n    return {'l1': this.l1, 'l2': this.l2};\n  }\n\n  /** @nocollapse */\n  static override fromConfig<T extends serialization.Serializable>(\n      cls: serialization.SerializableConstructor<T>,\n      config: serialization.ConfigDict): T {\n    return new cls({l1: config['l1'] as number, l2: config['l2'] as number});\n  }\n}\nserialization.registerClass(L1L2);\n\nexport function l1(args?: L1Args) {\n  assertObjectArgs(args);\n  return new L1L2({l1: args != null ? args.l1 : null, l2: 0});\n}\n\nexport function l2(args: L2Args) {\n  assertObjectArgs(args);\n  return new L1L2({l2: args != null ? args.l2 : null, l1: 0});\n}\n\n/** @docinline */\nexport type RegularizerIdentifier = 'l1l2'|string;\n\n// Maps the JavaScript-like identifier keys to the corresponding keras symbols.\nexport const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP:\n    {[identifier in RegularizerIdentifier]: string} = {\n      'l1l2': 'L1L2'\n    };\n\nexport function serializeRegularizer(constraint: Regularizer):\n    serialization.ConfigDictValue {\n  return serializeKerasObject(constraint);\n}\n\nexport function deserializeRegularizer(\n    config: serialization.ConfigDict,\n    customObjects: serialization.ConfigDict = {}): Regularizer {\n  return deserializeKerasObject(\n      config, serialization.SerializationMap.getMap().classNameMap,\n      customObjects, 'regularizer');\n}\n\nexport function getRegularizer(identifier: RegularizerIdentifier|\n                               serialization.ConfigDict|\n                               Regularizer): Regularizer {\n  if (identifier == null) {\n    return null;\n  }\n  if (typeof identifier === 'string') {\n    const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?\n        REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :\n        identifier;\n    const config = {className, config: {}};\n    return deserializeRegularizer(config);\n  } else if (identifier instanceof Regularizer) {\n    return identifier;\n  } else {\n    return deserializeRegularizer(identifier);\n  }\n}\n"]}