UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

223 lines (222 loc) 9.08 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /// <amd-module name="@tensorflow/tfjs-layers/dist/base_callbacks" /> import { Tensor } from '@tensorflow/tfjs-core'; import { Container } from './engine/container'; import { Logs, UnresolvedLogs } from './logs'; /** Verbosity logging level when fitting a model. */ export declare enum ModelLoggingVerbosity { SILENT = 0, VERBOSE = 1 } /** How often to yield to the main thread when training (in ms). */ export declare const DEFAULT_YIELD_EVERY_MS = 125; export type Params = { [key: string]: number | string | boolean | number[] | string[] | boolean[]; }; export type YieldEveryOptions = 'auto' | 'batch' | 'epoch' | 'never' | number; /** * Abstract base class used to build new callbacks. * * The `logs` dictionary that callback methods take as argument will contain * keys for quantities relevant to the current batch or epoch. * * Currently, the `.fit()` method of the `Sequential` model class * will include the following quantities in the `logs` that * it passes to its callbacks: * * onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss` * (if validation is enabled in `fit`), and `valAcc` (if validation and * accuracy monitoring are enabled). * onBatchBegin: Logs include `size`, the number of samples in the current * batch. * onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring * is enabled). */ export declare abstract class BaseCallback { validationData: Tensor | Tensor[]; /** * Training parameters (eg. verbosity, batch size, number of epochs...). */ params: Params; setParams(params: Params): void; onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>; onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>; onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>; onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>; onTrainBegin(logs?: UnresolvedLogs): Promise<void>; onTrainEnd(logs?: UnresolvedLogs): Promise<void>; setModel(model: Container): void; } /** * Container abstracting a list of callbacks. */ export declare class CallbackList { callbacks: BaseCallback[]; queueLength: number; /** * Constructor of CallbackList. * @param callbacks Array of `Callback` instances. * @param queueLength Queue length for keeping running statistics over * callback execution time. */ constructor(callbacks?: BaseCallback[], queueLength?: number); append(callback: BaseCallback): void; setParams(params: Params): void; setModel(model: Container): void; /** * Called at the start of an epoch. * @param epoch Index of epoch. * @param logs Dictionary of logs. */ onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>; /** * Called at the end of an epoch. * @param epoch Index of epoch. * @param logs Dictionary of logs. */ onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>; /** * Called right before processing a batch. * @param batch Index of batch within the current epoch. * @param logs Dictionary of logs. */ onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>; /** * Called at the end of a batch. * @param batch Index of batch within the current epoch. * @param logs Dictionary of logs. */ onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>; /** * Called at the beginning of training. * @param logs Dictionary of logs. */ onTrainBegin(logs?: UnresolvedLogs): Promise<void>; /** * Called at the end of training. * @param logs Dictionary of logs. */ onTrainEnd(logs?: UnresolvedLogs): Promise<void>; } /** * Callback that accumulates epoch averages of metrics. * * This callback is automatically applied to every LayersModel. */ export declare class BaseLogger extends BaseCallback { private seen; private totals; constructor(); onEpochBegin(epoch: number): Promise<void>; onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>; onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>; } /** * Callback that records events into a `History` object. This callback is * automatically applied to every TF.js Layers model. The `History` object * gets returned by the `fit` method of models. */ export declare class History extends BaseCallback { epoch: number[]; history: { [key: string]: Array<number | Tensor>; }; onTrainBegin(logs?: UnresolvedLogs): Promise<void>; onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>; /** * Await the values of all losses and metrics. */ syncData(): Promise<void>; } export interface CustomCallbackArgs { onTrainBegin?: (logs?: Logs) => void | Promise<void>; onTrainEnd?: (logs?: Logs) => void | Promise<void>; onEpochBegin?: (epoch: number, logs?: Logs) => void | Promise<void>; onEpochEnd?: (epoch: number, logs?: Logs) => void | Promise<void>; onBatchBegin?: (batch: number, logs?: Logs) => void | Promise<void>; onBatchEnd?: (batch: number, logs?: Logs) => void | Promise<void>; onYield?: (epoch: number, batch: number, logs: Logs) => void | Promise<void>; nowFunc?: Function; nextFrameFunc?: Function; } /** * Custom callback for training. */ export declare class CustomCallback extends BaseCallback { protected readonly trainBegin: (logs?: Logs) => void | Promise<void>; protected readonly trainEnd: (logs?: Logs) => void | Promise<void>; protected readonly epochBegin: (epoch: number, logs?: Logs) => void | Promise<void>; protected readonly epochEnd: (epoch: number, logs?: Logs) => void | Promise<void>; protected readonly batchBegin: (batch: number, logs?: Logs) => void | Promise<void>; protected readonly batchEnd: (batch: number, logs?: Logs) => void | Promise<void>; protected readonly yield: (epoch: number, batch: number, logs: Logs) => void | Promise<void>; private yieldEvery; private currentEpoch; nowFunc: Function; nextFrameFunc: Function; constructor(args: CustomCallbackArgs, yieldEvery?: YieldEveryOptions); maybeWait(epoch: number, batch: number, logs: UnresolvedLogs): Promise<void>; onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>; onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>; onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>; onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>; onTrainBegin(logs?: UnresolvedLogs): Promise<void>; onTrainEnd(logs?: UnresolvedLogs): Promise<void>; } /** * Standardize callbacks or configurations of them to an Array of callbacks. */ export declare function standardizeCallbacks(callbacks: BaseCallback | BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[], yieldEvery: YieldEveryOptions): BaseCallback[]; export declare type BaseCallbackConstructor = { new (): BaseCallback; }; /** * A global registry for callback constructors to be used during * LayersModel.fit(). */ export declare class CallbackConstructorRegistry { private static constructors; /** * Blocks public access to constructor. */ private constructor(); /** * Register a tf.LayersModel.fit() callback constructor. * * The registered callback constructor will be used to instantiate * callbacks for every tf.LayersModel.fit() call afterwards. * * @param verbosityLevel Level of verbosity at which the `callbackConstructor` * is to be reigstered. * @param callbackConstructor A no-arg constructor for `tf.Callback`. * @throws Error, if the same callbackConstructor has been registered before, * either at the same or a different `verbosityLevel`. */ static registerCallbackConstructor(verbosityLevel: number, callbackConstructor: BaseCallbackConstructor): void; private static checkForDuplicate; /** * Clear all registered callback constructors. */ protected static clear(): void; /** * Create callbacks using the registered callback constructors. * * Given `verbosityLevel`, all constructors registered at that level or above * will be called and the instantiated callbacks will be used. * * @param verbosityLevel: Level of verbosity. */ static createCallbacks(verbosityLevel: number): BaseCallback[]; } export declare function configureCallbacks(callbacks: BaseCallback[], verbose: ModelLoggingVerbosity, epochs: number, initialEpoch: number, numTrainSamples: number, stepsPerEpoch: number, batchSize: number, doValidation: boolean, callbackMetrics: string[]): { callbackList: CallbackList; history: History; };