@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
223 lines (222 loc) • 9.08 kB
TypeScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/// <amd-module name="@tensorflow/tfjs-layers/dist/base_callbacks" />
import { Tensor } from '@tensorflow/tfjs-core';
import { Container } from './engine/container';
import { Logs, UnresolvedLogs } from './logs';
/** Verbosity logging level when fitting a model. */
export declare enum ModelLoggingVerbosity {
SILENT = 0,
VERBOSE = 1
}
/** How often to yield to the main thread when training (in ms). */
export declare const DEFAULT_YIELD_EVERY_MS = 125;
export type Params = {
[key: string]: number | string | boolean | number[] | string[] | boolean[];
};
export type YieldEveryOptions = 'auto' | 'batch' | 'epoch' | 'never' | number;
/**
* Abstract base class used to build new callbacks.
*
* The `logs` dictionary that callback methods take as argument will contain
* keys for quantities relevant to the current batch or epoch.
*
* Currently, the `.fit()` method of the `Sequential` model class
* will include the following quantities in the `logs` that
* it passes to its callbacks:
*
* onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss`
* (if validation is enabled in `fit`), and `valAcc` (if validation and
* accuracy monitoring are enabled).
* onBatchBegin: Logs include `size`, the number of samples in the current
* batch.
* onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring
* is enabled).
*/
export declare abstract class BaseCallback {
validationData: Tensor | Tensor[];
/**
* Training parameters (eg. verbosity, batch size, number of epochs...).
*/
params: Params;
setParams(params: Params): void;
onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
setModel(model: Container): void;
}
/**
* Container abstracting a list of callbacks.
*/
export declare class CallbackList {
callbacks: BaseCallback[];
queueLength: number;
/**
* Constructor of CallbackList.
* @param callbacks Array of `Callback` instances.
* @param queueLength Queue length for keeping running statistics over
* callback execution time.
*/
constructor(callbacks?: BaseCallback[], queueLength?: number);
append(callback: BaseCallback): void;
setParams(params: Params): void;
setModel(model: Container): void;
/**
* Called at the start of an epoch.
* @param epoch Index of epoch.
* @param logs Dictionary of logs.
*/
onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
/**
* Called at the end of an epoch.
* @param epoch Index of epoch.
* @param logs Dictionary of logs.
*/
onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
/**
* Called right before processing a batch.
* @param batch Index of batch within the current epoch.
* @param logs Dictionary of logs.
*/
onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
/**
* Called at the end of a batch.
* @param batch Index of batch within the current epoch.
* @param logs Dictionary of logs.
*/
onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
/**
* Called at the beginning of training.
* @param logs Dictionary of logs.
*/
onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
/**
* Called at the end of training.
* @param logs Dictionary of logs.
*/
onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
}
/**
* Callback that accumulates epoch averages of metrics.
*
* This callback is automatically applied to every LayersModel.
*/
export declare class BaseLogger extends BaseCallback {
private seen;
private totals;
constructor();
onEpochBegin(epoch: number): Promise<void>;
onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
}
/**
* Callback that records events into a `History` object. This callback is
* automatically applied to every TF.js Layers model. The `History` object
* gets returned by the `fit` method of models.
*/
export declare class History extends BaseCallback {
epoch: number[];
history: {
[key: string]: Array<number | Tensor>;
};
onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
/**
* Await the values of all losses and metrics.
*/
syncData(): Promise<void>;
}
export interface CustomCallbackArgs {
onTrainBegin?: (logs?: Logs) => void | Promise<void>;
onTrainEnd?: (logs?: Logs) => void | Promise<void>;
onEpochBegin?: (epoch: number, logs?: Logs) => void | Promise<void>;
onEpochEnd?: (epoch: number, logs?: Logs) => void | Promise<void>;
onBatchBegin?: (batch: number, logs?: Logs) => void | Promise<void>;
onBatchEnd?: (batch: number, logs?: Logs) => void | Promise<void>;
onYield?: (epoch: number, batch: number, logs: Logs) => void | Promise<void>;
nowFunc?: Function;
nextFrameFunc?: Function;
}
/**
* Custom callback for training.
*/
export declare class CustomCallback extends BaseCallback {
protected readonly trainBegin: (logs?: Logs) => void | Promise<void>;
protected readonly trainEnd: (logs?: Logs) => void | Promise<void>;
protected readonly epochBegin: (epoch: number, logs?: Logs) => void | Promise<void>;
protected readonly epochEnd: (epoch: number, logs?: Logs) => void | Promise<void>;
protected readonly batchBegin: (batch: number, logs?: Logs) => void | Promise<void>;
protected readonly batchEnd: (batch: number, logs?: Logs) => void | Promise<void>;
protected readonly yield: (epoch: number, batch: number, logs: Logs) => void | Promise<void>;
private yieldEvery;
private currentEpoch;
nowFunc: Function;
nextFrameFunc: Function;
constructor(args: CustomCallbackArgs, yieldEvery?: YieldEveryOptions);
maybeWait(epoch: number, batch: number, logs: UnresolvedLogs): Promise<void>;
onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
}
/**
* Standardize callbacks or configurations of them to an Array of callbacks.
*/
export declare function standardizeCallbacks(callbacks: BaseCallback | BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[], yieldEvery: YieldEveryOptions): BaseCallback[];
export declare type BaseCallbackConstructor = {
new (): BaseCallback;
};
/**
* A global registry for callback constructors to be used during
* LayersModel.fit().
*/
export declare class CallbackConstructorRegistry {
private static constructors;
/**
* Blocks public access to constructor.
*/
private constructor();
/**
* Register a tf.LayersModel.fit() callback constructor.
*
* The registered callback constructor will be used to instantiate
* callbacks for every tf.LayersModel.fit() call afterwards.
*
* @param verbosityLevel Level of verbosity at which the `callbackConstructor`
* is to be reigstered.
* @param callbackConstructor A no-arg constructor for `tf.Callback`.
* @throws Error, if the same callbackConstructor has been registered before,
* either at the same or a different `verbosityLevel`.
*/
static registerCallbackConstructor(verbosityLevel: number, callbackConstructor: BaseCallbackConstructor): void;
private static checkForDuplicate;
/**
* Clear all registered callback constructors.
*/
protected static clear(): void;
/**
* Create callbacks using the registered callback constructors.
*
* Given `verbosityLevel`, all constructors registered at that level or above
* will be called and the instantiated callbacks will be used.
*
* @param verbosityLevel: Level of verbosity.
*/
static createCallbacks(verbosityLevel: number): BaseCallback[];
}
export declare function configureCallbacks(callbacks: BaseCallback[], verbose: ModelLoggingVerbosity, epochs: number, initialEpoch: number, numTrainSamples: number, stepsPerEpoch: number, batchSize: number, doValidation: boolean, callbackMetrics: string[]): {
callbackList: CallbackList;
history: History;
};