@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
103 lines • 13.2 kB
JavaScript
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Position embedding implementation based on `tf.layers.Layer`.
*/
/* Original source: keras_nlp/layers/modeling/position_embedding.py */
import { serialization, tidy } from '@tensorflow/tfjs-core';
import { Layer } from '../../../engine/topology';
import { ValueError } from '../../../errors';
import { getInitializer, serializeInitializer } from '../../../initializers';
import { getExactlyOneTensor } from '../../../utils/types_utils';
/**
* A layer which learns a position embedding for input sequences.
*
* This class assumes that in the input tensor, the last dimension corresponds
* to the features, and the dimension before the last corresponds to the
* sequence.
*
* Examples:
*
* Called directly on input.
* ```js
* const layer = new PositionEmbedding({sequenceLength=10});
* layer.call(tf.zeros([8, 10, 16]));
* ```
*
* Combine with a token embedding.
* ```js
* const seqLength = 50;
* const vocabSize = 5000;
* const embedDim = 128;
* const inputs = tf.input({shape: [seqLength]});
* const tokenEmbeddings = tf.layers.embedding({
* inputDim=vocabSize, outputDim=embedDim
* }).apply(inputs);
* const positionEmbeddings = new PositionEmbedding({
* sequenceLength: seqLength
* }).apply(tokenEmbeddings);
* const outputs = tf.add(tokenEmbeddings, positionEmbeddings);
* ```
*
* Reference:
* - [Devlin et al., 2019](https://arxiv.org/abs/1810.04805)
*/
class PositionEmbedding extends Layer {
constructor(args) {
super(args);
if (args.sequenceLength == null) {
throw new ValueError('`sequenceLength` must be an Integer, received `null`.');
}
this.sequenceLength = args.sequenceLength;
this.initializer = getInitializer(args.initializer || 'glorotUniform');
}
getConfig() {
const config = {
'sequenceLength': this.sequenceLength,
'initializer': serializeInitializer(this.initializer),
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
build(inputShape) {
const featureSize = inputShape[inputShape.length - 1];
this.positionEmbeddings = this.addWeight('embeddings', [this.sequenceLength, featureSize], null, this.initializer, null, true);
super.build(inputShape);
}
call(inputs, kwargs) {
return tidy(() => {
var _a;
kwargs.startIndex = (_a = kwargs.startIndex) !== null && _a !== void 0 ? _a : 0;
const shape = getExactlyOneTensor(inputs).shape;
const featureLength = shape[shape.length - 1];
const sequenceLength = shape[shape.length - 2];
// trim to match the length of the input sequence, which might be less
// than the sequence_length of the layer.
const positionEmbeddings = this.positionEmbeddings.read().slice([kwargs.startIndex, 0], [sequenceLength, featureLength]);
return positionEmbeddings.broadcastTo(shape);
});
}
computeOutputShape(inputShape) {
return inputShape;
}
}
/** @nocollapse */
PositionEmbedding.className = 'PositionEmbedding';
export { PositionEmbedding };
serialization.registerClass(PositionEmbedding);
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"position_embedding.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/position_embedding.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,sEAAsE;AACtE,OAAO,EAAU,aAAa,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAGpE,OAAO,EAAE,KAAK,EAAa,MAAM,0BAA0B,CAAC;AAC5D,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAC7C,OAAO,EAAsC,cAAc,EAAE,oBAAoB,EAAE,MAAM,uBAAuB,CAAC;AACjH,OAAO,EAAE,mBAAmB,EAAE,MAAM,4BAA4B,CAAC;AAwBjE;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAgCG;AACH,MAAa,iBAAkB,SAAQ,KAAK;IAO1C,YAAY,IAA2B;QACrC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,IAAI,CAAC,cAAc,IAAI,IAAI,EAAE;YAC/B,MAAM,IAAI,UAAU,CAClB,uDAAuD,CAAC,CAAC;SAC5D;QACD,IAAI,CAAC,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC;QAC1C,IAAI,CAAC,WAAW,GAAG,cAAc,CAAC,IAAI,CAAC,WAAW,IAAI,eAAe,CAAC,CAAC;IACzE,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,gBAAgB,EAAE,IAAI,CAAC,cAAc;YACrC,aAAa,EAAE,oBAAoB,CAAC,IAAI,CAAC,WAAW,CAAC;SACtD,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,KAAK,CAAC,UAAiB;QAC9B,MAAM,WAAW,GAAG,UAAU,CAAC,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QACtD,IAAI,CAAC,kBAAkB,GAAG,IAAI,CAAC,SAAS,CACtC,YAAY,EACZ,CAAC,IAAI,CAAC,cAAc,EAAE,WAAW,CAAC,EAClC,IAAI,EACJ,IAAI,CAAC,WAAW,EAChB,IAAI,EACJ,IAAI,CACL,CAAC;QACF,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC1B,CAAC;IAEQ,IAAI,CACX,MAAuB,EACvB,MAAiC;QAEjC,OAAO,IAAI,CAAC,GAAG,EAAE;;YACf,MAAM,CAAC,UAAU,GAAG,MAAA,MAAM,CAAC,UAAU,mCAAI,CAAC,CAAC;YAC3C,MAAM,KAAK,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC;YAChD,MAAM,aAAa,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC9C,MAAM,cAAc,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC/C,sEAAsE;YACtE,yCAAyC;YACzC,MAAM,kBAAkB,GAAG,IAAI,CAAC,kBAAkB,CAAC,IAAI,EAAE,CAAC,KAAK,CAC7D,CAAC,MAAM,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,CAAC,cAAc,EAAE,aAAa,CAAC,CAAC,CAAC;YAC3D,OAAO,kBAAkB,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC;QAC/C,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,kBAAkB,CAAC,UAAiB;QAC3C,OAAO,UAAU,CAAC;IACpB,CAAC;;AA1DD,kBAAkB;AACF,2BAAS,GAAG,mBAAmB,CAAC;SAFrC,iBAAiB;AA6D9B,aAAa,CAAC,aAAa,CAAC,iBAAiB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Position embedding implementation based on `tf.layers.Layer`.\n */\n\n/* Original source: keras_nlp/layers/modeling/position_embedding.py */\nimport { Tensor, serialization, tidy } from '@tensorflow/tfjs-core';\n\nimport { Shape } from '../../../keras_format/common';\nimport { Layer, LayerArgs } from '../../../engine/topology';\nimport { ValueError } from '../../../errors';\nimport { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../../initializers';\nimport { getExactlyOneTensor } from '../../../utils/types_utils';\nimport { LayerVariable } from '../../../variables';\n\nexport declare interface PositionEmbeddingArgs extends LayerArgs {\n  /**\n   * Integer. The maximum length of the dynamic sequence.\n   */\n  sequenceLength: number;\n\n  /**\n   * The initializer to use for the embedding weights.\n   * Defaults to `\"glorotUniform\"`.\n   */\n  initializer?: Initializer|InitializerIdentifier;\n}\n\nexport declare interface PositionEmbeddingOptions {\n  /**\n   * Integer. Index to start the position embeddings at.\n   * Defaults to 0.\n   */\n  startIndex?: number;\n}\n\n/**\n * A layer which learns a position embedding for input sequences.\n *\n * This class assumes that in the input tensor, the last dimension corresponds\n * to the features, and the dimension before the last corresponds to the\n * sequence.\n *\n * Examples:\n *\n * Called directly on input.\n * ```js\n * const layer = new PositionEmbedding({sequenceLength=10});\n * layer.call(tf.zeros([8, 10, 16]));\n * ```\n *\n * Combine with a token embedding.\n * ```js\n * const seqLength = 50;\n * const vocabSize = 5000;\n * const embedDim = 128;\n * const inputs = tf.input({shape: [seqLength]});\n * const tokenEmbeddings = tf.layers.embedding({\n *     inputDim=vocabSize, outputDim=embedDim\n * }).apply(inputs);\n * const positionEmbeddings = new PositionEmbedding({\n *     sequenceLength: seqLength\n * }).apply(tokenEmbeddings);\n * const outputs = tf.add(tokenEmbeddings, positionEmbeddings);\n * ```\n *\n * Reference:\n *  - [Devlin et al., 2019](https://arxiv.org/abs/1810.04805)\n */\nexport class PositionEmbedding extends Layer {\n  /** @nocollapse */\n  static readonly className = 'PositionEmbedding';\n  private sequenceLength: number;\n  private initializer: Initializer;\n  protected positionEmbeddings: LayerVariable;\n\n  constructor(args: PositionEmbeddingArgs) {\n    super(args);\n    if (args.sequenceLength == null) {\n      throw new ValueError(\n        '`sequenceLength` must be an Integer, received `null`.');\n    }\n    this.sequenceLength = args.sequenceLength;\n    this.initializer = getInitializer(args.initializer || 'glorotUniform');\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      'sequenceLength': this.sequenceLength,\n      'initializer': serializeInitializer(this.initializer),\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override build(inputShape: Shape): void {\n    const featureSize = inputShape[inputShape.length - 1];\n    this.positionEmbeddings = this.addWeight(\n      'embeddings',\n      [this.sequenceLength, featureSize],\n      null,\n      this.initializer,\n      null,\n      true\n    );\n    super.build(inputShape);\n  }\n\n  override call(\n    inputs: Tensor|Tensor[],\n    kwargs?: PositionEmbeddingOptions\n  ): Tensor {\n    return tidy(() => {\n      kwargs.startIndex = kwargs.startIndex ?? 0;\n      const shape = getExactlyOneTensor(inputs).shape;\n      const featureLength = shape[shape.length - 1];\n      const sequenceLength = shape[shape.length - 2];\n      // trim to match the length of the input sequence, which might be less\n      // than the sequence_length of the layer.\n      const positionEmbeddings = this.positionEmbeddings.read().slice(\n        [kwargs.startIndex, 0], [sequenceLength, featureLength]);\n      return positionEmbeddings.broadcastTo(shape);\n    });\n  }\n\n  override computeOutputShape(inputShape: Shape): Shape {\n    return inputShape;\n  }\n}\nserialization.registerClass(PositionEmbedding);\n"]}