UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

165 lines (164 loc) 6.72 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/models/gpt2/gpt2_causal_lm" /> /** * GPT2 Causal LM (Language Model). */ import { NamedTensorMap, Tensor, serialization } from '@tensorflow/tfjs-core'; import { GPT2Preprocessor } from './gpt2_preprocessor'; import { GenerativeTask } from '../generative_task'; import { GPT2Backbone } from './gpt2_backbone'; import { PipelineModelArgs } from '../../utils'; export declare interface GPT2CausalLMArgs extends PipelineModelArgs { /** * A `GPT2Backbone` instance. */ backbone: GPT2Backbone; /** * Optional `GPT2CausalLMPreprocessor`. * If `null`, this model will not apply preprocessing, and inputs should be * preprocessed before calling the model. */ preprocessor?: GPT2Preprocessor; } /** * An end-to-end GPT2 model for causal language modeling. * * A causal language model (LM) predicts the next token based on previous * tokens. This task setup can be used to train the model unsupervised on * plain text input, or to autoregressively generate plain text similar to * the data used for training. This task can be used for pre-training or * fine-tuning a GPT-2 model, simply by calling `fit()`. * * This model has a `generate()` method, which generates text based on a * prompt. The generation strategy used is controlled by an additional * sampler` argument on `compile()`. * By default, the top k results will be returned. * * This model can optionally be configured with a `preprocessor` layer, in * which case it will automatically apply preprocessing to string inputs during * fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default * when creating the model with `fromPreset()`. * * Disclaimer: Pre-trained models are provided on an "as is" basis, without * warranties or conditions of any kind. The underlying model is provided by a * third party and subject to a separate license, available * here](https://github.com/openai/gpt-2). * * Use `generate()` to do text generation. * ```js * const gpt2LM = GPT2CausalLM.fromPreset('gpt2_base_en'); * gpt2LM.generate("I want to say", max_length=30); * // Generate with batched prompts. * gpt2LM.generate(["This is a", "Where are you"], max_length=30); * ``` * * Use `generate()` without preprocessing. * ```js * // Prompt the model with `5338, 318` (the token ids for `"Who is"`). * // Use `"paddingMask"` to indicate values that should not be overridden. * const prompt = { * tokenIds: tf.tensor([[5338, 318, 0, 0, 0], [5338, 318, 0, 0, 0]]), * paddingMask: tf.tensor([[1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]]), * }; * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null); * gpt2LM.generate(prompt); * ``` * * Call `fit()` on a single batch. * ```js * const features = ['The quick brown fox jumped.', 'I forgot my homework.']; * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en'); * gpt2LM.fit(features, {batchSize: 2}); * ``` * * Call `fit()` without preprocessing. * ```js * const x = { * tokenIds: tf.tensor([[50256, 1, 2, 3, 4], [50256, 1, 2, 3, 4]]), * paddingMask: tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), * }; * const y = tf.tensor([[1, 2, 3, 4, 50256], [1, 2, 3, 4, 50256]]); * const sw = tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]); * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null); * gpt2LM.fit(x, y, {sampleWeight: sw, batchSize: 2}); * ``` * * Custom backbone and vocabulary. * ```js * const features = ["a quick fox.", "a fox quick."]; * const vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6}; * const merges = [ * "Ġ q", "u i", "c k", "ui ck", "Ġq uick", "Ġ f", "o x", "Ġf ox" * ]; * const tokenizer = new GPT2Tokenizer({vocabulary: vocab, merges}); * const preprocessor = new GPT2CausalLMPreprocessor({ * tokenizer, * sequence_length: 128, * }); * const backbone = new GPT2Backbone({ * vocabularysize: 30552, * numlayers: 4, * numheads: 4, * hiddendim: 256, * intermediatedim: 512, * maxSequenceLength: 128, * }); * const gpt2LM = new GPT2CausalLM({backbone, preprocessor}); * gpt2LM.fit(features, {batch_size: 2}); * ``` */ export declare class GPT2CausalLM extends GenerativeTask { /** @nocollapse */ static className: string; constructor(args: GPT2CausalLMArgs); static presets<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>): {}; /** * Forward pass of `GPT2CausalLM` with cache. * * `callWithCache` adds an additional forward pass for the model for * autoregressive inference. Unlike calling the model directly, this method * allows caching previous key/value Tensors in multi-head attention layer, * and avoids recomputing the outputs of seen tokens. * * @param tokenIds a dense int Tensor with shape `[batchSize, maxLength]`. * @param cache a dense float Tensor, the cache of key and value. * @param cacheUpdateIndex Integer. The index of current inputs in the whole * sequence. * @returns [logits, hiddenStates, cache], where `logits` is the * language model logits for the input tokenIds, `hiddenStates` is * the final hidden representation of the input tokens, and `cache` is * the decoding cache. */ callWithCache(tokenIds: Tensor, cache: Tensor, cacheUpdateIndex: number): [Tensor, Tensor, Tensor]; /** * Build an empty cache for use with `callWithCache()`. */ private buildCache; /** * A compilable generation function for a single batch of inputs. * * This function represents the inner generation function for a single batch * of inputs. * * @param inputs An object with two keys `tokenIds` and `paddingMask` and * batched tensor values. * @param endTokenId The id of the end token to stop on. If all * sequences have produced a new `endTokenId`, generation will stop. */ generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap; }