UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

177 lines 20.1 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * GPT2 Causal LM (Language Model). */ /* Original source: keras-nlp/models/gpt2/gpt2_causal_lm.py */ import { serialization } from '@tensorflow/tfjs-core'; import { NotImplementedError } from '../../../../errors'; import { Layer } from '../../../../exports_layers'; import { GenerativeTask } from '../generative_task'; class ReverseEmbedding extends Layer { constructor(args) { super(args); this.embedding = args.embedding; } call(inputs, kwargs) { throw new NotImplementedError(); } computeOutputShape(inputShape) { throw new NotImplementedError(); } } /** * An end-to-end GPT2 model for causal language modeling. * * A causal language model (LM) predicts the next token based on previous * tokens. This task setup can be used to train the model unsupervised on * plain text input, or to autoregressively generate plain text similar to * the data used for training. This task can be used for pre-training or * fine-tuning a GPT-2 model, simply by calling `fit()`. * * This model has a `generate()` method, which generates text based on a * prompt. The generation strategy used is controlled by an additional * sampler` argument on `compile()`. * By default, the top k results will be returned. * * This model can optionally be configured with a `preprocessor` layer, in * which case it will automatically apply preprocessing to string inputs during * fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default * when creating the model with `fromPreset()`. * * Disclaimer: Pre-trained models are provided on an "as is" basis, without * warranties or conditions of any kind. The underlying model is provided by a * third party and subject to a separate license, available * here](https://github.com/openai/gpt-2). * * Use `generate()` to do text generation. * ```js * const gpt2LM = GPT2CausalLM.fromPreset('gpt2_base_en'); * gpt2LM.generate("I want to say", max_length=30); * // Generate with batched prompts. * gpt2LM.generate(["This is a", "Where are you"], max_length=30); * ``` * * Use `generate()` without preprocessing. * ```js * // Prompt the model with `5338, 318` (the token ids for `"Who is"`). * // Use `"paddingMask"` to indicate values that should not be overridden. * const prompt = { * tokenIds: tf.tensor([[5338, 318, 0, 0, 0], [5338, 318, 0, 0, 0]]), * paddingMask: tf.tensor([[1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]]), * }; * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null); * gpt2LM.generate(prompt); * ``` * * Call `fit()` on a single batch. * ```js * const features = ['The quick brown fox jumped.', 'I forgot my homework.']; * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en'); * gpt2LM.fit(features, {batchSize: 2}); * ``` * * Call `fit()` without preprocessing. * ```js * const x = { * tokenIds: tf.tensor([[50256, 1, 2, 3, 4], [50256, 1, 2, 3, 4]]), * paddingMask: tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), * }; * const y = tf.tensor([[1, 2, 3, 4, 50256], [1, 2, 3, 4, 50256]]); * const sw = tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]); * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null); * gpt2LM.fit(x, y, {sampleWeight: sw, batchSize: 2}); * ``` * * Custom backbone and vocabulary. * ```js * const features = ["a quick fox.", "a fox quick."]; * const vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6}; * const merges = [ * "Ġ q", "u i", "c k", "ui ck", "Ġq uick", "Ġ f", "o x", "Ġf ox" * ]; * const tokenizer = new GPT2Tokenizer({vocabulary: vocab, merges}); * const preprocessor = new GPT2CausalLMPreprocessor({ * tokenizer, * sequence_length: 128, * }); * const backbone = new GPT2Backbone({ * vocabularysize: 30552, * numlayers: 4, * numheads: 4, * hiddendim: 256, * intermediatedim: 512, * maxSequenceLength: 128, * }); * const gpt2LM = new GPT2CausalLM({backbone, preprocessor}); * gpt2LM.fit(features, {batch_size: 2}); * ``` */ class GPT2CausalLM extends GenerativeTask { constructor(args) { super(args); throw new NotImplementedError(`Uses ${ReverseEmbedding}.`); } static presets(cls) { throw new NotImplementedError(); } /** * Forward pass of `GPT2CausalLM` with cache. * * `callWithCache` adds an additional forward pass for the model for * autoregressive inference. Unlike calling the model directly, this method * allows caching previous key/value Tensors in multi-head attention layer, * and avoids recomputing the outputs of seen tokens. * * @param tokenIds a dense int Tensor with shape `[batchSize, maxLength]`. * @param cache a dense float Tensor, the cache of key and value. * @param cacheUpdateIndex Integer. The index of current inputs in the whole * sequence. * @returns [logits, hiddenStates, cache], where `logits` is the * language model logits for the input tokenIds, `hiddenStates` is * the final hidden representation of the input tokens, and `cache` is * the decoding cache. */ callWithCache(tokenIds, cache, cacheUpdateIndex) { throw new NotImplementedError(); } /** * Build an empty cache for use with `callWithCache()`. */ buildCache(tokenIds) { throw new NotImplementedError(); } /** * A compilable generation function for a single batch of inputs. * * This function represents the inner generation function for a single batch * of inputs. * * @param inputs An object with two keys `tokenIds` and `paddingMask` and * batched tensor values. * @param endTokenId The id of the end token to stop on. If all * sequences have produced a new `endTokenId`, generation will stop. */ generateStep(inputs, endTokenId) { throw new NotImplementedError(`Uses ${this.buildCache}`); } } /** @nocollapse */ GPT2CausalLM.className = 'GPT2CausalLM'; export { GPT2CausalLM }; serialization.registerClass(GPT2CausalLM); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"gpt2_causal_lm.js","sourceRoot":"","sources":["../../../../../../../../../tfjs-layers/src/layers/nlp/models/gpt2/gpt2_causal_lm.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,8DAA8D;AAC9D,OAAO,EAA0B,aAAa,EAAE,MAAM,uBAAuB,CAAC;AAG9E,OAAO,EAAE,mBAAmB,EAAE,MAAM,oBAAoB,CAAC;AACzD,OAAO,EAAE,KAAK,EAAE,MAAM,4BAA4B,CAAC;AAInD,OAAO,EAAE,cAAc,EAAE,MAAM,oBAAoB,CAAC;AASpD,MAAM,gBAAiB,SAAQ,KAAK;IAGlC,YAAY,IAA0B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;IAClC,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;CAEF;AAgBD;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAsFG;AACH,MAAa,YAAa,SAAQ,cAAc;IAI9C,YAAY,IAAsB;QAChC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,MAAM,IAAI,mBAAmB,CAAC,QAAQ,gBAAgB,GAAG,CAAC,CAAC;IAC7D,CAAC;IAED,MAAM,CAAU,OAAO,CACrB,GAA6C;QAE7C,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAED;;;;;;;;;;;;;;;;OAgBG;IACH,aAAa,CACX,QAAgB,EAChB,KAAa,EACb,gBAAwB;QAExB,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAED;;OAEG;IACK,UAAU,CAAC,QAAgB;QACjC,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAED;;;;;;;;;;OAUG;IACM,YAAY,CACnB,MAAsB,EACtB,UAAkB;QAElB,MAAM,IAAI,mBAAmB,CAAC,QAAQ,IAAI,CAAC,UAAU,EAAE,CAAC,CAAC;IAC3D,CAAC;;AA9DD,kBAAkB;AACF,sBAAS,GAAG,cAAc,CAAC;SAFhC,YAAY;AAiEzB,aAAa,CAAC,aAAa,CAAC,YAAY,CAAC,CAAC","sourcesContent":["/**\r\n * @license\r\n * Copyright 2023 Google LLC.\r\n * Licensed under the Apache License, Version 2.0 (the \"License\");\r\n * you may not use this file except in compliance with the License.\r\n * You may obtain a copy of the License at\r\n *\r\n * http://www.apache.org/licenses/LICENSE-2.0\r\n *\r\n * Unless required by applicable law or agreed to in writing, software\r\n * distributed under the License is distributed on an \"AS IS\" BASIS,\r\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n * See the License for the specific language governing permissions and\r\n * limitations under the License.\r\n * =============================================================================\r\n */\r\n\r\n/**\r\n * GPT2 Causal LM (Language Model).\r\n */\r\n\r\n/* Original source: keras-nlp/models/gpt2/gpt2_causal_lm.py */\r\nimport { NamedTensorMap, Tensor, serialization } from '@tensorflow/tfjs-core';\r\n\r\nimport { GPT2Preprocessor } from './gpt2_preprocessor';\r\nimport { NotImplementedError } from '../../../../errors';\r\nimport { Layer } from '../../../../exports_layers';\r\nimport { LayerArgs } from '../../../../engine/topology';\r\nimport { Embedding } from '../../../../layers/embeddings';\r\nimport { Shape } from '../../../../keras_format/common';\r\nimport { GenerativeTask } from '../generative_task';\r\nimport { GPT2Backbone } from './gpt2_backbone';\r\nimport { PipelineModelArgs } from '../../utils';\r\nimport { Kwargs } from '../../../../types';\r\n\r\ndeclare interface ReverseEmbeddingArgs extends LayerArgs {\r\n  embedding: Embedding;\r\n}\r\n\r\nclass ReverseEmbedding extends Layer {\r\n  protected embedding: Embedding;\r\n\r\n  constructor(args: ReverseEmbeddingArgs) {\r\n    super(args);\r\n    this.embedding = args.embedding;\r\n  }\r\n\r\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\r\n    throw new NotImplementedError();\r\n  }\r\n\r\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\r\n    throw new NotImplementedError();\r\n  }\r\n\r\n}\r\n\r\nexport declare interface GPT2CausalLMArgs extends PipelineModelArgs {\r\n  /**\r\n   * A `GPT2Backbone` instance.\r\n   */\r\n  backbone: GPT2Backbone;\r\n\r\n  /**\r\n   * Optional `GPT2CausalLMPreprocessor`.\r\n   * If `null`, this model will not apply preprocessing, and inputs should be\r\n   * preprocessed before calling the model.\r\n   */\r\n  preprocessor?: GPT2Preprocessor;\r\n}\r\n\r\n/**\r\n * An end-to-end GPT2 model for causal language modeling.\r\n *\r\n * A causal language model (LM) predicts the next token based on previous\r\n * tokens. This task setup can be used to train the model unsupervised on\r\n * plain text input, or to autoregressively generate plain text similar to\r\n * the data used for training. This task can be used for pre-training or\r\n * fine-tuning a GPT-2 model, simply by calling `fit()`.\r\n *\r\n * This model has a `generate()` method, which generates text based on a\r\n * prompt. The generation strategy used is controlled by an additional\r\n * sampler` argument on `compile()`.\r\n * By default, the top k results will be returned.\r\n *\r\n * This model can optionally be configured with a `preprocessor` layer, in\r\n * which case it will automatically apply preprocessing to string inputs during\r\n * fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default\r\n * when creating the model with `fromPreset()`.\r\n *\r\n * Disclaimer: Pre-trained models are provided on an \"as is\" basis, without\r\n * warranties or conditions of any kind. The underlying model is provided by a\r\n * third party and subject to a separate license, available\r\n * here](https://github.com/openai/gpt-2).\r\n *\r\n * Use `generate()` to do text generation.\r\n * ```js\r\n * const gpt2LM = GPT2CausalLM.fromPreset('gpt2_base_en');\r\n * gpt2LM.generate(\"I want to say\", max_length=30);\r\n * // Generate with batched prompts.\r\n * gpt2LM.generate([\"This is a\", \"Where are you\"], max_length=30);\r\n * ```\r\n *\r\n * Use `generate()` without preprocessing.\r\n * ```js\r\n * // Prompt the model with `5338, 318` (the token ids for `\"Who is\"`).\r\n * // Use `\"paddingMask\"` to indicate values that should not be overridden.\r\n * const prompt = {\r\n *  tokenIds: tf.tensor([[5338, 318, 0, 0, 0], [5338, 318, 0, 0, 0]]),\r\n *  paddingMask: tf.tensor([[1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]]),\r\n * };\r\n * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null);\r\n * gpt2LM.generate(prompt);\r\n * ```\r\n *\r\n * Call `fit()` on a single batch.\r\n * ```js\r\n * const features = ['The quick brown fox jumped.', 'I forgot my homework.'];\r\n * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en');\r\n * gpt2LM.fit(features, {batchSize: 2});\r\n * ```\r\n *\r\n * Call `fit()` without preprocessing.\r\n * ```js\r\n * const x = {\r\n *  tokenIds: tf.tensor([[50256, 1, 2, 3, 4], [50256, 1, 2, 3, 4]]),\r\n *  paddingMask: tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]),\r\n * };\r\n * const y = tf.tensor([[1, 2, 3, 4, 50256], [1, 2, 3, 4, 50256]]);\r\n * const sw = tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]);\r\n * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null);\r\n * gpt2LM.fit(x, y, {sampleWeight: sw, batchSize: 2});\r\n * ```\r\n *\r\n * Custom backbone and vocabulary.\r\n * ```js\r\n * const features = [\"a quick fox.\", \"a fox quick.\"];\r\n * const vocab = {\"<|endoftext|>\": 0, \"a\": 4, \"Ġquick\": 5, \"Ġfox\": 6};\r\n * const merges = [\r\n *  \"Ġ q\", \"u i\", \"c k\", \"ui ck\", \"Ġq uick\", \"Ġ f\", \"o x\", \"Ġf ox\"\r\n * ];\r\n * const tokenizer = new GPT2Tokenizer({vocabulary: vocab, merges});\r\n * const preprocessor =  new GPT2CausalLMPreprocessor({\r\n *  tokenizer,\r\n *  sequence_length: 128,\r\n * });\r\n * const backbone = new GPT2Backbone({\r\n *  vocabularysize: 30552,\r\n *  numlayers: 4,\r\n *  numheads: 4,\r\n *  hiddendim: 256,\r\n *  intermediatedim: 512,\r\n *  maxSequenceLength: 128,\r\n * });\r\n * const gpt2LM = new GPT2CausalLM({backbone, preprocessor});\r\n * gpt2LM.fit(features, {batch_size: 2});\r\n * ```\r\n */\r\nexport class GPT2CausalLM extends GenerativeTask {\r\n  /** @nocollapse */\r\n  static override className = 'GPT2CausalLM';\r\n\r\n  constructor(args: GPT2CausalLMArgs) {\r\n    super(args);\r\n    throw new NotImplementedError(`Uses ${ReverseEmbedding}.`);\r\n  }\r\n\r\n  static override presets<T extends serialization.Serializable>(\r\n    cls: serialization.SerializableConstructor<T>\r\n  ): {} {\r\n    throw new NotImplementedError();\r\n  }\r\n\r\n  /**\r\n   * Forward pass of `GPT2CausalLM` with cache.\r\n   *\r\n   * `callWithCache` adds an additional forward pass for the model for\r\n   * autoregressive inference. Unlike calling the model directly, this method\r\n   * allows caching previous key/value Tensors in multi-head attention layer,\r\n   * and avoids recomputing the outputs of seen tokens.\r\n   *\r\n   * @param tokenIds a dense int Tensor with shape `[batchSize, maxLength]`.\r\n   * @param cache a dense float Tensor, the cache of key and value.\r\n   * @param cacheUpdateIndex Integer. The index of current inputs in the whole\r\n   *  sequence.\r\n   * @returns [logits, hiddenStates, cache], where `logits` is the\r\n   *  language model logits for the input tokenIds, `hiddenStates` is\r\n   *  the final hidden representation of the input tokens, and `cache` is\r\n   *  the decoding cache.\r\n   */\r\n  callWithCache(\r\n    tokenIds: Tensor,\r\n    cache: Tensor,\r\n    cacheUpdateIndex: number\r\n  ): [Tensor, Tensor, Tensor] {\r\n    throw new NotImplementedError();\r\n  }\r\n\r\n  /**\r\n   * Build an empty cache for use with `callWithCache()`.\r\n   */\r\n  private buildCache(tokenIds: Tensor): [Tensor, Tensor] {\r\n    throw new NotImplementedError();\r\n  }\r\n\r\n  /**\r\n   * A compilable generation function for a single batch of inputs.\r\n   *\r\n   * This function represents the inner generation function for a single batch\r\n   *  of inputs.\r\n   *\r\n   * @param inputs An object with two keys `tokenIds` and `paddingMask` and\r\n   *  batched tensor values.\r\n   * @param endTokenId The id of the end token to stop on. If all\r\n   *  sequences have produced a new `endTokenId`, generation will stop.\r\n   */\r\n  override generateStep(\r\n    inputs: NamedTensorMap,\r\n    endTokenId: number\r\n  ): NamedTensorMap {\r\n    throw new NotImplementedError(`Uses ${this.buildCache}`);\r\n  }\r\n}\r\nserialization.registerClass(GPT2CausalLM);\r\n"]}