UNPKG

@tensorflow/tfjs-node

Version:

This repository provides native TensorFlow execution in backend JavaScript applications under the Node.js runtime, accelerated by the TensorFlow C binary under the hood. It provides the same API as [TensorFlow.js](https://js.tensorflow.org/api/latest/).

401 lines (400 loc) 19.1 kB
"use strict"; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var __extends = (this && this.__extends) || (function () { var extendStatics = function (d, b) { extendStatics = Object.setPrototypeOf || ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || function (d, b) { for (var p in b) if (Object.prototype.hasOwnProperty.call(b, p)) d[p] = b[p]; }; return extendStatics(d, b); }; return function (d, b) { if (typeof b !== "function" && b !== null) throw new TypeError("Class extends value " + String(b) + " is not a constructor or null"); extendStatics(d, b); function __() { this.constructor = d; } d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); }; })(); var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; var __generator = (this && this.__generator) || function (thisArg, body) { var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; function verb(n) { return function (v) { return step([n, v]); }; } function step(op) { if (f) throw new TypeError("Generator is already executing."); while (g && (g = 0, op[0] && (_ = 0)), _) try { if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; if (y = 0, t) op = [op[0] & 2, t.value]; switch (op[0]) { case 0: case 1: t = op; break; case 4: _.label++; return { value: op[1], done: false }; case 5: _.label++; y = op[1]; op = [0]; continue; case 7: op = _.ops.pop(); _.trys.pop(); continue; default: if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } if (t[2]) _.ops.pop(); _.trys.pop(); continue; } op = body.call(thisArg, _); } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; } }; Object.defineProperty(exports, "__esModule", { value: true }); exports.tensorBoard = exports.TensorBoardCallback = exports.getDisplayDecimalPlaces = exports.getSuccinctNumberDisplay = exports.ProgbarLogger = exports.progressBarHelper = void 0; var tfjs_1 = require("@tensorflow/tfjs"); var path = require("path"); var ProgressBar = require("progress"); var tensorboard_1 = require("./tensorboard"); // A helper class created for testing with the jasmine `spyOn` method, which // operates only on member methods of objects. // tslint:disable-next-line:no-any exports.progressBarHelper = { ProgressBar: ProgressBar, log: console.log }; /** * Terminal-based progress bar callback for tf.Model.fit(). */ var ProgbarLogger = /** @class */ (function (_super) { __extends(ProgbarLogger, _super); /** * Construtor of LoggingCallback. */ function ProgbarLogger() { var _this = _super.call(this, { onTrainBegin: function (logs) { return __awaiter(_this, void 0, void 0, function () { var samples, batchSize, steps; return __generator(this, function (_a) { samples = this.params.samples; batchSize = this.params.batchSize; steps = this.params.steps; if (samples != null || steps != null) { this.numTrainBatchesPerEpoch = samples != null ? Math.ceil(samples / batchSize) : steps; } else { // Undetermined number of batches per epoch, e.g., due to // `fitDataset()` without `batchesPerEpoch`. this.numTrainBatchesPerEpoch = 0; } return [2 /*return*/]; }); }); }, onEpochBegin: function (epoch, logs) { return __awaiter(_this, void 0, void 0, function () { return __generator(this, function (_a) { exports.progressBarHelper.log("Epoch ".concat(epoch + 1, " / ").concat(this.params.epochs)); this.currentEpochBegin = tfjs_1.util.now(); this.epochDurationMillis = null; this.usPerStep = null; this.batchesInLatestEpoch = 0; this.terminalWidth = process.stderr.columns; return [2 /*return*/]; }); }); }, onBatchEnd: function (batch, logs) { return __awaiter(_this, void 0, void 0, function () { var maxMetricsStringLength, tickTokens; return __generator(this, function (_a) { switch (_a.label) { case 0: this.batchesInLatestEpoch++; if (batch === 0) { this.progressBar = new exports.progressBarHelper.ProgressBar('eta=:eta :bar :placeholderForLossesAndMetrics', { width: Math.floor(0.5 * this.terminalWidth), total: this.numTrainBatchesPerEpoch + 1, head: ">", renderThrottle: this.RENDER_THROTTLE_MS }); } maxMetricsStringLength = Math.floor(this.terminalWidth * 0.5 - 12); tickTokens = { placeholderForLossesAndMetrics: this.formatLogsAsMetricsContent(logs, maxMetricsStringLength) }; if (this.numTrainBatchesPerEpoch === 0) { // Undetermined number of batches per epoch. this.progressBar.tick(0, tickTokens); } else { this.progressBar.tick(tickTokens); } return [4 /*yield*/, (0, tfjs_1.nextFrame)()]; case 1: _a.sent(); if (batch === this.numTrainBatchesPerEpoch - 1) { this.epochDurationMillis = tfjs_1.util.now() - this.currentEpochBegin; this.usPerStep = this.params.samples != null ? this.epochDurationMillis / this.params.samples * 1e3 : this.epochDurationMillis / this.batchesInLatestEpoch * 1e3; } return [2 /*return*/]; } }); }); }, onEpochEnd: function (epoch, logs) { return __awaiter(_this, void 0, void 0, function () { var lossesAndMetricsString; return __generator(this, function (_a) { switch (_a.label) { case 0: if (this.epochDurationMillis == null) { // In cases where the number of batches per epoch is not determined, // the calculation of the per-step duration is done at the end of the // epoch. N.B., this includes the time spent on validation. this.epochDurationMillis = tfjs_1.util.now() - this.currentEpochBegin; this.usPerStep = this.epochDurationMillis / this.batchesInLatestEpoch * 1e3; } this.progressBar.tick({ placeholderForLossesAndMetrics: '' }); lossesAndMetricsString = this.formatLogsAsMetricsContent(logs); exports.progressBarHelper.log("".concat(this.epochDurationMillis.toFixed(0), "ms ") + "".concat(this.usPerStep.toFixed(0), "us/step - ") + "".concat(lossesAndMetricsString)); return [4 /*yield*/, (0, tfjs_1.nextFrame)()]; case 1: _a.sent(); return [2 /*return*/]; } }); }); }, }) || this; _this.RENDER_THROTTLE_MS = 50; return _this; } ProgbarLogger.prototype.formatLogsAsMetricsContent = function (logs, maxMetricsLength) { var metricsContent = ''; var keys = Object.keys(logs).sort(); for (var _i = 0, keys_1 = keys; _i < keys_1.length; _i++) { var key = keys_1[_i]; if (this.isFieldRelevant(key)) { var value = logs[key]; metricsContent += "".concat(key, "=").concat(getSuccinctNumberDisplay(value), " "); } } if (maxMetricsLength != null && metricsContent.length > maxMetricsLength) { // Cut off metrics strings that are too long to avoid new lines being // constantly created. metricsContent = metricsContent.slice(0, maxMetricsLength - 3) + '...'; } return metricsContent; }; ProgbarLogger.prototype.isFieldRelevant = function (key) { return key !== 'batch' && key !== 'size'; }; return ProgbarLogger; }(tfjs_1.CustomCallback)); exports.ProgbarLogger = ProgbarLogger; var BASE_NUM_DIGITS = 2; var MAX_NUM_DECIMAL_PLACES = 4; /** * Get a succint string representation of a number. * * Uses decimal notation if the number isn't too small. * Otherwise, use engineering notation. * * @param x Input number. * @return Succinct string representing `x`. */ function getSuccinctNumberDisplay(x) { var decimalPlaces = getDisplayDecimalPlaces(x); return decimalPlaces > MAX_NUM_DECIMAL_PLACES ? x.toExponential(BASE_NUM_DIGITS) : x.toFixed(decimalPlaces); } exports.getSuccinctNumberDisplay = getSuccinctNumberDisplay; /** * Determine the number of decimal places to display. * * @param x Number to display. * @return Number of decimal places to display for `x`. */ function getDisplayDecimalPlaces(x) { if (!Number.isFinite(x) || x === 0 || x > 1 || x < -1) { return BASE_NUM_DIGITS; } else { return BASE_NUM_DIGITS - Math.floor(Math.log10(Math.abs(x))); } } exports.getDisplayDecimalPlaces = getDisplayDecimalPlaces; /** * Callback for logging to TensorBoard during training. * * Users are expected to access this class through the `tensorBoardCallback()` * factory method instead. */ var TensorBoardCallback = /** @class */ (function (_super) { __extends(TensorBoardCallback, _super); function TensorBoardCallback(logdir, args) { if (logdir === void 0) { logdir = './logs'; } var _this = _super.call(this, { onBatchEnd: function (batch, logs) { return __awaiter(_this, void 0, void 0, function () { return __generator(this, function (_a) { this.batchesSeen++; if (this.args.updateFreq !== 'epoch') { this.logMetrics(logs, 'batch_', this.batchesSeen); } return [2 /*return*/]; }); }); }, onEpochEnd: function (epoch, logs) { return __awaiter(_this, void 0, void 0, function () { return __generator(this, function (_a) { this.logMetrics(logs, 'epoch_', epoch + 1); if (this.args.histogramFreq > 0 && epoch % this.args.histogramFreq === 0) { this.logWeights(epoch); } return [2 /*return*/]; }); }); }, onTrainEnd: function (logs) { return __awaiter(_this, void 0, void 0, function () { return __generator(this, function (_a) { if (this.trainWriter != null) { this.trainWriter.flush(); } if (this.valWriter != null) { this.valWriter.flush(); } return [2 /*return*/]; }); }); } }) || this; _this.logdir = logdir; _this.model = null; _this.args = args == null ? {} : args; if (_this.args.updateFreq == null) { _this.args.updateFreq = 'epoch'; } tfjs_1.util.assert(['batch', 'epoch'].indexOf(_this.args.updateFreq) !== -1, function () { return "Expected updateFreq to be 'batch' or 'epoch', but got " + "".concat(_this.args.updateFreq); }); if (_this.args.histogramFreq == null) { _this.args.histogramFreq = 0; } tfjs_1.util.assert(Number.isInteger(_this.args.histogramFreq) && _this.args.histogramFreq >= 0, function () { return "Expected histogramFreq to be a positive integer, but got " + "".concat(_this.args.histogramFreq); }); _this.batchesSeen = 0; return _this; } TensorBoardCallback.prototype.setModel = function (model) { // This method is inherited from BaseCallback. To avoid cyclical imports, // that class uses Container instead of LayersModel, and uses a run-time // check to make sure the model is a LayersModel. // Since this subclass isn't imported by tfjs-layers, we can safely use type // the parameter as a LayersModel. this.model = model; }; TensorBoardCallback.prototype.logMetrics = function (logs, prefix, step) { for (var key in logs) { if (key === 'batch' || key === 'size' || key === 'num_steps') { continue; } var VAL_PREFIX = 'val_'; if (key.startsWith(VAL_PREFIX)) { this.ensureValWriterCreated(); var scalarName = prefix + key.slice(VAL_PREFIX.length); this.valWriter.scalar(scalarName, logs[key], step); } else { this.ensureTrainWriterCreated(); this.trainWriter.scalar("".concat(prefix).concat(key), logs[key], step); } } }; TensorBoardCallback.prototype.logWeights = function (step) { for (var _i = 0, _a = this.model.weights; _i < _a.length; _i++) { var weights = _a[_i]; this.trainWriter.histogram(weights.name, weights.read(), step); } }; TensorBoardCallback.prototype.ensureTrainWriterCreated = function () { this.trainWriter = (0, tensorboard_1.summaryFileWriter)(path.join(this.logdir, 'train')); }; TensorBoardCallback.prototype.ensureValWriterCreated = function () { this.valWriter = (0, tensorboard_1.summaryFileWriter)(path.join(this.logdir, 'val')); }; return TensorBoardCallback; }(tfjs_1.CustomCallback)); exports.TensorBoardCallback = TensorBoardCallback; /** * Callback for logging to TensorBoard during training. * * Writes the loss and metric values (if any) to the specified log directory * (`logdir`) which can be ingested and visualized by TensorBoard. * This callback is usually passed as a callback to `tf.Model.fit()` or * `tf.Model.fitDataset()` calls during model training. The frequency at which * the values are logged can be controlled with the `updateFreq` field of the * configuration object (2nd argument). * * Usage example: * ```js * // Constructor a toy multilayer-perceptron regressor for demo purpose. * const model = tf.sequential(); * model.add( * tf.layers.dense({units: 100, activation: 'relu', inputShape: [200]})); * model.add(tf.layers.dense({units: 1})); * model.compile({ * loss: 'meanSquaredError', * optimizer: 'sgd', * metrics: ['MAE'] * }); * * // Generate some random fake data for demo purpose. * const xs = tf.randomUniform([10000, 200]); * const ys = tf.randomUniform([10000, 1]); * const valXs = tf.randomUniform([1000, 200]); * const valYs = tf.randomUniform([1000, 1]); * * // Start model training process. * await model.fit(xs, ys, { * epochs: 100, * validationData: [valXs, valYs], * // Add the tensorBoard callback here. * callbacks: tf.node.tensorBoard('/tmp/fit_logs_1') * }); * ``` * * Then you can use the following commands to point tensorboard * to the logdir: * * ```sh * pip install tensorboard # Unless you've already installed it. * tensorboard --logdir /tmp/fit_logs_1 * ``` * * @param logdir Directory to which the logs will be written. * @param args Optional configuration arguments. * @returns An instance of `TensorBoardCallback`, which is a subclass of * `tf.CustomCallback`. * * @doc {heading: 'TensorBoard', namespace: 'node'} */ function tensorBoard(logdir, args) { if (logdir === void 0) { logdir = './logs'; } return new TensorBoardCallback(logdir, args); } exports.tensorBoard = tensorBoard;