UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

95 lines 12 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Utility functions for `TransformerDecoder`. */ /* Original source: keras_nlp/layers/modeling/transformer_layer_utils.py */ import { add, expandDims, tensor, tidy } from '@tensorflow/tfjs-core'; import { ValueError } from '../../../errors'; function checkMasksShapes(inputs, paddingMask, attentionMask) { if (paddingMask != null) { if (paddingMask.shape.length !== 2) { throw new ValueError('`paddingMask` should have shape ' + `[batchSize, targetLength]. Received shape ${paddingMask.shape}.`); } } if (attentionMask != null) { if (attentionMask.shape.length !== 3) { throw new ValueError('`attentionMask` should have shape ' + `[batchSize, targetLength, sourceLength]. ` + `Received shape ${attentionMask.shape}.`); } } } /** * Compute a causal attention mask for a transformer decoder. * * @param batchSize batch size for the mask. * @param inputLength the length of key/value tensors in the attention layer. * @param outputLength the length of query tensor in the attention layer. * @param cacheIndex the current index for cached generation. If passed, the * query sequence will be considered to start at `cacheIndex` rather than zero. * For example, a casual mask with `outputLength=1` and `cacheIndex=5` would * allow the query tensor to attend to the first five positions of the * key/value tensors. * * @returns a causal attention mask with shape * `[batchSize, outputLength, inputLength]` that can be passed to a attention * layer. */ export function computeCausalMask(batchSize, inputLength, outputLength, cacheIndex = 0) { return tidy(() => { const i = add(expandDims(Array.from({ length: outputLength }, (_, i) => i), 1), cacheIndex); const j = tensor(Array.from({ length: inputLength }, (_, i) => i)); const mask = i.greaterEqual(j).cast('int32').expandDims(0); return mask.broadcastTo([batchSize, outputLength, inputLength]); }); } /** * Merge the padding mask with a customized attention mask. * * @param inputs the input sequence. * @param paddingMask the 1D padding mask, of shape * [batchSize, sequenceLength]. * @param attentionMask the 2D customized mask, of shape * [batchSize, sequenceLength, sequence2_length]. * @returns * A merged 2D mask or null. If only `paddingMask` is provided, the * returned mask is paddingMask with one additional axis. */ export function mergePaddingAndAttentionMask(inputs, paddingMask, attentionMask) { return tidy(() => { checkMasksShapes(inputs, paddingMask, attentionMask); let mask; if (paddingMask != null) { // Add an axis for broadcasting, the attention mask should be 2D // (not including the batch axis). mask = paddingMask.expandDims(1).cast('int32'); } if (attentionMask != null) { attentionMask = attentionMask.cast('int32'); if (mask == null) { return attentionMask; } else { return mask.minimum(attentionMask); } } return mask; }); } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"transformer_layer_utils.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/transformer_layer_utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,2EAA2E;AAC3E,OAAO,EAAU,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAE9E,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAE7C,SAAS,gBAAgB,CACrB,MAAc,EAAE,WAAmB,EAAE,aAAqB;IAC5D,IAAI,WAAW,IAAI,IAAI,EAAE;QACvB,IAAI,WAAW,CAAC,KAAK,CAAC,MAAM,KAAI,CAAC,EAAE;YACjC,MAAM,IAAI,UAAU,CAClB,kCAAkC;gBAClC,6CAA6C,WAAW,CAAC,KAAK,GAAG,CAClE,CAAC;SACH;KACF;IACD,IAAI,aAAa,IAAI,IAAI,EAAE;QACzB,IAAI,aAAa,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE;YACpC,MAAM,IAAI,UAAU,CAClB,oCAAoC;gBACpC,2CAA2C;gBAC3C,kBAAkB,aAAa,CAAC,KAAK,GAAG,CACzC,CAAC;SACH;KACF;AACH,CAAC;AAED;;;;;;;;;;;;;;;GAeG;AACH,MAAM,UAAU,iBAAiB,CAC7B,SAAiB,EACjB,WAAmB,EACnB,YAAoB,EACpB,UAAU,GAAG,CAAC;IAEhB,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,MAAM,CAAC,GAAG,GAAG,CACX,UAAU,CAAC,KAAK,CAAC,IAAI,CAAC,EAAC,MAAM,EAAE,YAAY,EAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAC9D,UAAU,CACX,CAAC;QACF,MAAM,CAAC,GAAG,MAAM,CAAC,KAAK,CAAC,IAAI,CAAC,EAAC,MAAM,EAAE,WAAW,EAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACjE,MAAM,IAAI,GAAG,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;QAC3D,OAAO,IAAI,CAAC,WAAW,CAAC,CAAC,SAAS,EAAE,YAAY,EAAE,WAAW,CAAC,CAAC,CAAC;IAClE,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;;;;;;;;GAWG;AACH,MAAM,UAAU,4BAA4B,CACxC,MAAc,EAAE,WAAmB,EAAE,aAAqB;IAC5D,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,gBAAgB,CAAC,MAAM,EAAE,WAAW,EAAE,aAAa,CAAC,CAAC;QACrD,IAAI,IAAY,CAAC;QACjB,IAAI,WAAW,IAAI,IAAI,EAAE;YACvB,gEAAgE;YAChE,kCAAkC;YAClC,IAAI,GAAG,WAAW,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;SAChD;QACD,IAAI,aAAa,IAAI,IAAI,EAAE;YACzB,aAAa,GAAG,aAAa,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YAC5C,IAAI,IAAI,IAAI,IAAI,EAAE;gBAChB,OAAO,aAAa,CAAC;aACtB;iBAAM;gBACL,OAAO,IAAI,CAAC,OAAO,CAAC,aAAa,CAAC,CAAC;aACpC;SACF;QACD,OAAO,IAAI,CAAC;IACd,CAAC,CAAC,CAAC;AACL,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Utility functions for `TransformerDecoder`.\n */\n\n/* Original source: keras_nlp/layers/modeling/transformer_layer_utils.py */\nimport { Tensor, add, expandDims, tensor, tidy } from '@tensorflow/tfjs-core';\n\nimport { ValueError } from '../../../errors';\n\nfunction checkMasksShapes(\n    inputs: Tensor, paddingMask: Tensor, attentionMask: Tensor): void {\n  if (paddingMask != null) {\n    if (paddingMask.shape.length !==2) {\n      throw new ValueError(\n        '`paddingMask` should have shape ' +\n        `[batchSize, targetLength]. Received shape ${paddingMask.shape}.`\n      );\n    }\n  }\n  if (attentionMask != null) {\n    if (attentionMask.shape.length !== 3) {\n      throw new ValueError(\n        '`attentionMask` should have shape ' +\n        `[batchSize, targetLength, sourceLength]. ` +\n        `Received shape ${attentionMask.shape}.`\n      );\n    }\n  }\n}\n\n/**\n * Compute a causal attention mask for a transformer decoder.\n *\n * @param batchSize batch size for the mask.\n * @param inputLength the length of key/value tensors in the attention layer.\n * @param outputLength the length of query tensor in the attention layer.\n * @param cacheIndex the current index for cached generation. If passed, the\n *  query sequence will be considered to start at `cacheIndex` rather than zero.\n *  For example, a casual mask with `outputLength=1` and `cacheIndex=5` would\n *  allow the query tensor to attend to the first five positions of the\n *  key/value tensors.\n *\n * @returns a causal attention mask with shape\n *  `[batchSize, outputLength, inputLength]` that can be passed to a attention\n *  layer.\n */\nexport function computeCausalMask(\n    batchSize: number,\n    inputLength: number,\n    outputLength: number,\n    cacheIndex = 0\n  ): Tensor {\n  return tidy(() => {\n    const i = add(\n      expandDims(Array.from({length: outputLength}, (_, i) => i), 1),\n      cacheIndex,\n    );\n    const j = tensor(Array.from({length: inputLength}, (_, i) => i));\n    const mask = i.greaterEqual(j).cast('int32').expandDims(0);\n    return mask.broadcastTo([batchSize, outputLength, inputLength]);\n  });\n}\n\n/**\n * Merge the padding mask with a customized attention mask.\n *\n * @param inputs the input sequence.\n * @param paddingMask the 1D padding mask, of shape\n *          [batchSize, sequenceLength].\n * @param attentionMask the 2D customized mask, of shape\n *          [batchSize, sequenceLength, sequence2_length].\n * @returns\n *  A merged 2D mask or null. If only `paddingMask` is provided, the\n *  returned mask is paddingMask with one additional axis.\n */\nexport function mergePaddingAndAttentionMask(\n    inputs: Tensor, paddingMask: Tensor, attentionMask: Tensor): Tensor {\n  return tidy(() => {\n    checkMasksShapes(inputs, paddingMask, attentionMask);\n    let mask: Tensor;\n    if (paddingMask != null) {\n      // Add an axis for broadcasting, the attention mask should be 2D\n      // (not including the batch axis).\n      mask = paddingMask.expandDims(1).cast('int32');\n    }\n    if (attentionMask != null) {\n      attentionMask = attentionMask.cast('int32');\n      if (mask == null) {\n        return attentionMask;\n      } else {\n        return mask.minimum(attentionMask);\n      }\n    }\n    return mask;\n  });\n}\n"]}