UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

84 lines (83 loc) 3.62 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/models/generative_task" /> /** * Base class for Generative Task models. */ import { NamedTensorMap, Tensor } from '@tensorflow/tfjs-core'; import { ModelCompileArgs } from '../../../engine/training'; import { Task } from './task'; export type GenerateFn = (inputs: NamedTensorMap, endTokenId?: number) => NamedTensorMap; /** * Base class for Generative Task models. */ export declare class GenerativeTask extends Task { /** @nocollapse */ static className: string; protected generateFunction: GenerateFn; compile(args: ModelCompileArgs): void; /** * Run the generation on a single batch of input. */ generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap; /** * Create or return the compiled generation function. */ makeGenerateFunction(): GenerateFn; /** * Normalize user input to the generate function. * * This function converts all inputs to tensors, adds a batch dimension if * necessary, and returns a iterable "dataset like" object. */ protected normalizeGenerateInputs(inputs: Tensor): [Tensor, boolean]; /** * Normalize user output from the generate function. * * This function converts all output to numpy (for integer output), or * python strings (for string output). If a batch dimension was added to * the input, it is removed from the output (so generate can be string in, * string out). */ protected normalizeGenerateOutputs(outputs: Tensor, inputIsScalar: boolean): Tensor; /** * Generate text given prompt `inputs`. * * This method generates text based on given `inputs`. The sampling method * used for generation can be set via the `compile()` method. * * `inputs` will be handled as a single batch. * * If a `preprocessor` is attached to the model, `inputs` will be * preprocessed inside the `generate()` function and should match the * structure expected by the `preprocessor` layer (usually raw strings). * If a `preprocessor` is not attached, inputs should match the structure * expected by the `backbone`. See the example usage above for a * demonstration of each. * * @param inputs tensor data. If a `preprocessor` is attached to the model, * `inputs` should match the structure expected by the `preprocessor` layer. * If a `preprocessor` is not attached, `inputs` should match the structure * expected the the `backbone` model. * @param maxLength Integer. The max length of the generated sequence. * Will default to the max configured `sequenceLength` of the * `preprocessor`. If `preprocessor` is `null`, `inputs` should be * should be padded to the desired maximum length and this argument * will be ignored. */ generate(inputs: Tensor, maxLength?: number): void; }