@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
90 lines • 9.96 kB
JavaScript
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { NotImplementedError } from '../../../errors';
import { Task } from './task';
/**
* Base class for Generative Task models.
*/
class GenerativeTask extends Task {
compile(args) {
throw new NotImplementedError();
}
/**
* Run the generation on a single batch of input.
*/
generateStep(inputs, endTokenId) {
throw new NotImplementedError();
}
/**
* Create or return the compiled generation function.
*/
makeGenerateFunction() {
throw new NotImplementedError();
}
/**
* Normalize user input to the generate function.
*
* This function converts all inputs to tensors, adds a batch dimension if
* necessary, and returns a iterable "dataset like" object.
*/
normalizeGenerateInputs(inputs) {
throw new NotImplementedError();
}
/**
* Normalize user output from the generate function.
*
* This function converts all output to numpy (for integer output), or
* python strings (for string output). If a batch dimension was added to
* the input, it is removed from the output (so generate can be string in,
* string out).
*/
normalizeGenerateOutputs(outputs, inputIsScalar) {
throw new NotImplementedError();
}
/**
* Generate text given prompt `inputs`.
*
* This method generates text based on given `inputs`. The sampling method
* used for generation can be set via the `compile()` method.
*
* `inputs` will be handled as a single batch.
*
* If a `preprocessor` is attached to the model, `inputs` will be
* preprocessed inside the `generate()` function and should match the
* structure expected by the `preprocessor` layer (usually raw strings).
* If a `preprocessor` is not attached, inputs should match the structure
* expected by the `backbone`. See the example usage above for a
* demonstration of each.
*
* @param inputs tensor data. If a `preprocessor` is attached to the model,
* `inputs` should match the structure expected by the `preprocessor` layer.
* If a `preprocessor` is not attached, `inputs` should match the structure
* expected the the `backbone` model.
* @param maxLength Integer. The max length of the generated sequence.
* Will default to the max configured `sequenceLength` of the
* `preprocessor`. If `preprocessor` is `null`, `inputs` should be
* should be padded to the desired maximum length and this argument
* will be ignored.
*/
generate(inputs, maxLength) {
throw new NotImplementedError();
}
}
/** @nocollapse */
GenerativeTask.className = 'GenerativeTask';
export { GenerativeTask };
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZ2VuZXJhdGl2ZV90YXNrLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1sYXllcnMvc3JjL2xheWVycy9ubHAvbW9kZWxzL2dlbmVyYXRpdmVfdGFzay50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFTSCxPQUFPLEVBQUUsbUJBQW1CLEVBQUUsTUFBTSxpQkFBaUIsQ0FBQztBQUd0RCxPQUFPLEVBQUUsSUFBSSxFQUFFLE1BQU0sUUFBUSxDQUFDO0FBSzlCOztHQUVHO0FBQ0gsTUFBYSxjQUFlLFNBQVEsSUFBSTtJQU03QixPQUFPLENBQUMsSUFBc0I7UUFDckMsTUFBTSxJQUFJLG1CQUFtQixFQUFFLENBQUM7SUFDbEMsQ0FBQztJQUVEOztPQUVHO0lBQ0gsWUFBWSxDQUNWLE1BQXNCLEVBQ3RCLFVBQWtCO1FBRWxCLE1BQU0sSUFBSSxtQkFBbUIsRUFBRSxDQUFDO0lBQ2xDLENBQUM7SUFFRDs7T0FFRztJQUNILG9CQUFvQjtRQUNsQixNQUFNLElBQUksbUJBQW1CLEVBQUUsQ0FBQztJQUNsQyxDQUFDO0lBRUQ7Ozs7O09BS0c7SUFDTyx1QkFBdUIsQ0FBQyxNQUFjO1FBQzlDLE1BQU0sSUFBSSxtQkFBbUIsRUFBRSxDQUFDO0lBQ2xDLENBQUM7SUFFRDs7Ozs7OztPQU9HO0lBQ08sd0JBQXdCLENBQ2hDLE9BQWUsRUFDZixhQUFzQjtRQUV0QixNQUFNLElBQUksbUJBQW1CLEVBQUUsQ0FBQztJQUNsQyxDQUFDO0lBRUQ7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7OztPQXdCRztJQUNILFFBQVEsQ0FBQyxNQUFjLEVBQUUsU0FBa0I7UUFDekMsTUFBTSxJQUFJLG1CQUFtQixFQUFFLENBQUM7SUFDbEMsQ0FBQzs7QUE5RUQsa0JBQWtCO0FBQ0Ysd0JBQVMsR0FBRyxnQkFBZ0IsQ0FBQztTQUZsQyxjQUFjIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjMgR29vZ2xlIExMQy5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG4vKipcbiAqICBCYXNlIGNsYXNzIGZvciBHZW5lcmF0aXZlIFRhc2sgbW9kZWxzLlxuICovXG5cbi8qIE9yaWdpbmFsIHNvdXJjZToga2VyYXNfbmxwL21vZGVscy9nZW5lcmF0aXZlX3Rhc2sucHkgKi9cbmltcG9ydCB7IE5hbWVkVGVuc29yTWFwLCBUZW5zb3IgfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQgeyBOb3RJbXBsZW1lbnRlZEVycm9yIH0gZnJvbSAnLi4vLi4vLi4vZXJyb3JzJztcbmltcG9ydCB7IE1vZGVsQ29tcGlsZUFyZ3MgfSBmcm9tICcuLi8uLi8uLi9lbmdpbmUvdHJhaW5pbmcnO1xuXG5pbXBvcnQgeyBUYXNrIH0gZnJvbSAnLi90YXNrJztcblxuZXhwb3J0IHR5cGUgR2VuZXJhdGVGbiA9XG4gIChpbnB1dHM6IE5hbWVkVGVuc29yTWFwLCBlbmRUb2tlbklkPzogbnVtYmVyKSA9PiBOYW1lZFRlbnNvck1hcDtcblxuLyoqXG4gKiAgQmFzZSBjbGFzcyBmb3IgR2VuZXJhdGl2ZSBUYXNrIG1vZGVscy5cbiAqL1xuZXhwb3J0IGNsYXNzIEdlbmVyYXRpdmVUYXNrIGV4dGVuZHMgVGFzayB7XG4gIC8qKiBAbm9jb2xsYXBzZSAqL1xuICBzdGF0aWMgb3ZlcnJpZGUgY2xhc3NOYW1lID0gJ0dlbmVyYXRpdmVUYXNrJztcblxuICBwcm90ZWN0ZWQgZ2VuZXJhdGVGdW5jdGlvbjogR2VuZXJhdGVGbjtcblxuICBvdmVycmlkZSBjb21waWxlKGFyZ3M6IE1vZGVsQ29tcGlsZUFyZ3MpOiB2b2lkIHtcbiAgICB0aHJvdyBuZXcgTm90SW1wbGVtZW50ZWRFcnJvcigpO1xuICB9XG5cbiAgLyoqXG4gICAqIFJ1biB0aGUgZ2VuZXJhdGlvbiBvbiBhIHNpbmdsZSBiYXRjaCBvZiBpbnB1dC5cbiAgICovXG4gIGdlbmVyYXRlU3RlcChcbiAgICBpbnB1dHM6IE5hbWVkVGVuc29yTWFwLFxuICAgIGVuZFRva2VuSWQ6IG51bWJlclxuICApOiBOYW1lZFRlbnNvck1hcCB7XG4gICAgdGhyb3cgbmV3IE5vdEltcGxlbWVudGVkRXJyb3IoKTtcbiAgfVxuXG4gIC8qKlxuICAgKiBDcmVhdGUgb3IgcmV0dXJuIHRoZSBjb21waWxlZCBnZW5lcmF0aW9uIGZ1bmN0aW9uLlxuICAgKi9cbiAgbWFrZUdlbmVyYXRlRnVuY3Rpb24oKTogR2VuZXJhdGVGbiB7XG4gICAgdGhyb3cgbmV3IE5vdEltcGxlbWVudGVkRXJyb3IoKTtcbiAgfVxuXG4gIC8qKlxuICAgKiBOb3JtYWxpemUgdXNlciBpbnB1dCB0byB0aGUgZ2VuZXJhdGUgZnVuY3Rpb24uXG4gICAqXG4gICAqIFRoaXMgZnVuY3Rpb24gY29udmVydHMgYWxsIGlucHV0cyB0byB0ZW5zb3JzLCBhZGRzIGEgYmF0Y2ggZGltZW5zaW9uIGlmXG4gICAqIG5lY2Vzc2FyeSwgYW5kIHJldHVybnMgYSBpdGVyYWJsZSBcImRhdGFzZXQgbGlrZVwiIG9iamVjdC5cbiAgICovXG4gIHByb3RlY3RlZCBub3JtYWxpemVHZW5lcmF0ZUlucHV0cyhpbnB1dHM6IFRlbnNvcik6IFtUZW5zb3IsIGJvb2xlYW5dIHtcbiAgICB0aHJvdyBuZXcgTm90SW1wbGVtZW50ZWRFcnJvcigpO1xuICB9XG5cbiAgLyoqXG4gICAqIE5vcm1hbGl6ZSB1c2VyIG91dHB1dCBmcm9tIHRoZSBnZW5lcmF0ZSBmdW5jdGlvbi5cbiAgICpcbiAgICogVGhpcyBmdW5jdGlvbiBjb252ZXJ0cyBhbGwgb3V0cHV0IHRvIG51bXB5IChmb3IgaW50ZWdlciBvdXRwdXQpLCBvclxuICAgKiBweXRob24gc3RyaW5ncyAoZm9yIHN0cmluZyBvdXRwdXQpLiBJZiBhIGJhdGNoIGRpbWVuc2lvbiB3YXMgYWRkZWQgdG9cbiAgICogdGhlIGlucHV0LCBpdCBpcyByZW1vdmVkIGZyb20gdGhlIG91dHB1dCAoc28gZ2VuZXJhdGUgY2FuIGJlIHN0cmluZyBpbixcbiAgICogc3RyaW5nIG91dCkuXG4gICAqL1xuICBwcm90ZWN0ZWQgbm9ybWFsaXplR2VuZXJhdGVPdXRwdXRzKFxuICAgIG91dHB1dHM6IFRlbnNvcixcbiAgICBpbnB1dElzU2NhbGFyOiBib29sZWFuXG4gICk6IFRlbnNvciB7XG4gICAgdGhyb3cgbmV3IE5vdEltcGxlbWVudGVkRXJyb3IoKTtcbiAgfVxuXG4gIC8qKlxuICAgKiBHZW5lcmF0ZSB0ZXh0IGdpdmVuIHByb21wdCBgaW5wdXRzYC5cbiAgICpcbiAgICogVGhpcyBtZXRob2QgZ2VuZXJhdGVzIHRleHQgYmFzZWQgb24gZ2l2ZW4gYGlucHV0c2AuIFRoZSBzYW1wbGluZyBtZXRob2RcbiAgICogdXNlZCBmb3IgZ2VuZXJhdGlvbiBjYW4gYmUgc2V0IHZpYSB0aGUgYGNvbXBpbGUoKWAgbWV0aG9kLlxuICAgKlxuICAgKiBgaW5wdXRzYCB3aWxsIGJlIGhhbmRsZWQgYXMgYSBzaW5nbGUgYmF0Y2guXG4gICAqXG4gICAqIElmIGEgYHByZXByb2Nlc3NvcmAgaXMgYXR0YWNoZWQgdG8gdGhlIG1vZGVsLCBgaW5wdXRzYCB3aWxsIGJlXG4gICAqIHByZXByb2Nlc3NlZCBpbnNpZGUgdGhlIGBnZW5lcmF0ZSgpYCBmdW5jdGlvbiBhbmQgc2hvdWxkIG1hdGNoIHRoZVxuICAgKiBzdHJ1Y3R1cmUgZXhwZWN0ZWQgYnkgdGhlIGBwcmVwcm9jZXNzb3JgIGxheWVyICh1c3VhbGx5IHJhdyBzdHJpbmdzKS5cbiAgICogSWYgYSBgcHJlcHJvY2Vzc29yYCBpcyBub3QgYXR0YWNoZWQsIGlucHV0cyBzaG91bGQgbWF0Y2ggdGhlIHN0cnVjdHVyZVxuICAgKiBleHBlY3RlZCBieSB0aGUgYGJhY2tib25lYC4gU2VlIHRoZSBleGFtcGxlIHVzYWdlIGFib3ZlIGZvciBhXG4gICAqIGRlbW9uc3RyYXRpb24gb2YgZWFjaC5cbiAgICpcbiAgICogQHBhcmFtIGlucHV0cyB0ZW5zb3IgZGF0YS4gSWYgYSBgcHJlcHJvY2Vzc29yYCBpcyBhdHRhY2hlZCB0byB0aGUgbW9kZWwsXG4gICAqICBgaW5wdXRzYCBzaG91bGQgbWF0Y2ggdGhlIHN0cnVjdHVyZSBleHBlY3RlZCBieSB0aGUgYHByZXByb2Nlc3NvcmAgbGF5ZXIuXG4gICAqICBJZiBhIGBwcmVwcm9jZXNzb3JgIGlzIG5vdCBhdHRhY2hlZCwgYGlucHV0c2Agc2hvdWxkIG1hdGNoIHRoZSBzdHJ1Y3R1cmVcbiAgICogIGV4cGVjdGVkIHRoZSB0aGUgYGJhY2tib25lYCBtb2RlbC5cbiAgICogQHBhcmFtIG1heExlbmd0aCBJbnRlZ2VyLiBUaGUgbWF4IGxlbmd0aCBvZiB0aGUgZ2VuZXJhdGVkIHNlcXVlbmNlLlxuICAgKiAgV2lsbCBkZWZhdWx0IHRvIHRoZSBtYXggY29uZmlndXJlZCBgc2VxdWVuY2VMZW5ndGhgIG9mIHRoZVxuICAgKiAgYHByZXByb2Nlc3NvcmAuIElmIGBwcmVwcm9jZXNzb3JgIGlzIGBudWxsYCwgYGlucHV0c2Agc2hvdWxkIGJlXG4gICAqICBzaG91bGQgYmUgcGFkZGVkIHRvIHRoZSBkZXNpcmVkIG1heGltdW0gbGVuZ3RoIGFuZCB0aGlzIGFyZ3VtZW50XG4gICAqICB3aWxsIGJlIGlnbm9yZWQuXG4gICAqL1xuICBnZW5lcmF0ZShpbnB1dHM6IFRlbnNvciwgbWF4TGVuZ3RoPzogbnVtYmVyKSB7XG4gICAgdGhyb3cgbmV3IE5vdEltcGxlbWVudGVkRXJyb3IoKTtcbiAgfVxufVxuIl19