UNPKG

@tensorflow/tfjs-node

Version:

This repository provides native TensorFlow execution in backend JavaScript applications under the Node.js runtime, accelerated by the TensorFlow C binary under the hood. It provides the same API as [TensorFlow.js](https://js.tensorflow.org/api/latest/).

367 lines (341 loc) 12.6 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import {CustomCallback, LayersModel, Logs, nextFrame, util} from '@tensorflow/tfjs'; import * as path from 'path'; import * as ProgressBar from 'progress'; import {summaryFileWriter, SummaryFileWriter} from './tensorboard'; type LogFunction = (message: string) => void; // A helper class created for testing with the jasmine `spyOn` method, which // operates only on member methods of objects. // tslint:disable-next-line:no-any export const progressBarHelper: {ProgressBar: any, log: LogFunction} = { ProgressBar, log: console.log }; /** * Terminal-based progress bar callback for tf.Model.fit(). */ export class ProgbarLogger extends CustomCallback { private numTrainBatchesPerEpoch: number; private progressBar: ProgressBar; private currentEpochBegin: number; private epochDurationMillis: number; private usPerStep: number; private batchesInLatestEpoch: number; private terminalWidth: number; private readonly RENDER_THROTTLE_MS = 50; /** * Construtor of LoggingCallback. */ constructor() { super({ onTrainBegin: async (logs?: Logs) => { const samples = this.params.samples as number; const batchSize = this.params.batchSize as number; const steps = this.params.steps as number; if (samples != null || steps != null) { this.numTrainBatchesPerEpoch = samples != null ? Math.ceil(samples / batchSize) : steps; } else { // Undetermined number of batches per epoch, e.g., due to // `fitDataset()` without `batchesPerEpoch`. this.numTrainBatchesPerEpoch = 0; } }, onEpochBegin: async (epoch: number, logs?: Logs) => { progressBarHelper.log(`Epoch ${epoch + 1} / ${this.params.epochs}`); this.currentEpochBegin = util.now(); this.epochDurationMillis = null; this.usPerStep = null; this.batchesInLatestEpoch = 0; this.terminalWidth = process.stderr.columns; }, onBatchEnd: async (batch: number, logs?: Logs) => { this.batchesInLatestEpoch++; if (batch === 0) { this.progressBar = new progressBarHelper.ProgressBar( 'eta=:eta :bar :placeholderForLossesAndMetrics', { width: Math.floor(0.5 * this.terminalWidth), total: this.numTrainBatchesPerEpoch + 1, head: `>`, renderThrottle: this.RENDER_THROTTLE_MS }); } const maxMetricsStringLength = Math.floor(this.terminalWidth * 0.5 - 12); const tickTokens = { placeholderForLossesAndMetrics: this.formatLogsAsMetricsContent(logs, maxMetricsStringLength) }; if (this.numTrainBatchesPerEpoch === 0) { // Undetermined number of batches per epoch. this.progressBar.tick(0, tickTokens); } else { this.progressBar.tick(tickTokens); } await nextFrame(); if (batch === this.numTrainBatchesPerEpoch - 1) { this.epochDurationMillis = util.now() - this.currentEpochBegin; this.usPerStep = this.params.samples != null ? this.epochDurationMillis / (this.params.samples as number) * 1e3 : this.epochDurationMillis / this.batchesInLatestEpoch * 1e3; } }, onEpochEnd: async (epoch: number, logs?: Logs) => { if (this.epochDurationMillis == null) { // In cases where the number of batches per epoch is not determined, // the calculation of the per-step duration is done at the end of the // epoch. N.B., this includes the time spent on validation. this.epochDurationMillis = util.now() - this.currentEpochBegin; this.usPerStep = this.epochDurationMillis / this.batchesInLatestEpoch * 1e3; } this.progressBar.tick({placeholderForLossesAndMetrics: ''}); const lossesAndMetricsString = this.formatLogsAsMetricsContent(logs); progressBarHelper.log( `${this.epochDurationMillis.toFixed(0)}ms ` + `${this.usPerStep.toFixed(0)}us/step - ` + `${lossesAndMetricsString}`); await nextFrame(); }, }); } private formatLogsAsMetricsContent(logs: Logs, maxMetricsLength?: number): string { let metricsContent = ''; const keys = Object.keys(logs).sort(); for (const key of keys) { if (this.isFieldRelevant(key)) { const value = logs[key]; metricsContent += `${key}=${getSuccinctNumberDisplay(value)} `; } } if (maxMetricsLength != null && metricsContent.length > maxMetricsLength) { // Cut off metrics strings that are too long to avoid new lines being // constantly created. metricsContent = metricsContent.slice(0, maxMetricsLength - 3) + '...'; } return metricsContent; } private isFieldRelevant(key: string) { return key !== 'batch' && key !== 'size'; } } const BASE_NUM_DIGITS = 2; const MAX_NUM_DECIMAL_PLACES = 4; /** * Get a succint string representation of a number. * * Uses decimal notation if the number isn't too small. * Otherwise, use engineering notation. * * @param x Input number. * @return Succinct string representing `x`. */ export function getSuccinctNumberDisplay(x: number): string { const decimalPlaces = getDisplayDecimalPlaces(x); return decimalPlaces > MAX_NUM_DECIMAL_PLACES ? x.toExponential(BASE_NUM_DIGITS) : x.toFixed(decimalPlaces); } /** * Determine the number of decimal places to display. * * @param x Number to display. * @return Number of decimal places to display for `x`. */ export function getDisplayDecimalPlaces(x: number): number { if (!Number.isFinite(x) || x === 0 || x > 1 || x < -1) { return BASE_NUM_DIGITS; } else { return BASE_NUM_DIGITS - Math.floor(Math.log10(Math.abs(x))); } } export interface TensorBoardCallbackArgs { /** * The frequency at which loss and metric values are written to logs. * * Currently supported options are: * * - 'batch': Write logs at the end of every batch of training, in addition * to the end of every epoch of training. * - 'epoch': Write logs at the end of every epoch of training. * * Note that writing logs too often slows down the training. * * Default: 'epoch'. */ updateFreq?: 'batch'|'epoch'; /** * The frequency (in epochs) at which to compute activation and weight * histograms for the layers of the model. * * If set to 0, histograms won't be computed. * * Validation data (or split) must be specified for histogram visualizations. * * Default: 0. */ histogramFreq?: number; } /** * Callback for logging to TensorBoard during training. * * Users are expected to access this class through the `tensorBoardCallback()` * factory method instead. */ export class TensorBoardCallback extends CustomCallback { private model: LayersModel = null; private trainWriter: SummaryFileWriter; private valWriter: SummaryFileWriter; private batchesSeen: number; private readonly args: TensorBoardCallbackArgs; constructor(readonly logdir = './logs', args?: TensorBoardCallbackArgs) { super({ onBatchEnd: async (batch: number, logs?: Logs) => { this.batchesSeen++; if (this.args.updateFreq !== 'epoch') { this.logMetrics(logs, 'batch_', this.batchesSeen); } }, onEpochEnd: async (epoch: number, logs?: Logs) => { this.logMetrics(logs, 'epoch_', epoch + 1); if (this.args.histogramFreq > 0 && epoch % this.args.histogramFreq === 0) { this.logWeights(epoch); } }, onTrainEnd: async (logs?: Logs) => { if (this.trainWriter != null) { this.trainWriter.flush(); } if (this.valWriter != null) { this.valWriter.flush(); } } }); this.args = args == null ? {} : args; if (this.args.updateFreq == null) { this.args.updateFreq = 'epoch'; } util.assert( ['batch', 'epoch'].indexOf(this.args.updateFreq) !== -1, () => `Expected updateFreq to be 'batch' or 'epoch', but got ` + `${this.args.updateFreq}`); if (this.args.histogramFreq == null) { this.args.histogramFreq = 0; } util.assert( Number.isInteger(this.args.histogramFreq) && this.args.histogramFreq >= 0, () => `Expected histogramFreq to be a positive integer, but got ` + `${this.args.histogramFreq}`); this.batchesSeen = 0; } setModel(model: LayersModel): void { // This method is inherited from BaseCallback. To avoid cyclical imports, // that class uses Container instead of LayersModel, and uses a run-time // check to make sure the model is a LayersModel. // Since this subclass isn't imported by tfjs-layers, we can safely use type // the parameter as a LayersModel. this.model = model; } private logMetrics(logs: Logs, prefix: string, step: number) { for (const key in logs) { if (key === 'batch' || key === 'size' || key === 'num_steps') { continue; } const VAL_PREFIX = 'val_'; if (key.startsWith(VAL_PREFIX)) { this.ensureValWriterCreated(); const scalarName = prefix + key.slice(VAL_PREFIX.length); this.valWriter.scalar(scalarName, logs[key], step); } else { this.ensureTrainWriterCreated(); this.trainWriter.scalar(`${prefix}${key}`, logs[key], step); } } } private logWeights(step: number) { for (const weights of this.model.weights) { this.trainWriter.histogram(weights.name, weights.read(), step); } } private ensureTrainWriterCreated() { this.trainWriter = summaryFileWriter(path.join(this.logdir, 'train')); } private ensureValWriterCreated() { this.valWriter = summaryFileWriter(path.join(this.logdir, 'val')); } } /** * Callback for logging to TensorBoard during training. * * Writes the loss and metric values (if any) to the specified log directory * (`logdir`) which can be ingested and visualized by TensorBoard. * This callback is usually passed as a callback to `tf.Model.fit()` or * `tf.Model.fitDataset()` calls during model training. The frequency at which * the values are logged can be controlled with the `updateFreq` field of the * configuration object (2nd argument). * * Usage example: * ```js * // Constructor a toy multilayer-perceptron regressor for demo purpose. * const model = tf.sequential(); * model.add( * tf.layers.dense({units: 100, activation: 'relu', inputShape: [200]})); * model.add(tf.layers.dense({units: 1})); * model.compile({ * loss: 'meanSquaredError', * optimizer: 'sgd', * metrics: ['MAE'] * }); * * // Generate some random fake data for demo purpose. * const xs = tf.randomUniform([10000, 200]); * const ys = tf.randomUniform([10000, 1]); * const valXs = tf.randomUniform([1000, 200]); * const valYs = tf.randomUniform([1000, 1]); * * // Start model training process. * await model.fit(xs, ys, { * epochs: 100, * validationData: [valXs, valYs], * // Add the tensorBoard callback here. * callbacks: tf.node.tensorBoard('/tmp/fit_logs_1') * }); * ``` * * Then you can use the following commands to point tensorboard * to the logdir: * * ```sh * pip install tensorboard # Unless you've already installed it. * tensorboard --logdir /tmp/fit_logs_1 * ``` * * @param logdir Directory to which the logs will be written. * @param args Optional configuration arguments. * @returns An instance of `TensorBoardCallback`, which is a subclass of * `tf.CustomCallback`. * * @doc {heading: 'TensorBoard', namespace: 'node'} */ export function tensorBoard( logdir = './logs', args?: TensorBoardCallbackArgs): TensorBoardCallback { return new TensorBoardCallback(logdir, args); }