UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

140 lines 23.3 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * TensorFlow.js Layers: Embedding Layer. * * Original source: keras/constraints.py */ import { notEqual, reshape, serialization, tidy, zerosLike } from '@tensorflow/tfjs-core'; import * as K from '../backend/tfjs_backend'; import { getConstraint, serializeConstraint } from '../constraints'; import { Layer } from '../engine/topology'; import { ValueError } from '../errors'; import { getInitializer, serializeInitializer } from '../initializers'; import { getRegularizer, serializeRegularizer } from '../regularizers'; import * as generic_utils from '../utils/generic_utils'; import { getExactlyOneShape, getExactlyOneTensor } from '../utils/types_utils'; class Embedding extends Layer { constructor(args) { super(args); this.embeddings = null; this.DEFAULT_EMBEDDINGS_INITIALIZER = 'randomUniform'; if (args.batchInputShape == null && args.inputShape == null) { // Porting Note: This logic is copied from Layer's constructor, since we // can't do exactly what the Python constructor does for Embedding(). // Specifically, the super constructor can not be called after the // mutation of the `config` argument. let batchSize = null; if (args.batchSize != null) { batchSize = args.batchSize; } if (args.inputLength == null) { // Fix super-constructor to what it would have done if // 'config.inputShape' were (None, ) this.batchInputShape = [batchSize, null]; } else { // Fix super-constructor to what it would have done if // 'config.inputShape' were (config.inputLength, ) this.batchInputShape = [batchSize].concat(generic_utils.toList(args.inputLength)); } } this.inputDim = args.inputDim; generic_utils.assertPositiveInteger(this.inputDim, 'inputDim'); this.outputDim = args.outputDim; generic_utils.assertPositiveInteger(this.outputDim, 'outputDim'); this.embeddingsInitializer = getInitializer(args.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER); this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer); this.activityRegularizer = getRegularizer(args.activityRegularizer); this.embeddingsConstraint = getConstraint(args.embeddingsConstraint); this.maskZero = args.maskZero; this.supportsMasking = args.maskZero; this.inputLength = args.inputLength; } build(inputShape) { this.embeddings = this.addWeight('embeddings', [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, true, this.embeddingsConstraint); this.built = true; } // Override warnOnIncompatibleInputShape because an embedding layer allows // the input to have varying ranks. warnOnIncompatibleInputShape(inputShape) { } computeMask(inputs, mask) { return tidy(() => { if (!this.maskZero) { return null; } else { inputs = getExactlyOneTensor(inputs); return notEqual(inputs, zerosLike(inputs)); } }); } computeOutputShape(inputShape) { inputShape = getExactlyOneShape(inputShape); if (this.inputLength == null) { return [...inputShape, this.outputDim]; } // inputLength can be an array if input is 3D or higher. const inLens = generic_utils.toList(this.inputLength); if (inLens.length !== inputShape.length - 1) { throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` + `input shape has shape ${inputShape}`); } else { let i = 0; for (let k = 0; k < inLens.length; ++k) { const s1 = inLens[k]; const s2 = inputShape[k + 1]; if ((s1 != null) && (s2 != null) && (s1 !== s2)) { throw new ValueError(`"inputLength" is ${this.inputLength}, but received ` + `input shape has shape ${inputShape}`); } else if (s1 == null) { inLens[i] = s2; } i++; } } return [inputShape[0], ...inLens, this.outputDim]; } call(inputs, kwargs) { return tidy(() => { this.invokeCallHook(inputs, kwargs); // Embedding layer accepts only a single input. let input = getExactlyOneTensor(inputs); if (input.dtype !== 'int32') { input = K.cast(input, 'int32'); } const output = K.gather(this.embeddings.read(), reshape(input, [input.size])); return reshape(output, getExactlyOneShape(this.computeOutputShape(input.shape))); }); } getConfig() { const config = { inputDim: this.inputDim, outputDim: this.outputDim, embeddingsInitializer: serializeInitializer(this.embeddingsInitializer), embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer), activityRegularizer: serializeRegularizer(this.activityRegularizer), embeddingsConstraint: serializeConstraint(this.embeddingsConstraint), maskZero: this.maskZero, inputLength: this.inputLength }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } } /** @nocollapse */ Embedding.className = 'Embedding'; export { Embedding }; serialization.registerClass(Embedding); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"embeddings.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/layers/embeddings.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;;;GAIG;AACH,OAAO,EAAC,QAAQ,EAAE,OAAO,EAAE,aAAa,EAAU,IAAI,EAAE,SAAS,EAAC,MAAM,uBAAuB,CAAC;AAEhG,OAAO,KAAK,CAAC,MAAM,yBAAyB,CAAC;AAC7C,OAAO,EAAmC,aAAa,EAAE,mBAAmB,EAAC,MAAM,gBAAgB,CAAC;AACpG,OAAO,EAAC,KAAK,EAAY,MAAM,oBAAoB,CAAC;AACpD,OAAO,EAAC,UAAU,EAAC,MAAM,WAAW,CAAC;AACrC,OAAO,EAAC,cAAc,EAAsC,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAEzG,OAAO,EAAC,cAAc,EAAsC,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAEzG,OAAO,KAAK,aAAa,MAAM,wBAAwB,CAAC;AACxD,OAAO,EAAC,kBAAkB,EAAE,mBAAmB,EAAC,MAAM,sBAAsB,CAAC;AAiD7E,MAAa,SAAU,SAAQ,KAAK;IAgBlC,YAAY,IAAwB;QAClC,KAAK,CAAC,IAAI,CAAC,CAAC;QARN,eAAU,GAAkB,IAAI,CAAC;QAEhC,mCAA8B,GACnC,eAAe,CAAC;QAMlB,IAAI,IAAI,CAAC,eAAe,IAAI,IAAI,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3D,wEAAwE;YACxE,qEAAqE;YACrE,kEAAkE;YAClE,qCAAqC;YACrC,IAAI,SAAS,GAAW,IAAI,CAAC;YAC7B,IAAI,IAAI,CAAC,SAAS,IAAI,IAAI,EAAE;gBAC1B,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;aAC5B;YACD,IAAI,IAAI,CAAC,WAAW,IAAI,IAAI,EAAE;gBAC5B,sDAAsD;gBACtD,oCAAoC;gBACpC,IAAI,CAAC,eAAe,GAAG,CAAC,SAAS,EAAE,IAAI,CAAC,CAAC;aAC1C;iBAAM;gBACL,sDAAsD;gBACtD,kDAAkD;gBAClD,IAAI,CAAC,eAAe;oBAChB,CAAC,SAAS,CAAC,CAAC,MAAM,CAAC,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC;aAChE;SACF;QACD,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,aAAa,CAAC,qBAAqB,CAAC,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAC,CAAC;QAC/D,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;QAChC,aAAa,CAAC,qBAAqB,CAAC,IAAI,CAAC,SAAS,EAAE,WAAW,CAAC,CAAC;QACjE,IAAI,CAAC,qBAAqB,GAAG,cAAc,CACvC,IAAI,CAAC,qBAAqB,IAAI,IAAI,CAAC,8BAA8B,CAAC,CAAC;QACvE,IAAI,CAAC,qBAAqB,GAAG,cAAc,CAAC,IAAI,CAAC,qBAAqB,CAAC,CAAC;QACxE,IAAI,CAAC,mBAAmB,GAAG,cAAc,CAAC,IAAI,CAAC,mBAAmB,CAAC,CAAC;QACpE,IAAI,CAAC,oBAAoB,GAAG,aAAa,CAAC,IAAI,CAAC,oBAAoB,CAAC,CAAC;QACrE,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC,QAAQ,CAAC;QACrC,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,WAAW,CAAC;IACtC,CAAC;IAEe,KAAK,CAAC,UAAyB;QAC7C,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,SAAS,CAC5B,YAAY,EAAE,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,SAAS,CAAC,EAAE,IAAI,CAAC,KAAK,EACzD,IAAI,CAAC,qBAAqB,EAAE,IAAI,CAAC,qBAAqB,EAAE,IAAI,EAC5D,IAAI,CAAC,oBAAoB,CAAC,CAAC;QAC/B,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC;IACpB,CAAC;IAED,0EAA0E;IAC1E,mCAAmC;IAChB,4BAA4B,CAAC,UAAiB,IAAG,CAAC;IAE5D,WAAW,CAAC,MAAuB,EAAE,IAAsB;QAElE,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE;gBAClB,OAAO,IAAI,CAAC;aACb;iBAAM;gBACL,MAAM,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;gBACrC,OAAO,QAAQ,CAAC,MAAM,EAAE,SAAS,CAAC,MAAM,CAAC,CAAC,CAAC;aAC5C;QACH,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC5C,IAAI,IAAI,CAAC,WAAW,IAAI,IAAI,EAAE;YAC5B,OAAO,CAAC,GAAG,UAAU,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;SACxC;QACD,wDAAwD;QACxD,MAAM,MAAM,GAAa,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC;QAChE,IAAI,MAAM,CAAC,MAAM,KAAK,UAAU,CAAC,MAAM,GAAG,CAAC,EAAE;YAC3C,MAAM,IAAI,UAAU,CAChB,oBAAoB,IAAI,CAAC,WAAW,iBAAiB;gBACrD,yBAAyB,UAAU,EAAE,CAAC,CAAC;SAC5C;aAAM;YACL,IAAI,CAAC,GAAG,CAAC,CAAC;YACV,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;gBACtC,MAAM,EAAE,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;gBACrB,MAAM,EAAE,GAAG,UAAU,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;gBAC7B,IAAI,CAAC,EAAE,IAAI,IAAI,CAAC,IAAI,CAAC,EAAE,IAAI,IAAI,CAAC,IAAI,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE;oBAC/C,MAAM,IAAI,UAAU,CAChB,oBAAoB,IAAI,CAAC,WAAW,iBAAiB;wBACrD,yBAAyB,UAAU,EAAE,CAAC,CAAC;iBAC5C;qBAAM,IAAI,EAAE,IAAI,IAAI,EAAE;oBACrB,MAAM,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC;iBAChB;gBACD,CAAC,EAAE,CAAC;aACL;SACF;QACD,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,EAAE,GAAG,MAAM,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;IACpD,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,CAAC,cAAc,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YACpC,+CAA+C;YAC/C,IAAI,KAAK,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;YACxC,IAAI,KAAK,CAAC,KAAK,KAAK,OAAO,EAAE;gBAC3B,KAAK,GAAG,CAAC,CAAC,IAAI,CAAC,KAAK,EAAE,OAAO,CAAC,CAAC;aAChC;YACD,MAAM,MAAM,GACR,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,IAAI,EAAE,EAAE,OAAO,CAAC,KAAK,EAAE,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;YACnE,OAAO,OAAO,CACV,MAAM,EAAE,kBAAkB,CAAC,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QACxE,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,SAAS,EAAE,IAAI,CAAC,SAAS;YACzB,qBAAqB,EAAE,oBAAoB,CAAC,IAAI,CAAC,qBAAqB,CAAC;YACvE,qBAAqB,EAAE,oBAAoB,CAAC,IAAI,CAAC,qBAAqB,CAAC;YACvE,mBAAmB,EAAE,oBAAoB,CAAC,IAAI,CAAC,mBAAmB,CAAC;YACnE,oBAAoB,EAAE,mBAAmB,CAAC,IAAI,CAAC,oBAAoB,CAAC;YACpE,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,WAAW,EAAE,IAAI,CAAC,WAAW;SAC9B,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AArID,kBAAkB;AACX,mBAAS,GAAG,WAAW,AAAd,CAAe;SAFpB,SAAS;AAwItB,aAAa,CAAC,aAAa,CAAC,SAAS,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * TensorFlow.js Layers: Embedding Layer.\n *\n * Original source: keras/constraints.py\n */\nimport {notEqual, reshape, serialization, Tensor, tidy, zerosLike} from '@tensorflow/tfjs-core';\n\nimport * as K from '../backend/tfjs_backend';\nimport {Constraint, ConstraintIdentifier, getConstraint, serializeConstraint} from '../constraints';\nimport {Layer, LayerArgs} from '../engine/topology';\nimport {ValueError} from '../errors';\nimport {getInitializer, Initializer, InitializerIdentifier, serializeInitializer} from '../initializers';\nimport {Shape} from '../keras_format/common';\nimport {getRegularizer, Regularizer, RegularizerIdentifier, serializeRegularizer} from '../regularizers';\nimport {Kwargs} from '../types';\nimport * as generic_utils from '../utils/generic_utils';\nimport {getExactlyOneShape, getExactlyOneTensor} from '../utils/types_utils';\nimport {LayerVariable} from '../variables';\n\nexport declare interface EmbeddingLayerArgs extends LayerArgs {\n  /**\n   * Integer > 0. Size of the vocabulary, i.e. maximum integer index + 1.\n   */\n  inputDim: number;\n  /**\n   * Integer >= 0. Dimension of the dense embedding.\n   */\n  outputDim: number;\n  /**\n   * Initializer for the `embeddings` matrix.\n   */\n  embeddingsInitializer?: InitializerIdentifier|Initializer;\n  /**\n   * Regularizer function applied to the `embeddings` matrix.\n   */\n  embeddingsRegularizer?: RegularizerIdentifier|Regularizer;\n  /**\n   * Regularizer function applied to the activation.\n   */\n  activityRegularizer?: RegularizerIdentifier|Regularizer;\n  /**\n   * Constraint function applied to the `embeddings` matrix.\n   */\n  embeddingsConstraint?: ConstraintIdentifier|Constraint;\n  /**\n   * Whether the input value 0 is a special \"padding\" value that should be\n   * masked out. This is useful when using recurrent layers which may take\n   * variable length input.\n   *\n   * If this is `True` then all subsequent layers in the model need to support\n   * masking or an exception will be raised. If maskZero is set to `True`, as a\n   * consequence, index 0 cannot be used in the vocabulary (inputDim should\n   * equal size of vocabulary + 1).\n   */\n  maskZero?: boolean;\n  /**\n   * Length of input sequences, when it is constant.\n   *\n   * This argument is required if you are going to connect `flatten` then\n   * `dense` layers upstream (without it, the shape of the dense outputs cannot\n   * be computed).\n   */\n  inputLength?: number|number[];\n}\n\nexport class Embedding extends Layer {\n  /** @nocollapse */\n  static className = 'Embedding';\n  private inputDim: number;\n  private outputDim: number;\n  private embeddingsInitializer: Initializer;\n  private maskZero: boolean;\n  private inputLength: number|number[];\n\n  private embeddings: LayerVariable = null;\n\n  readonly DEFAULT_EMBEDDINGS_INITIALIZER: InitializerIdentifier =\n      'randomUniform';\n  private readonly embeddingsRegularizer?: Regularizer;\n  private readonly embeddingsConstraint?: Constraint;\n\n  constructor(args: EmbeddingLayerArgs) {\n    super(args);\n    if (args.batchInputShape == null && args.inputShape == null) {\n      // Porting Note: This logic is copied from Layer's constructor, since we\n      // can't do exactly what the Python constructor does for Embedding().\n      // Specifically, the super constructor can not be called after the\n      // mutation of the `config` argument.\n      let batchSize: number = null;\n      if (args.batchSize != null) {\n        batchSize = args.batchSize;\n      }\n      if (args.inputLength == null) {\n        // Fix super-constructor to what it would have done if\n        // 'config.inputShape' were (None, )\n        this.batchInputShape = [batchSize, null];\n      } else {\n        // Fix super-constructor to what it would have done if\n        // 'config.inputShape' were (config.inputLength, )\n        this.batchInputShape =\n            [batchSize].concat(generic_utils.toList(args.inputLength));\n      }\n    }\n    this.inputDim = args.inputDim;\n    generic_utils.assertPositiveInteger(this.inputDim, 'inputDim');\n    this.outputDim = args.outputDim;\n    generic_utils.assertPositiveInteger(this.outputDim, 'outputDim');\n    this.embeddingsInitializer = getInitializer(\n        args.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER);\n    this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);\n    this.activityRegularizer = getRegularizer(args.activityRegularizer);\n    this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);\n    this.maskZero = args.maskZero;\n    this.supportsMasking = args.maskZero;\n    this.inputLength = args.inputLength;\n  }\n\n  public override build(inputShape: Shape|Shape[]): void {\n    this.embeddings = this.addWeight(\n        'embeddings', [this.inputDim, this.outputDim], this.dtype,\n        this.embeddingsInitializer, this.embeddingsRegularizer, true,\n        this.embeddingsConstraint);\n    this.built = true;\n  }\n\n  // Override warnOnIncompatibleInputShape because an embedding layer allows\n  // the input to have varying ranks.\n  protected override warnOnIncompatibleInputShape(inputShape: Shape) {}\n\n  override computeMask(inputs: Tensor|Tensor[], mask?: Tensor|Tensor[]):\n      Tensor {\n    return tidy(() => {\n      if (!this.maskZero) {\n        return null;\n      } else {\n        inputs = getExactlyOneTensor(inputs);\n        return notEqual(inputs, zerosLike(inputs));\n      }\n    });\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    inputShape = getExactlyOneShape(inputShape);\n    if (this.inputLength == null) {\n      return [...inputShape, this.outputDim];\n    }\n    // inputLength can be an array if input is 3D or higher.\n    const inLens: number[] = generic_utils.toList(this.inputLength);\n    if (inLens.length !== inputShape.length - 1) {\n      throw new ValueError(\n          `\"inputLength\" is ${this.inputLength}, but received ` +\n          `input shape has shape ${inputShape}`);\n    } else {\n      let i = 0;\n      for (let k = 0; k < inLens.length; ++k) {\n        const s1 = inLens[k];\n        const s2 = inputShape[k + 1];\n        if ((s1 != null) && (s2 != null) && (s1 !== s2)) {\n          throw new ValueError(\n              `\"inputLength\" is ${this.inputLength}, but received ` +\n              `input shape has shape ${inputShape}`);\n        } else if (s1 == null) {\n          inLens[i] = s2;\n        }\n        i++;\n      }\n    }\n    return [inputShape[0], ...inLens, this.outputDim];\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    return tidy(() => {\n      this.invokeCallHook(inputs, kwargs);\n      // Embedding layer accepts only a single input.\n      let input = getExactlyOneTensor(inputs);\n      if (input.dtype !== 'int32') {\n        input = K.cast(input, 'int32');\n      }\n      const output =\n          K.gather(this.embeddings.read(), reshape(input, [input.size]));\n      return reshape(\n          output, getExactlyOneShape(this.computeOutputShape(input.shape)));\n    });\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      inputDim: this.inputDim,\n      outputDim: this.outputDim,\n      embeddingsInitializer: serializeInitializer(this.embeddingsInitializer),\n      embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer),\n      activityRegularizer: serializeRegularizer(this.activityRegularizer),\n      embeddingsConstraint: serializeConstraint(this.embeddingsConstraint),\n      maskZero: this.maskZero,\n      inputLength: this.inputLength\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(Embedding);\n"]}