UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

112 lines 18.6 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Cached MHA layer based on `MultiHeadAttention`. */ /* Original source: keras_nlp/layers/modeling/cached_multi_head_attention.py */ import { cast, einsum, mul, reciprocal, serialization, sqrt, stack, tidy } from '@tensorflow/tfjs-core'; import { ValueError } from '../../../errors'; import { MultiHeadAttention } from '../multihead_attention'; import { sliceUpdate } from '../utils'; /** * MultiHeadAttention layer with cache support. * * This layer is suitable for use in autoregressive decoding. It can be use * to cache decoder self-attention and cross-attention. The forward pass * can happen in one of three modes: * - No cache, same as regular multi-head attention. * - Static cache (`cacheUpdateIndex` is None). In this case, the * cached key/value projections will be used and the input values will * be ignored. * - Updated cache (`cacheUpdateIndex` is not None). In this case, new * key/value projections are computed using the input, and spliced into * the cache at the specified index. * * Note that caching is useful only during inference and should not be used * during training. * * We use the notation `B`, `T`, `S` below, where `B` is the batch dimension, * `T` is the target sequence length, and `S` in the source sequence length. * Note that during generative decoding, `T` is usually 1 (you are * generating a target sequence of length one to predict the next token). * * Returns: * An `(attentionOutput, cache)` tuple. `attentionOutput` is the result * of the computation, of shape `(B, T, dim)`, where `T` is for target * sequence shapes and `dim` is the query input last dimension if * `outputShape` is `null`. Otherwise, the multi-head outputs are * projected to the shape specified by `outputShape`. `cache` is the * updated cache. */ export class CachedMultiHeadAttention extends MultiHeadAttention { call(query, kwargs) { return this.callAndReturnCache(query, kwargs)[0]; } /** * Exactly like `call` except also returns the updated cache. */ callAndReturnCache(query, { value, key, attentionMask, cache, cacheUpdateIndex }) { return tidy(() => { if (!this.builtFromSignature) { this.buildFromSignature(query.shape, value.shape, key ? key.shape : null); } if (key == null) { key = value; } query = this.queryDense.apply(query); // If cache is not `null`, we will use the cache to compute the final key // and value tensors. If `cacheUpdateIndex` is not `null`, we will first // update the cache before use. To do this, we first call the // `keyDense` and `valueDense` layers, and copy the outputs into the // cache at the specified index. `cache = null` handles the training // case, where we don't use the cache at all. if (cache != null) { const keyCache = cache.gather([0], 1).squeeze(); const valueCache = cache.gather([1], 1).squeeze(); if (cacheUpdateIndex == null) { key = keyCache; value = valueCache; } else { const keyUpdate = this.keyDense.apply(key); const valueUpdate = this.valueDense.apply(value); const start = [0, cacheUpdateIndex, 0, 0]; key = sliceUpdate(keyCache, start, keyUpdate); value = sliceUpdate(valueCache, start, valueUpdate); cache = stack([key, value], 1); } } else { if (cacheUpdateIndex != null) { throw new ValueError('`cacheUpdateIndex` should not be set if `cache` is `null`. ' + `Received: cache=${cache}, cacheUpdateIndex=${cacheUpdateIndex}`); } key = this.keyDense.apply(key); value = this.valueDense.apply(value); } query = mul(query, reciprocal(sqrt(cast(this.keyDim, query.dtype)))); let attentionScores = einsum(this.dotProductEquation, key, query); attentionScores = this.maskedSoftmax(attentionScores, attentionMask); attentionScores = this.dropoutLayer.apply(attentionScores); let attentionOutput = einsum(this.combineEquation, attentionScores, value); attentionOutput = this.outputDense.apply(attentionOutput); return [attentionOutput, cache]; }); } } serialization.registerClass(CachedMultiHeadAttention); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"cached_multihead_attention.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,+EAA+E;AAC/E,OAAO,EAAU,IAAI,EAAE,MAAM,EAAE,GAAG,EAAE,UAAU,EAAE,aAAa,EAAE,IAAI,EAAE,KAAK,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAEhH,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAC7C,OAAO,EAAE,kBAAkB,EAAE,MAAM,wBAAwB,CAAC;AAC5D,OAAO,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAiDvC;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;AACH,MAAM,OAAO,wBAAyB,SAAQ,kBAAkB;IAErD,IAAI,CACX,KAAa,EAAE,MAAuC;QAEtD,OAAO,IAAI,CAAC,kBAAkB,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;IACnD,CAAC;IAED;;OAEG;IACH,kBAAkB,CAChB,KAAa,EACb,EACE,KAAK,EACL,GAAG,EACH,aAAa,EACb,KAAK,EACL,gBAAgB,EACiB;QAEnC,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,CAAC,IAAI,CAAC,kBAAkB,EAAE;gBAC5B,IAAI,CAAC,kBAAkB,CACrB,KAAK,CAAC,KAAK,EAAE,KAAK,CAAC,KAAK,EAAE,GAAG,CAAC,CAAC,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;aACrD;YACD,IAAI,GAAG,IAAI,IAAI,EAAE;gBACf,GAAG,GAAG,KAAK,CAAC;aACb;YAED,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;YAC/C,yEAAyE;YACzE,wEAAwE;YACxE,6DAA6D;YAC7D,oEAAoE;YACpE,oEAAoE;YACpE,6CAA6C;YAC7C,IAAI,KAAK,IAAI,IAAI,EAAE;gBACjB,MAAM,QAAQ,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;gBAChD,MAAM,UAAU,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;gBAClD,IAAI,gBAAgB,IAAI,IAAI,EAAE;oBAC5B,GAAG,GAAG,QAAQ,CAAC;oBACf,KAAK,GAAG,UAAU,CAAC;iBACpB;qBAAM;oBACL,MAAM,SAAS,GAAG,IAAI,CAAC,QAAQ,CAAC,KAAK,CAAC,GAAG,CAAW,CAAC;oBACrD,MAAM,WAAW,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;oBAC3D,MAAM,KAAK,GAAG,CAAC,CAAC,EAAE,gBAAgB,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;oBAC1C,GAAG,GAAG,WAAW,CAAC,QAAQ,EAAE,KAAK,EAAE,SAAS,CAAC,CAAC;oBAC9C,KAAK,GAAG,WAAW,CAAC,UAAU,EAAE,KAAK,EAAE,WAAW,CAAC,CAAC;oBACpD,KAAK,GAAG,KAAK,CAAC,CAAC,GAAG,EAAE,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC;iBAChC;aACF;iBAAM;gBACL,IAAI,gBAAgB,IAAI,IAAI,EAAE;oBAC5B,MAAM,IAAI,UAAU,CAClB,6DAA6D;wBAC7D,mBAAmB,KAAK,sBAAsB,gBAAgB,EAAE,CACjE,CAAC;iBACH;gBACD,GAAG,GAAG,IAAI,CAAC,QAAQ,CAAC,KAAK,CAAC,GAAG,CAAW,CAAC;gBACzC,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;aAChD;YAED,KAAK,GAAG,GAAG,CAAC,KAAK,EAAE,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YACrE,IAAI,eAAe,GAAG,MAAM,CAAC,IAAI,CAAC,kBAAkB,EAAE,GAAG,EAAE,KAAK,CAAC,CAAC;YAClE,eAAe,GAAG,IAAI,CAAC,aAAa,CAAC,eAAe,EAAE,aAAa,CAAC,CAAC;YACrE,eAAe,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,eAAe,CAAW,CAAC;YAErE,IAAI,eAAe,GACjB,MAAM,CAAC,IAAI,CAAC,eAAe,EAAE,eAAe,EAAE,KAAK,CAAC,CAAC;YACvD,eAAe,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,eAAe,CAAW,CAAC;YAEpE,OAAO,CAAC,eAAe,EAAE,KAAK,CAAC,CAAC;QAClC,CAAC,CAAC,CAAC;IACL,CAAC;CACF;AACD,aAAa,CAAC,aAAa,CAAC,wBAAwB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Cached MHA layer based on `MultiHeadAttention`.\n */\n\n/* Original source: keras_nlp/layers/modeling/cached_multi_head_attention.py */\nimport { Tensor, cast, einsum, mul, reciprocal, serialization, sqrt, stack, tidy } from '@tensorflow/tfjs-core';\n\nimport { ValueError } from '../../../errors';\nimport { MultiHeadAttention } from '../multihead_attention';\nimport { sliceUpdate } from '../utils';\n\nexport declare interface CachedMultiHeadAttentionOptions {\n  /**\n   * Query `Tensor` of shape `(B, T, dim)`.\n   */\n\n  /**\n   * Value `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*`\n   * must equal `S` and match the shape of `attentionMask`. If `cache` is\n   * not `null`, `S*` can be any length less than `S`, and the computed\n   * value will be spliced into `cache` at `cacheUpdateIndex`.\n   */\n  value: Tensor;\n\n  /**\n   * Key `Tensor` of shape `(B, S*, dim)`.  If `cache` is `null`, `S*` must\n   * equal `S` and match the shape of `attentionMask`. If `cache` is not `null`,\n   * `S*` can be any length less than `S`, and the computed value will be\n   * spliced into `cache` at `cacheUpdateIndex`.\n   */\n  key?: Tensor;\n\n  /**\n   * A boolean mask of shape `(B, T, S)`. `attentionMask` prevents\n   * attention to certain positions. The boolean mask specifies which\n   * query elements can attend to which key elements, 1 indicates\n   * attention and 0 indicates no attention. Broadcasting can happen for\n   * the missing batch dimensions and the head dimension.\n   */\n  attentionMask?: Tensor;\n\n  /**\n   * A dense float Tensor. The key/value cache, of shape\n   * `[B, 2, S, numHeads, keyDims]`, where `S` must agree with the\n   * `attentionMask` shape. This argument is intended for use during\n   * generation to avoid recomputing intermediate state.\n   */\n  cache?: Tensor;\n\n  /**\n   * Integer or Integer `Tensor`. The index at which to update `cache`\n   * (usually the index of the current token being processed when running\n   * generation). If `cacheUpdateIndex=null` while `cache` is set, the cache\n   * will not be updated.\n   */\n  cacheUpdateIndex?: number;\n}\n\n/**\n * MultiHeadAttention layer with cache support.\n *\n * This layer is suitable for use in autoregressive decoding. It can be use\n * to cache decoder self-attention and cross-attention. The forward pass\n * can happen in one of three modes:\n * - No cache, same as regular multi-head attention.\n * - Static cache (`cacheUpdateIndex` is None). In this case, the\n *     cached key/value projections will be used and the input values will\n *     be ignored.\n * - Updated cache (`cacheUpdateIndex` is not None). In this case, new\n *     key/value projections are computed using the input, and spliced into\n *     the cache at the specified index.\n *\n * Note that caching is useful only during inference and should not be used\n * during training.\n *\n * We use the notation `B`, `T`, `S` below, where `B` is the batch dimension,\n * `T` is the target sequence length, and `S` in the source sequence length.\n * Note that during generative decoding, `T` is usually 1 (you are\n * generating a target sequence of length one to predict the next token).\n *\n * Returns:\n *     An `(attentionOutput, cache)` tuple. `attentionOutput` is the result\n *     of the computation, of shape `(B, T, dim)`, where `T` is for target\n *     sequence shapes and `dim` is the query input last dimension if\n *     `outputShape` is `null`. Otherwise, the multi-head outputs are\n *     projected to the shape specified by `outputShape`. `cache` is the\n *     updated cache.\n */\nexport class CachedMultiHeadAttention extends MultiHeadAttention {\n\n  override call(\n    query: Tensor, kwargs: CachedMultiHeadAttentionOptions\n  ): Tensor {\n    return this.callAndReturnCache(query, kwargs)[0];\n  }\n\n  /**\n   * Exactly like `call` except also returns the updated cache.\n   */\n  callAndReturnCache(\n    query: Tensor,\n    {\n      value,\n      key,\n      attentionMask,\n      cache,\n      cacheUpdateIndex\n    } : CachedMultiHeadAttentionOptions\n  ): [Tensor, Tensor] {\n    return tidy(() => {\n      if (!this.builtFromSignature) {\n        this.buildFromSignature(\n          query.shape, value.shape, key ? key.shape : null);\n      }\n      if (key == null) {\n        key = value;\n      }\n\n      query = this.queryDense.apply(query) as Tensor;\n      // If cache is not `null`, we will use the cache to compute the final key\n      // and value tensors. If `cacheUpdateIndex` is not `null`, we will first\n      // update the cache before use. To do this, we first call the\n      // `keyDense` and `valueDense` layers, and copy the outputs into the\n      // cache at the specified index. `cache = null` handles the training\n      // case, where we don't use the cache at all.\n      if (cache != null) {\n        const keyCache = cache.gather([0], 1).squeeze();\n        const valueCache = cache.gather([1], 1).squeeze();\n        if (cacheUpdateIndex == null) {\n          key = keyCache;\n          value = valueCache;\n        } else {\n          const keyUpdate = this.keyDense.apply(key) as Tensor;\n          const valueUpdate = this.valueDense.apply(value) as Tensor;\n          const start = [0, cacheUpdateIndex, 0, 0];\n          key = sliceUpdate(keyCache, start, keyUpdate);\n          value = sliceUpdate(valueCache, start, valueUpdate);\n          cache = stack([key, value], 1);\n        }\n      } else {\n        if (cacheUpdateIndex != null) {\n          throw new ValueError(\n            '`cacheUpdateIndex` should not be set if `cache` is `null`. ' +\n            `Received: cache=${cache}, cacheUpdateIndex=${cacheUpdateIndex}`\n          );\n        }\n        key = this.keyDense.apply(key) as Tensor;\n        value = this.valueDense.apply(value) as Tensor;\n      }\n\n      query = mul(query, reciprocal(sqrt(cast(this.keyDim, query.dtype))));\n      let attentionScores = einsum(this.dotProductEquation, key, query);\n      attentionScores = this.maskedSoftmax(attentionScores, attentionMask);\n      attentionScores = this.dropoutLayer.apply(attentionScores) as Tensor;\n\n      let attentionOutput =\n        einsum(this.combineEquation, attentionScores, value);\n      attentionOutput = this.outputDense.apply(attentionOutput) as Tensor;\n\n      return [attentionOutput, cache];\n    });\n  }\n}\nserialization.registerClass(CachedMultiHeadAttention);\n"]}