UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

228 lines (227 loc) 9.41 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/modeling/transformer_decoder" /> /** * Transformer decoder block implementation based on TFJS `Layer`. */ import { Tensor, serialization } from '@tensorflow/tfjs-core'; import { Activation } from '../../../activations'; import { Layer, LayerArgs, SymbolicTensor } from '../../../engine/topology'; import { Initializer, InitializerIdentifier } from '../../../initializers'; import { ActivationIdentifier } from '../../../keras_format/activation_config'; import { Shape } from '../../../keras_format/common'; import { Dense, Dropout } from '../../core'; import { LayerNormalization } from '../../normalization'; import { CachedMultiHeadAttention } from './cached_multihead_attention'; export declare interface TransformerDecoderArgs extends LayerArgs { /** * Integer. The hidden size of feedforward network. */ intermediateDim: number; /** * Integer. The number of heads in MultiHeadAttention. */ numHeads: number; /** * The dropout value, shared by MultiHeadAttention and feedforward network. * Defaults to `0.`. */ dropout?: number; /** * The activation function of feedforward network. * Defaults to `"relu"`. */ activation?: Activation | ActivationIdentifier; /** * The eps value in layer normalization components. * Defaults to `1e-5`. */ layerNormEpsilon?: number; /** * The kernel initializer for the dense and multiheaded attention layers. * Defaults to `"glorotUniform"`. */ kernelInitializer?: Initializer | InitializerIdentifier; /** * The bias initializer for the dense and multiheaded attention layers. * Defaults to `"zeros"`. */ biasInitializer?: Initializer | InitializerIdentifier; /** * If true, the inputs to the attention layer(s) and the intermediate dense * layer are normalized (similar to GPT-2). If set to false, outputs of * attention layer and intermediate dense layer are normalized * (similar to BERT). * Defaults to `false`. */ normalizeFirst?: boolean; } export declare interface TransformerDecoderOptions { /** * decoderSequence: The decode input sequence. */ /** * The encoder input sequence. For decoder only models (like GPT2), this * should be left `null`. Once the model is called without an encoderSequence, * you cannot call it again with encoderSequence. */ encoderSequence?: Tensor | SymbolicTensor; /** * A boolean Tensor, the padding mask of decoder sequence, must be of shape * `[batchSize, decoderSequenceLength]`. */ decoderPaddingMask?: Tensor | SymbolicTensor; /** * A boolean Tensor. Customized decoder sequence mask, must be of shape * `[batchSize, decoderSequenceLength, decoderSequenceLength]`. */ decoderAttentionMask?: Tensor; /** * A boolean Tensor, the padding mask of encoder sequence, must be of shape * `[batchSize, encoderSequenceLength]`. */ encoderPaddingMask?: Tensor; /** * A boolean Tensor. Customized encoder sequence mask, must be of shape * `[batchSize, encoderSequenceLength, encoderSequenceLength]`. */ encoderAttentionMask?: Tensor; /** * A dense float Tensor. The cache of key/values pairs in the self-attention * layer. Has shape `[batchSize, 2, maxSeqLen, numHeads, keyDims]`. */ selfAttentionCache?: Tensor; /** * Integer or Integer Tensor. The index at which to update the * `selfAttentionCache`. Usually, this is the index of the current token * being processed during decoding. */ selfAttentionCacheUpdateIndex?: number; /** * A dense float Tensor. The cache of key/value pairs in the cross-attention * layer. Has shape `[batchSize, 2, S, numHeads, keyDims]`. */ crossAttentionCache?: Tensor; /** * Integer or Integer Tensor. The index at which to update the * `crossAttentionCache`. Usually, this is either `0` (compute the entire * `crossAttentionCache`), or `null` (reuse a previously computed * `crossAttentionCache`). */ crossAttentionCacheUpdateIndex?: number; /** * If true, a causal mask (masking out future input) is applied on the decoder * sequence. * Defaults to `true`. */ useCausalMask?: boolean; } /** * Transformer decoder. * * This class follows the architecture of the transformer decoder layer in the * paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users * can instantiate multiple instances of this class to stack up a decoder. * * By default, this layer will apply a causal mask to the decoder attention * layer. This layer will correctly compute an attention mask from an implicit * padding mask (for example, by passing `maskZero=true` to a * `tf.layers.embedding` layer). See the Masking and Padding * [guide](https://keras.io/guides/understanding_masking_and_padding/) * for more details. * * This layer can be called with either one or two inputs. The number of inputs * must be consistent across all calls. The options are as follows: * `layer.call(decoderSequence)`: no cross-attention will be built into the * decoder block. This is useful when building a "decoder-only" * transformer such as GPT-2. * `layer.call(decoderSequence, {encoderSequence})`: cross-attention will be * built into the decoder block. This is useful when building an * "encoder-decoder" transformer, such as the original transformer * model described in Attention is All You Need. * * Examples: * ```js * // Create a single transformer decoder layer. * const decoder = new TransformerDecoder({intermediateDim: 64, numHeads: 8}); * * // Create a simple model containing the decoder. * const decoderInput = tf.input({shape: [10, 64]}); * const encoderInput = tf.input({shape: {[10, 64]}); * const output = decoder.call(decoderInput, {encoderInput}); * const model = tf.model({ * inputs: [decoderInput, encoderInput], * outputs: output, * ); * * // Call decoder on the inputs. * const decoderInputData = tf.randomUniform([2, 10, 64]); * const encoderInputData = tf.randomUniform([2, 10, 64]); * const decoderOutput = model.predict([decoderInputData, encoderInputData]); * ``` * * References: * - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762) */ export declare class TransformerDecoder extends Layer { /** @nocollapse */ static readonly className = "TransformerDecoder"; protected intermediateDim: number; protected numHeads: number; protected dropout: number; protected activation: Activation; protected layerNormEpsilon: number; protected kernelInitializer: Initializer; protected biasInitializer: Initializer; protected normalizeFirst: boolean; protected decoderSequenceShape: Shape; protected encoderSequenceShape: Shape; protected selfAttentionLayer: CachedMultiHeadAttention; protected selfAttentionLayernorm: LayerNormalization; protected selfAttentionDropout: Dropout; protected selfCrossAttentionLayer: CachedMultiHeadAttention; protected selfCrossAttentionLayernorm: LayerNormalization; protected selfCrossAttentionDropout: Dropout; protected feedforwardIntermediateDense: Dense; protected feedforwardOutputDense: Dense; protected feedforwardLayernorm: LayerNormalization; protected feedforwardDropout: Dropout; constructor(args: TransformerDecoderArgs); /** * * @param inputShape decoderSequenceShape or * [decoderSequenceShape, encoderSequenceShape] */ build(inputShape: Shape | [Shape, Shape]): void; apply(decoderSequence: Tensor | SymbolicTensor, kwargs?: TransformerDecoderOptions): Tensor | SymbolicTensor; call(decoderSequence: Tensor, kwargs: TransformerDecoderOptions): Tensor; /** * Forward pass of the TransformerDecoder. * * @returns One of three things, depending on call arguments: * - `[outputs, null, null]`, if `selfAttentionCache` is `null`. * - `[outputs, selfAttentionCache, null]`, if `selfAttentionCache` is * set and the layer has no cross-attention. * - `[outputs, selfAttentionCache, crossAttentionCache]`, if * `selfAttentionCache` and `crossAttentionCache` are set and * the layer has cross-attention. */ callAndReturnCaches(decoderSequence: Tensor, kwargs: TransformerDecoderOptions): [Tensor, Tensor, Tensor]; private computeSelfAttentionMask; getConfig(): serialization.ConfigDict; computeOutputShape(decoderSequenceShape: Shape): Shape; }