UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

101 lines (100 loc) 4.53 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/modeling/cached_multihead_attention" /> /** * Cached MHA layer based on `MultiHeadAttention`. */ import { Tensor } from '@tensorflow/tfjs-core'; import { MultiHeadAttention } from '../multihead_attention'; export declare interface CachedMultiHeadAttentionOptions { /** * Query `Tensor` of shape `(B, T, dim)`. */ /** * Value `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*` * must equal `S` and match the shape of `attentionMask`. If `cache` is * not `null`, `S*` can be any length less than `S`, and the computed * value will be spliced into `cache` at `cacheUpdateIndex`. */ value: Tensor; /** * Key `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*` must * equal `S` and match the shape of `attentionMask`. If `cache` is not `null`, * `S*` can be any length less than `S`, and the computed value will be * spliced into `cache` at `cacheUpdateIndex`. */ key?: Tensor; /** * A boolean mask of shape `(B, T, S)`. `attentionMask` prevents * attention to certain positions. The boolean mask specifies which * query elements can attend to which key elements, 1 indicates * attention and 0 indicates no attention. Broadcasting can happen for * the missing batch dimensions and the head dimension. */ attentionMask?: Tensor; /** * A dense float Tensor. The key/value cache, of shape * `[B, 2, S, numHeads, keyDims]`, where `S` must agree with the * `attentionMask` shape. This argument is intended for use during * generation to avoid recomputing intermediate state. */ cache?: Tensor; /** * Integer or Integer `Tensor`. The index at which to update `cache` * (usually the index of the current token being processed when running * generation). If `cacheUpdateIndex=null` while `cache` is set, the cache * will not be updated. */ cacheUpdateIndex?: number; } /** * MultiHeadAttention layer with cache support. * * This layer is suitable for use in autoregressive decoding. It can be use * to cache decoder self-attention and cross-attention. The forward pass * can happen in one of three modes: * - No cache, same as regular multi-head attention. * - Static cache (`cacheUpdateIndex` is None). In this case, the * cached key/value projections will be used and the input values will * be ignored. * - Updated cache (`cacheUpdateIndex` is not None). In this case, new * key/value projections are computed using the input, and spliced into * the cache at the specified index. * * Note that caching is useful only during inference and should not be used * during training. * * We use the notation `B`, `T`, `S` below, where `B` is the batch dimension, * `T` is the target sequence length, and `S` in the source sequence length. * Note that during generative decoding, `T` is usually 1 (you are * generating a target sequence of length one to predict the next token). * * Returns: * An `(attentionOutput, cache)` tuple. `attentionOutput` is the result * of the computation, of shape `(B, T, dim)`, where `T` is for target * sequence shapes and `dim` is the query input last dimension if * `outputShape` is `null`. Otherwise, the multi-head outputs are * projected to the shape specified by `outputShape`. `cache` is the * updated cache. */ export declare class CachedMultiHeadAttention extends MultiHeadAttention { call(query: Tensor, kwargs: CachedMultiHeadAttentionOptions): Tensor; /** * Exactly like `call` except also returns the updated cache. */ callAndReturnCache(query: Tensor, { value, key, attentionMask, cache, cacheUpdateIndex }: CachedMultiHeadAttentionOptions): [Tensor, Tensor]; }