@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
112 lines • 18.6 kB
JavaScript
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Cached MHA layer based on `MultiHeadAttention`.
*/
/* Original source: keras_nlp/layers/modeling/cached_multi_head_attention.py */
import { cast, einsum, mul, reciprocal, serialization, sqrt, stack, tidy } from '@tensorflow/tfjs-core';
import { ValueError } from '../../../errors';
import { MultiHeadAttention } from '../multihead_attention';
import { sliceUpdate } from '../utils';
/**
* MultiHeadAttention layer with cache support.
*
* This layer is suitable for use in autoregressive decoding. It can be use
* to cache decoder self-attention and cross-attention. The forward pass
* can happen in one of three modes:
* - No cache, same as regular multi-head attention.
* - Static cache (`cacheUpdateIndex` is None). In this case, the
* cached key/value projections will be used and the input values will
* be ignored.
* - Updated cache (`cacheUpdateIndex` is not None). In this case, new
* key/value projections are computed using the input, and spliced into
* the cache at the specified index.
*
* Note that caching is useful only during inference and should not be used
* during training.
*
* We use the notation `B`, `T`, `S` below, where `B` is the batch dimension,
* `T` is the target sequence length, and `S` in the source sequence length.
* Note that during generative decoding, `T` is usually 1 (you are
* generating a target sequence of length one to predict the next token).
*
* Returns:
* An `(attentionOutput, cache)` tuple. `attentionOutput` is the result
* of the computation, of shape `(B, T, dim)`, where `T` is for target
* sequence shapes and `dim` is the query input last dimension if
* `outputShape` is `null`. Otherwise, the multi-head outputs are
* projected to the shape specified by `outputShape`. `cache` is the
* updated cache.
*/
export class CachedMultiHeadAttention extends MultiHeadAttention {
call(query, kwargs) {
return this.callAndReturnCache(query, kwargs)[0];
}
/**
* Exactly like `call` except also returns the updated cache.
*/
callAndReturnCache(query, { value, key, attentionMask, cache, cacheUpdateIndex }) {
return tidy(() => {
if (!this.builtFromSignature) {
this.buildFromSignature(query.shape, value.shape, key ? key.shape : null);
}
if (key == null) {
key = value;
}
query = this.queryDense.apply(query);
// If cache is not `null`, we will use the cache to compute the final key
// and value tensors. If `cacheUpdateIndex` is not `null`, we will first
// update the cache before use. To do this, we first call the
// `keyDense` and `valueDense` layers, and copy the outputs into the
// cache at the specified index. `cache = null` handles the training
// case, where we don't use the cache at all.
if (cache != null) {
const keyCache = cache.gather([0], 1).squeeze();
const valueCache = cache.gather([1], 1).squeeze();
if (cacheUpdateIndex == null) {
key = keyCache;
value = valueCache;
}
else {
const keyUpdate = this.keyDense.apply(key);
const valueUpdate = this.valueDense.apply(value);
const start = [0, cacheUpdateIndex, 0, 0];
key = sliceUpdate(keyCache, start, keyUpdate);
value = sliceUpdate(valueCache, start, valueUpdate);
cache = stack([key, value], 1);
}
}
else {
if (cacheUpdateIndex != null) {
throw new ValueError('`cacheUpdateIndex` should not be set if `cache` is `null`. ' +
`Received: cache=${cache}, cacheUpdateIndex=${cacheUpdateIndex}`);
}
key = this.keyDense.apply(key);
value = this.valueDense.apply(value);
}
query = mul(query, reciprocal(sqrt(cast(this.keyDim, query.dtype))));
let attentionScores = einsum(this.dotProductEquation, key, query);
attentionScores = this.maskedSoftmax(attentionScores, attentionMask);
attentionScores = this.dropoutLayer.apply(attentionScores);
let attentionOutput = einsum(this.combineEquation, attentionScores, value);
attentionOutput = this.outputDense.apply(attentionOutput);
return [attentionOutput, cache];
});
}
}
serialization.registerClass(CachedMultiHeadAttention);
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"cached_multihead_attention.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,+EAA+E;AAC/E,OAAO,EAAU,IAAI,EAAE,MAAM,EAAE,GAAG,EAAE,UAAU,EAAE,aAAa,EAAE,IAAI,EAAE,KAAK,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAEhH,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAC7C,OAAO,EAAE,kBAAkB,EAAE,MAAM,wBAAwB,CAAC;AAC5D,OAAO,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAiDvC;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;AACH,MAAM,OAAO,wBAAyB,SAAQ,kBAAkB;IAErD,IAAI,CACX,KAAa,EAAE,MAAuC;QAEtD,OAAO,IAAI,CAAC,kBAAkB,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;IACnD,CAAC;IAED;;OAEG;IACH,kBAAkB,CAChB,KAAa,EACb,EACE,KAAK,EACL,GAAG,EACH,aAAa,EACb,KAAK,EACL,gBAAgB,EACiB;QAEnC,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,CAAC,IAAI,CAAC,kBAAkB,EAAE;gBAC5B,IAAI,CAAC,kBAAkB,CACrB,KAAK,CAAC,KAAK,EAAE,KAAK,CAAC,KAAK,EAAE,GAAG,CAAC,CAAC,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;aACrD;YACD,IAAI,GAAG,IAAI,IAAI,EAAE;gBACf,GAAG,GAAG,KAAK,CAAC;aACb;YAED,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;YAC/C,yEAAyE;YACzE,wEAAwE;YACxE,6DAA6D;YAC7D,oEAAoE;YACpE,oEAAoE;YACpE,6CAA6C;YAC7C,IAAI,KAAK,IAAI,IAAI,EAAE;gBACjB,MAAM,QAAQ,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;gBAChD,MAAM,UAAU,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;gBAClD,IAAI,gBAAgB,IAAI,IAAI,EAAE;oBAC5B,GAAG,GAAG,QAAQ,CAAC;oBACf,KAAK,GAAG,UAAU,CAAC;iBACpB;qBAAM;oBACL,MAAM,SAAS,GAAG,IAAI,CAAC,QAAQ,CAAC,KAAK,CAAC,GAAG,CAAW,CAAC;oBACrD,MAAM,WAAW,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;oBAC3D,MAAM,KAAK,GAAG,CAAC,CAAC,EAAE,gBAAgB,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;oBAC1C,GAAG,GAAG,WAAW,CAAC,QAAQ,EAAE,KAAK,EAAE,SAAS,CAAC,CAAC;oBAC9C,KAAK,GAAG,WAAW,CAAC,UAAU,EAAE,KAAK,EAAE,WAAW,CAAC,CAAC;oBACpD,KAAK,GAAG,KAAK,CAAC,CAAC,GAAG,EAAE,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC;iBAChC;aACF;iBAAM;gBACL,IAAI,gBAAgB,IAAI,IAAI,EAAE;oBAC5B,MAAM,IAAI,UAAU,CAClB,6DAA6D;wBAC7D,mBAAmB,KAAK,sBAAsB,gBAAgB,EAAE,CACjE,CAAC;iBACH;gBACD,GAAG,GAAG,IAAI,CAAC,QAAQ,CAAC,KAAK,CAAC,GAAG,CAAW,CAAC;gBACzC,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;aAChD;YAED,KAAK,GAAG,GAAG,CAAC,KAAK,EAAE,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YACrE,IAAI,eAAe,GAAG,MAAM,CAAC,IAAI,CAAC,kBAAkB,EAAE,GAAG,EAAE,KAAK,CAAC,CAAC;YAClE,eAAe,GAAG,IAAI,CAAC,aAAa,CAAC,eAAe,EAAE,aAAa,CAAC,CAAC;YACrE,eAAe,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,eAAe,CAAW,CAAC;YAErE,IAAI,eAAe,GACjB,MAAM,CAAC,IAAI,CAAC,eAAe,EAAE,eAAe,EAAE,KAAK,CAAC,CAAC;YACvD,eAAe,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,eAAe,CAAW,CAAC;YAEpE,OAAO,CAAC,eAAe,EAAE,KAAK,CAAC,CAAC;QAClC,CAAC,CAAC,CAAC;IACL,CAAC;CACF;AACD,aAAa,CAAC,aAAa,CAAC,wBAAwB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Cached MHA layer based on `MultiHeadAttention`.\n */\n\n/* Original source: keras_nlp/layers/modeling/cached_multi_head_attention.py */\nimport { Tensor, cast, einsum, mul, reciprocal, serialization, sqrt, stack, tidy } from '@tensorflow/tfjs-core';\n\nimport { ValueError } from '../../../errors';\nimport { MultiHeadAttention } from '../multihead_attention';\nimport { sliceUpdate } from '../utils';\n\nexport declare interface CachedMultiHeadAttentionOptions {\n  /**\n   * Query `Tensor` of shape `(B, T, dim)`.\n   */\n\n  /**\n   * Value `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*`\n   * must equal `S` and match the shape of `attentionMask`. If `cache` is\n   * not `null`, `S*` can be any length less than `S`, and the computed\n   * value will be spliced into `cache` at `cacheUpdateIndex`.\n   */\n  value: Tensor;\n\n  /**\n   * Key `Tensor` of shape `(B, S*, dim)`.  If `cache` is `null`, `S*` must\n   * equal `S` and match the shape of `attentionMask`. If `cache` is not `null`,\n   * `S*` can be any length less than `S`, and the computed value will be\n   * spliced into `cache` at `cacheUpdateIndex`.\n   */\n  key?: Tensor;\n\n  /**\n   * A boolean mask of shape `(B, T, S)`. `attentionMask` prevents\n   * attention to certain positions. The boolean mask specifies which\n   * query elements can attend to which key elements, 1 indicates\n   * attention and 0 indicates no attention. Broadcasting can happen for\n   * the missing batch dimensions and the head dimension.\n   */\n  attentionMask?: Tensor;\n\n  /**\n   * A dense float Tensor. The key/value cache, of shape\n   * `[B, 2, S, numHeads, keyDims]`, where `S` must agree with the\n   * `attentionMask` shape. This argument is intended for use during\n   * generation to avoid recomputing intermediate state.\n   */\n  cache?: Tensor;\n\n  /**\n   * Integer or Integer `Tensor`. The index at which to update `cache`\n   * (usually the index of the current token being processed when running\n   * generation). If `cacheUpdateIndex=null` while `cache` is set, the cache\n   * will not be updated.\n   */\n  cacheUpdateIndex?: number;\n}\n\n/**\n * MultiHeadAttention layer with cache support.\n *\n * This layer is suitable for use in autoregressive decoding. It can be use\n * to cache decoder self-attention and cross-attention. The forward pass\n * can happen in one of three modes:\n * - No cache, same as regular multi-head attention.\n * - Static cache (`cacheUpdateIndex` is None). In this case, the\n *     cached key/value projections will be used and the input values will\n *     be ignored.\n * - Updated cache (`cacheUpdateIndex` is not None). In this case, new\n *     key/value projections are computed using the input, and spliced into\n *     the cache at the specified index.\n *\n * Note that caching is useful only during inference and should not be used\n * during training.\n *\n * We use the notation `B`, `T`, `S` below, where `B` is the batch dimension,\n * `T` is the target sequence length, and `S` in the source sequence length.\n * Note that during generative decoding, `T` is usually 1 (you are\n * generating a target sequence of length one to predict the next token).\n *\n * Returns:\n *     An `(attentionOutput, cache)` tuple. `attentionOutput` is the result\n *     of the computation, of shape `(B, T, dim)`, where `T` is for target\n *     sequence shapes and `dim` is the query input last dimension if\n *     `outputShape` is `null`. Otherwise, the multi-head outputs are\n *     projected to the shape specified by `outputShape`. `cache` is the\n *     updated cache.\n */\nexport class CachedMultiHeadAttention extends MultiHeadAttention {\n\n  override call(\n    query: Tensor, kwargs: CachedMultiHeadAttentionOptions\n  ): Tensor {\n    return this.callAndReturnCache(query, kwargs)[0];\n  }\n\n  /**\n   * Exactly like `call` except also returns the updated cache.\n   */\n  callAndReturnCache(\n    query: Tensor,\n    {\n      value,\n      key,\n      attentionMask,\n      cache,\n      cacheUpdateIndex\n    } : CachedMultiHeadAttentionOptions\n  ): [Tensor, Tensor] {\n    return tidy(() => {\n      if (!this.builtFromSignature) {\n        this.buildFromSignature(\n          query.shape, value.shape, key ? key.shape : null);\n      }\n      if (key == null) {\n        key = value;\n      }\n\n      query = this.queryDense.apply(query) as Tensor;\n      // If cache is not `null`, we will use the cache to compute the final key\n      // and value tensors. If `cacheUpdateIndex` is not `null`, we will first\n      // update the cache before use. To do this, we first call the\n      // `keyDense` and `valueDense` layers, and copy the outputs into the\n      // cache at the specified index. `cache = null` handles the training\n      // case, where we don't use the cache at all.\n      if (cache != null) {\n        const keyCache = cache.gather([0], 1).squeeze();\n        const valueCache = cache.gather([1], 1).squeeze();\n        if (cacheUpdateIndex == null) {\n          key = keyCache;\n          value = valueCache;\n        } else {\n          const keyUpdate = this.keyDense.apply(key) as Tensor;\n          const valueUpdate = this.valueDense.apply(value) as Tensor;\n          const start = [0, cacheUpdateIndex, 0, 0];\n          key = sliceUpdate(keyCache, start, keyUpdate);\n          value = sliceUpdate(valueCache, start, valueUpdate);\n          cache = stack([key, value], 1);\n        }\n      } else {\n        if (cacheUpdateIndex != null) {\n          throw new ValueError(\n            '`cacheUpdateIndex` should not be set if `cache` is `null`. ' +\n            `Received: cache=${cache}, cacheUpdateIndex=${cacheUpdateIndex}`\n          );\n        }\n        key = this.keyDense.apply(key) as Tensor;\n        value = this.valueDense.apply(value) as Tensor;\n      }\n\n      query = mul(query, reciprocal(sqrt(cast(this.keyDim, query.dtype))));\n      let attentionScores = einsum(this.dotProductEquation, key, query);\n      attentionScores = this.maskedSoftmax(attentionScores, attentionMask);\n      attentionScores = this.dropoutLayer.apply(attentionScores) as Tensor;\n\n      let attentionOutput =\n        einsum(this.combineEquation, attentionScores, value);\n      attentionOutput = this.outputDense.apply(attentionOutput) as Tensor;\n\n      return [attentionOutput, cache];\n    });\n  }\n}\nserialization.registerClass(CachedMultiHeadAttention);\n"]}