@tensorflow/tfjs-node
Version:
This repository provides native TensorFlow execution in backend JavaScript applications under the Node.js runtime, accelerated by the TensorFlow C binary under the hood. It provides the same API as [TensorFlow.js](https://js.tensorflow.org/api/latest/).
367 lines (341 loc) • 12.6 kB
text/typescript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {CustomCallback, LayersModel, Logs, nextFrame, util} from '@tensorflow/tfjs';
import * as path from 'path';
import * as ProgressBar from 'progress';
import {summaryFileWriter, SummaryFileWriter} from './tensorboard';
type LogFunction = (message: string) => void;
// A helper class created for testing with the jasmine `spyOn` method, which
// operates only on member methods of objects.
// tslint:disable-next-line:no-any
export const progressBarHelper: {ProgressBar: any, log: LogFunction} = {
ProgressBar,
log: console.log
};
/**
* Terminal-based progress bar callback for tf.Model.fit().
*/
export class ProgbarLogger extends CustomCallback {
private numTrainBatchesPerEpoch: number;
private progressBar: ProgressBar;
private currentEpochBegin: number;
private epochDurationMillis: number;
private usPerStep: number;
private batchesInLatestEpoch: number;
private terminalWidth: number;
private readonly RENDER_THROTTLE_MS = 50;
/**
* Construtor of LoggingCallback.
*/
constructor() {
super({
onTrainBegin: async (logs?: Logs) => {
const samples = this.params.samples as number;
const batchSize = this.params.batchSize as number;
const steps = this.params.steps as number;
if (samples != null || steps != null) {
this.numTrainBatchesPerEpoch =
samples != null ? Math.ceil(samples / batchSize) : steps;
} else {
// Undetermined number of batches per epoch, e.g., due to
// `fitDataset()` without `batchesPerEpoch`.
this.numTrainBatchesPerEpoch = 0;
}
},
onEpochBegin: async (epoch: number, logs?: Logs) => {
progressBarHelper.log(`Epoch ${epoch + 1} / ${this.params.epochs}`);
this.currentEpochBegin = util.now();
this.epochDurationMillis = null;
this.usPerStep = null;
this.batchesInLatestEpoch = 0;
this.terminalWidth = process.stderr.columns;
},
onBatchEnd: async (batch: number, logs?: Logs) => {
this.batchesInLatestEpoch++;
if (batch === 0) {
this.progressBar = new progressBarHelper.ProgressBar(
'eta=:eta :bar :placeholderForLossesAndMetrics', {
width: Math.floor(0.5 * this.terminalWidth),
total: this.numTrainBatchesPerEpoch + 1,
head: `>`,
renderThrottle: this.RENDER_THROTTLE_MS
});
}
const maxMetricsStringLength =
Math.floor(this.terminalWidth * 0.5 - 12);
const tickTokens = {
placeholderForLossesAndMetrics:
this.formatLogsAsMetricsContent(logs, maxMetricsStringLength)
};
if (this.numTrainBatchesPerEpoch === 0) {
// Undetermined number of batches per epoch.
this.progressBar.tick(0, tickTokens);
} else {
this.progressBar.tick(tickTokens);
}
await nextFrame();
if (batch === this.numTrainBatchesPerEpoch - 1) {
this.epochDurationMillis = util.now() - this.currentEpochBegin;
this.usPerStep = this.params.samples != null ?
this.epochDurationMillis / (this.params.samples as number) * 1e3 :
this.epochDurationMillis / this.batchesInLatestEpoch * 1e3;
}
},
onEpochEnd: async (epoch: number, logs?: Logs) => {
if (this.epochDurationMillis == null) {
// In cases where the number of batches per epoch is not determined,
// the calculation of the per-step duration is done at the end of the
// epoch. N.B., this includes the time spent on validation.
this.epochDurationMillis = util.now() - this.currentEpochBegin;
this.usPerStep =
this.epochDurationMillis / this.batchesInLatestEpoch * 1e3;
}
this.progressBar.tick({placeholderForLossesAndMetrics: ''});
const lossesAndMetricsString = this.formatLogsAsMetricsContent(logs);
progressBarHelper.log(
`${this.epochDurationMillis.toFixed(0)}ms ` +
`${this.usPerStep.toFixed(0)}us/step - ` +
`${lossesAndMetricsString}`);
await nextFrame();
},
});
}
private formatLogsAsMetricsContent(logs: Logs, maxMetricsLength?: number):
string {
let metricsContent = '';
const keys = Object.keys(logs).sort();
for (const key of keys) {
if (this.isFieldRelevant(key)) {
const value = logs[key];
metricsContent += `${key}=${getSuccinctNumberDisplay(value)} `;
}
}
if (maxMetricsLength != null && metricsContent.length > maxMetricsLength) {
// Cut off metrics strings that are too long to avoid new lines being
// constantly created.
metricsContent = metricsContent.slice(0, maxMetricsLength - 3) + '...';
}
return metricsContent;
}
private isFieldRelevant(key: string) {
return key !== 'batch' && key !== 'size';
}
}
const BASE_NUM_DIGITS = 2;
const MAX_NUM_DECIMAL_PLACES = 4;
/**
* Get a succint string representation of a number.
*
* Uses decimal notation if the number isn't too small.
* Otherwise, use engineering notation.
*
* @param x Input number.
* @return Succinct string representing `x`.
*/
export function getSuccinctNumberDisplay(x: number): string {
const decimalPlaces = getDisplayDecimalPlaces(x);
return decimalPlaces > MAX_NUM_DECIMAL_PLACES ?
x.toExponential(BASE_NUM_DIGITS) :
x.toFixed(decimalPlaces);
}
/**
* Determine the number of decimal places to display.
*
* @param x Number to display.
* @return Number of decimal places to display for `x`.
*/
export function getDisplayDecimalPlaces(x: number): number {
if (!Number.isFinite(x) || x === 0 || x > 1 || x < -1) {
return BASE_NUM_DIGITS;
} else {
return BASE_NUM_DIGITS - Math.floor(Math.log10(Math.abs(x)));
}
}
export interface TensorBoardCallbackArgs {
/**
* The frequency at which loss and metric values are written to logs.
*
* Currently supported options are:
*
* - 'batch': Write logs at the end of every batch of training, in addition
* to the end of every epoch of training.
* - 'epoch': Write logs at the end of every epoch of training.
*
* Note that writing logs too often slows down the training.
*
* Default: 'epoch'.
*/
updateFreq?: 'batch'|'epoch';
/**
* The frequency (in epochs) at which to compute activation and weight
* histograms for the layers of the model.
*
* If set to 0, histograms won't be computed.
*
* Validation data (or split) must be specified for histogram visualizations.
*
* Default: 0.
*/
histogramFreq?: number;
}
/**
* Callback for logging to TensorBoard during training.
*
* Users are expected to access this class through the `tensorBoardCallback()`
* factory method instead.
*/
export class TensorBoardCallback extends CustomCallback {
private model: LayersModel = null;
private trainWriter: SummaryFileWriter;
private valWriter: SummaryFileWriter;
private batchesSeen: number;
private readonly args: TensorBoardCallbackArgs;
constructor(readonly logdir = './logs', args?: TensorBoardCallbackArgs) {
super({
onBatchEnd: async (batch: number, logs?: Logs) => {
this.batchesSeen++;
if (this.args.updateFreq !== 'epoch') {
this.logMetrics(logs, 'batch_', this.batchesSeen);
}
},
onEpochEnd: async (epoch: number, logs?: Logs) => {
this.logMetrics(logs, 'epoch_', epoch + 1);
if (this.args.histogramFreq > 0 &&
epoch % this.args.histogramFreq === 0) {
this.logWeights(epoch);
}
},
onTrainEnd: async (logs?: Logs) => {
if (this.trainWriter != null) {
this.trainWriter.flush();
}
if (this.valWriter != null) {
this.valWriter.flush();
}
}
});
this.args = args == null ? {} : args;
if (this.args.updateFreq == null) {
this.args.updateFreq = 'epoch';
}
util.assert(
['batch', 'epoch'].indexOf(this.args.updateFreq) !== -1,
() => `Expected updateFreq to be 'batch' or 'epoch', but got ` +
`${this.args.updateFreq}`);
if (this.args.histogramFreq == null) {
this.args.histogramFreq = 0;
}
util.assert(
Number.isInteger(this.args.histogramFreq) &&
this.args.histogramFreq >= 0,
() => `Expected histogramFreq to be a positive integer, but got ` +
`${this.args.histogramFreq}`);
this.batchesSeen = 0;
}
setModel(model: LayersModel): void {
// This method is inherited from BaseCallback. To avoid cyclical imports,
// that class uses Container instead of LayersModel, and uses a run-time
// check to make sure the model is a LayersModel.
// Since this subclass isn't imported by tfjs-layers, we can safely use type
// the parameter as a LayersModel.
this.model = model;
}
private logMetrics(logs: Logs, prefix: string, step: number) {
for (const key in logs) {
if (key === 'batch' || key === 'size' || key === 'num_steps') {
continue;
}
const VAL_PREFIX = 'val_';
if (key.startsWith(VAL_PREFIX)) {
this.ensureValWriterCreated();
const scalarName = prefix + key.slice(VAL_PREFIX.length);
this.valWriter.scalar(scalarName, logs[key], step);
} else {
this.ensureTrainWriterCreated();
this.trainWriter.scalar(`${prefix}${key}`, logs[key], step);
}
}
}
private logWeights(step: number) {
for (const weights of this.model.weights) {
this.trainWriter.histogram(weights.name, weights.read(), step);
}
}
private ensureTrainWriterCreated() {
this.trainWriter = summaryFileWriter(path.join(this.logdir, 'train'));
}
private ensureValWriterCreated() {
this.valWriter = summaryFileWriter(path.join(this.logdir, 'val'));
}
}
/**
* Callback for logging to TensorBoard during training.
*
* Writes the loss and metric values (if any) to the specified log directory
* (`logdir`) which can be ingested and visualized by TensorBoard.
* This callback is usually passed as a callback to `tf.Model.fit()` or
* `tf.Model.fitDataset()` calls during model training. The frequency at which
* the values are logged can be controlled with the `updateFreq` field of the
* configuration object (2nd argument).
*
* Usage example:
* ```js
* // Constructor a toy multilayer-perceptron regressor for demo purpose.
* const model = tf.sequential();
* model.add(
* tf.layers.dense({units: 100, activation: 'relu', inputShape: [200]}));
* model.add(tf.layers.dense({units: 1}));
* model.compile({
* loss: 'meanSquaredError',
* optimizer: 'sgd',
* metrics: ['MAE']
* });
*
* // Generate some random fake data for demo purpose.
* const xs = tf.randomUniform([10000, 200]);
* const ys = tf.randomUniform([10000, 1]);
* const valXs = tf.randomUniform([1000, 200]);
* const valYs = tf.randomUniform([1000, 1]);
*
* // Start model training process.
* await model.fit(xs, ys, {
* epochs: 100,
* validationData: [valXs, valYs],
* // Add the tensorBoard callback here.
* callbacks: tf.node.tensorBoard('/tmp/fit_logs_1')
* });
* ```
*
* Then you can use the following commands to point tensorboard
* to the logdir:
*
* ```sh
* pip install tensorboard # Unless you've already installed it.
* tensorboard --logdir /tmp/fit_logs_1
* ```
*
* @param logdir Directory to which the logs will be written.
* @param args Optional configuration arguments.
* @returns An instance of `TensorBoardCallback`, which is a subclass of
* `tf.CustomCallback`.
*
* @doc {heading: 'TensorBoard', namespace: 'node'}
*/
export function tensorBoard(
logdir = './logs', args?: TensorBoardCallbackArgs): TensorBoardCallback {
return new TensorBoardCallback(logdir, args);
}