UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

103 lines 13.2 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Position embedding implementation based on `tf.layers.Layer`. */ /* Original source: keras_nlp/layers/modeling/position_embedding.py */ import { serialization, tidy } from '@tensorflow/tfjs-core'; import { Layer } from '../../../engine/topology'; import { ValueError } from '../../../errors'; import { getInitializer, serializeInitializer } from '../../../initializers'; import { getExactlyOneTensor } from '../../../utils/types_utils'; /** * A layer which learns a position embedding for input sequences. * * This class assumes that in the input tensor, the last dimension corresponds * to the features, and the dimension before the last corresponds to the * sequence. * * Examples: * * Called directly on input. * ```js * const layer = new PositionEmbedding({sequenceLength=10}); * layer.call(tf.zeros([8, 10, 16])); * ``` * * Combine with a token embedding. * ```js * const seqLength = 50; * const vocabSize = 5000; * const embedDim = 128; * const inputs = tf.input({shape: [seqLength]}); * const tokenEmbeddings = tf.layers.embedding({ * inputDim=vocabSize, outputDim=embedDim * }).apply(inputs); * const positionEmbeddings = new PositionEmbedding({ * sequenceLength: seqLength * }).apply(tokenEmbeddings); * const outputs = tf.add(tokenEmbeddings, positionEmbeddings); * ``` * * Reference: * - [Devlin et al., 2019](https://arxiv.org/abs/1810.04805) */ class PositionEmbedding extends Layer { constructor(args) { super(args); if (args.sequenceLength == null) { throw new ValueError('`sequenceLength` must be an Integer, received `null`.'); } this.sequenceLength = args.sequenceLength; this.initializer = getInitializer(args.initializer || 'glorotUniform'); } getConfig() { const config = { 'sequenceLength': this.sequenceLength, 'initializer': serializeInitializer(this.initializer), }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } build(inputShape) { const featureSize = inputShape[inputShape.length - 1]; this.positionEmbeddings = this.addWeight('embeddings', [this.sequenceLength, featureSize], null, this.initializer, null, true); super.build(inputShape); } call(inputs, kwargs) { return tidy(() => { var _a; kwargs.startIndex = (_a = kwargs.startIndex) !== null && _a !== void 0 ? _a : 0; const shape = getExactlyOneTensor(inputs).shape; const featureLength = shape[shape.length - 1]; const sequenceLength = shape[shape.length - 2]; // trim to match the length of the input sequence, which might be less // than the sequence_length of the layer. const positionEmbeddings = this.positionEmbeddings.read().slice([kwargs.startIndex, 0], [sequenceLength, featureLength]); return positionEmbeddings.broadcastTo(shape); }); } computeOutputShape(inputShape) { return inputShape; } } /** @nocollapse */ PositionEmbedding.className = 'PositionEmbedding'; export { PositionEmbedding }; serialization.registerClass(PositionEmbedding); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"position_embedding.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/position_embedding.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,sEAAsE;AACtE,OAAO,EAAU,aAAa,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAGpE,OAAO,EAAE,KAAK,EAAa,MAAM,0BAA0B,CAAC;AAC5D,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAC7C,OAAO,EAAsC,cAAc,EAAE,oBAAoB,EAAE,MAAM,uBAAuB,CAAC;AACjH,OAAO,EAAE,mBAAmB,EAAE,MAAM,4BAA4B,CAAC;AAwBjE;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAgCG;AACH,MAAa,iBAAkB,SAAQ,KAAK;IAO1C,YAAY,IAA2B;QACrC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,IAAI,CAAC,cAAc,IAAI,IAAI,EAAE;YAC/B,MAAM,IAAI,UAAU,CAClB,uDAAuD,CAAC,CAAC;SAC5D;QACD,IAAI,CAAC,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC;QAC1C,IAAI,CAAC,WAAW,GAAG,cAAc,CAAC,IAAI,CAAC,WAAW,IAAI,eAAe,CAAC,CAAC;IACzE,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,gBAAgB,EAAE,IAAI,CAAC,cAAc;YACrC,aAAa,EAAE,oBAAoB,CAAC,IAAI,CAAC,WAAW,CAAC;SACtD,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,KAAK,CAAC,UAAiB;QAC9B,MAAM,WAAW,GAAG,UAAU,CAAC,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QACtD,IAAI,CAAC,kBAAkB,GAAG,IAAI,CAAC,SAAS,CACtC,YAAY,EACZ,CAAC,IAAI,CAAC,cAAc,EAAE,WAAW,CAAC,EAClC,IAAI,EACJ,IAAI,CAAC,WAAW,EAChB,IAAI,EACJ,IAAI,CACL,CAAC;QACF,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC1B,CAAC;IAEQ,IAAI,CACX,MAAuB,EACvB,MAAiC;QAEjC,OAAO,IAAI,CAAC,GAAG,EAAE;;YACf,MAAM,CAAC,UAAU,GAAG,MAAA,MAAM,CAAC,UAAU,mCAAI,CAAC,CAAC;YAC3C,MAAM,KAAK,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC;YAChD,MAAM,aAAa,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC9C,MAAM,cAAc,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC/C,sEAAsE;YACtE,yCAAyC;YACzC,MAAM,kBAAkB,GAAG,IAAI,CAAC,kBAAkB,CAAC,IAAI,EAAE,CAAC,KAAK,CAC7D,CAAC,MAAM,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,CAAC,cAAc,EAAE,aAAa,CAAC,CAAC,CAAC;YAC3D,OAAO,kBAAkB,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC;QAC/C,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,kBAAkB,CAAC,UAAiB;QAC3C,OAAO,UAAU,CAAC;IACpB,CAAC;;AA1DD,kBAAkB;AACF,2BAAS,GAAG,mBAAmB,CAAC;SAFrC,iBAAiB;AA6D9B,aAAa,CAAC,aAAa,CAAC,iBAAiB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Position embedding implementation based on `tf.layers.Layer`.\n */\n\n/* Original source: keras_nlp/layers/modeling/position_embedding.py */\nimport { Tensor, serialization, tidy } from '@tensorflow/tfjs-core';\n\nimport { Shape } from '../../../keras_format/common';\nimport { Layer, LayerArgs } from '../../../engine/topology';\nimport { ValueError } from '../../../errors';\nimport { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../../initializers';\nimport { getExactlyOneTensor } from '../../../utils/types_utils';\nimport { LayerVariable } from '../../../variables';\n\nexport declare interface PositionEmbeddingArgs extends LayerArgs {\n  /**\n   * Integer. The maximum length of the dynamic sequence.\n   */\n  sequenceLength: number;\n\n  /**\n   * The initializer to use for the embedding weights.\n   * Defaults to `\"glorotUniform\"`.\n   */\n  initializer?: Initializer|InitializerIdentifier;\n}\n\nexport declare interface PositionEmbeddingOptions {\n  /**\n   * Integer. Index to start the position embeddings at.\n   * Defaults to 0.\n   */\n  startIndex?: number;\n}\n\n/**\n * A layer which learns a position embedding for input sequences.\n *\n * This class assumes that in the input tensor, the last dimension corresponds\n * to the features, and the dimension before the last corresponds to the\n * sequence.\n *\n * Examples:\n *\n * Called directly on input.\n * ```js\n * const layer = new PositionEmbedding({sequenceLength=10});\n * layer.call(tf.zeros([8, 10, 16]));\n * ```\n *\n * Combine with a token embedding.\n * ```js\n * const seqLength = 50;\n * const vocabSize = 5000;\n * const embedDim = 128;\n * const inputs = tf.input({shape: [seqLength]});\n * const tokenEmbeddings = tf.layers.embedding({\n *     inputDim=vocabSize, outputDim=embedDim\n * }).apply(inputs);\n * const positionEmbeddings = new PositionEmbedding({\n *     sequenceLength: seqLength\n * }).apply(tokenEmbeddings);\n * const outputs = tf.add(tokenEmbeddings, positionEmbeddings);\n * ```\n *\n * Reference:\n *  - [Devlin et al., 2019](https://arxiv.org/abs/1810.04805)\n */\nexport class PositionEmbedding extends Layer {\n  /** @nocollapse */\n  static readonly className = 'PositionEmbedding';\n  private sequenceLength: number;\n  private initializer: Initializer;\n  protected positionEmbeddings: LayerVariable;\n\n  constructor(args: PositionEmbeddingArgs) {\n    super(args);\n    if (args.sequenceLength == null) {\n      throw new ValueError(\n        '`sequenceLength` must be an Integer, received `null`.');\n    }\n    this.sequenceLength = args.sequenceLength;\n    this.initializer = getInitializer(args.initializer || 'glorotUniform');\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      'sequenceLength': this.sequenceLength,\n      'initializer': serializeInitializer(this.initializer),\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override build(inputShape: Shape): void {\n    const featureSize = inputShape[inputShape.length - 1];\n    this.positionEmbeddings = this.addWeight(\n      'embeddings',\n      [this.sequenceLength, featureSize],\n      null,\n      this.initializer,\n      null,\n      true\n    );\n    super.build(inputShape);\n  }\n\n  override call(\n    inputs: Tensor|Tensor[],\n    kwargs?: PositionEmbeddingOptions\n  ): Tensor {\n    return tidy(() => {\n      kwargs.startIndex = kwargs.startIndex ?? 0;\n      const shape = getExactlyOneTensor(inputs).shape;\n      const featureLength = shape[shape.length - 1];\n      const sequenceLength = shape[shape.length - 2];\n      // trim to match the length of the input sequence, which might be less\n      // than the sequence_length of the layer.\n      const positionEmbeddings = this.positionEmbeddings.read().slice(\n        [kwargs.startIndex, 0], [sequenceLength, featureLength]);\n      return positionEmbeddings.broadcastTo(shape);\n    });\n  }\n\n  override computeOutputShape(inputShape: Shape): Shape {\n    return inputShape;\n  }\n}\nserialization.registerClass(PositionEmbedding);\n"]}