@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
1,104 lines • 265 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/* Original Source: engine/training.py */
import * as tfc from '@tensorflow/tfjs-core';
import { io, Optimizer, scalar, serialization, Tensor, tensor1d, util } from '@tensorflow/tfjs-core';
import * as K from '../backend/tfjs_backend';
import { configureCallbacks, standardizeCallbacks } from '../base_callbacks';
import { nameScope } from '../common';
import { NotImplementedError, RuntimeError, ValueError } from '../errors';
import { deserialize } from '../layers/serialization';
import { disposeTensorsInLogs } from '../logs';
import * as losses from '../losses';
import * as Metrics from '../metrics';
import * as optimizers from '../optimizers';
import { checkUserDefinedMetadata } from '../user_defined_metadata';
import { count, pyListRepeat, singletonOrArray, toCamelCase, toSnakeCase, unique } from '../utils/generic_utils';
import { printSummary } from '../utils/layer_utils';
import { range } from '../utils/math_utils';
import { convertPythonicToTs } from '../utils/serialization_utils';
import { version } from '../version';
import { Container } from './container';
import { execute, FeedDict } from './executor';
import { evaluateDataset, fitDataset } from './training_dataset';
import { checkBatchSize, disposeNewTensors, ensureTensorsRank2OrHigher, makeBatches, sliceArrays, sliceArraysByIndices } from './training_tensors';
import { computeWeightedLoss, standardizeClassWeights, standardizeWeights } from './training_utils';
/**
* Helper function for polymorphic input data: 1. singleton Tensor.
*/
export function isDataTensor(x) {
return x instanceof Tensor;
}
/**
* Helper function for polymorphic input data: 2. Array of Tensor.
*/
export function isDataArray(x) {
return Array.isArray(x);
}
/**
* Helper function for polymorphic input data: 3. "dict" of Tensor.
*/
export function isDataDict(x) {
return !isDataTensor(x) && !isDataArray(x);
}
/**
* Normalizes inputs and targets provided by users.
* @param data User-provided input data (polymorphic).
* @param names An Array of expected Tensor names.
* @param shapes Optional Array of expected Tensor shapes.
* @param checkBatchAxis Whether to check that the batch axis of the arrays
* match the expected value found in `shapes`.
* @param exceptionPrefix String prefix used for exception formatting.
* @returns List of standardized input Tensors (one Tensor per model input).
* @throws ValueError: in case of improperly formatted user data.
*/
export function standardizeInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
if (names == null || names.length === 0) {
// Check for the case where the model expected no data, but some data got
// sent.
if (data != null) {
let gotUnexpectedData = false;
if (isDataArray(data) && data.length > 0) {
gotUnexpectedData = true;
}
else if (isDataDict(data)) {
for (const key in data) {
if (data.hasOwnProperty(key)) {
gotUnexpectedData = true;
break;
}
}
}
else {
// `data` is a singleton Tensor in this case.
gotUnexpectedData = true;
}
if (gotUnexpectedData) {
throw new ValueError(`Error when checking model ${exceptionPrefix} expected no data, ` +
`but got ${data}`);
}
}
return [];
}
if (data == null) {
return names.map(name => null);
}
let arrays;
if (isDataDict(data)) {
data = data;
arrays = [];
for (const name of names) {
if (data[name] == null) {
throw new ValueError(`No data provided for "${name}". Need data for each key in: ` +
`${names}`);
}
arrays.push(data[name]);
}
}
else if (isDataArray(data)) {
data = data;
if (data.length !== names.length) {
throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
`Tensors that you are passing to your model is not the size the ` +
`model expected. Expected to see ${names.length} Tensor(s), but ` +
`instead got the following list of Tensor(s): ${data}`);
}
arrays = data;
}
else {
data = data;
if (names.length > 1) {
throw new ValueError(`The model ${exceptionPrefix} expects ${names.length} Tensor(s), ` +
`but only received one Tensor. Found: Tensor with shape ${data.shape}`);
}
arrays = [data];
}
arrays = ensureTensorsRank2OrHigher(arrays);
// Check shape compatibility.
if (shapes != null) {
for (let i = 0; i < names.length; ++i) {
if (shapes[i] == null) {
continue;
}
const array = arrays[i];
if (array.shape.length !== shapes[i].length) {
throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
`to have ${shapes[i].length} dimension(s). but got array with ` +
`shape ${array.shape}`);
}
for (let j = 0; j < shapes[i].length; ++j) {
if (j === 0 && !checkBatchAxis) {
// Skip the first (batch) axis.
continue;
}
const dim = array.shape[j];
const refDim = shapes[i][j];
if (refDim != null && refDim >= 0 && dim !== refDim) {
throw new ValueError(`${exceptionPrefix} expected a batch of elements where each ` +
`example has shape [${shapes[i].slice(1, shapes[i].length)}] ` +
`(i.e.,tensor shape [*,${shapes[i].slice(1, shapes[i].length)}])` +
` but the ${exceptionPrefix} received an input with ${array.shape[0]}` +
` examples, each with shape [${array.shape.slice(1, array.shape.length)}]` +
` (tensor shape [${array.shape}])`);
}
}
}
}
return arrays;
}
/**
* User input validation for Tensors.
* @param inputs `Array` of `tf.Tensor`s for inputs.
* @param targets `Array` of `tf.Tensor`s for targets.
* @param weights Optional `Array` of `tf.Tensor`s for sample weights.
* @throws ValueError: in case of incorrectly formatted data.
*/
export function checkArrayLengths(inputs, targets, weights) {
const setX = unique(inputs.map(input => input.shape[0]));
setX.sort();
const setY = unique(targets.map(target => target.shape[0]));
setY.sort();
// TODO(cais): Check `weights` as well.
if (setX.length > 1) {
throw new ValueError(`All input Tensors (x) should have the same number of samples. ` +
`Got array shapes: ` +
`${JSON.stringify(inputs.map(input => input.shape))}`);
}
if (setY.length > 1) {
throw new ValueError(`All target Tensors (y) should have the same number of samples. ` +
`Got array shapes: ` +
`${JSON.stringify(targets.map(target => target.shape))}`);
}
if (setX.length > 0 && setY.length > 0 && !util.arraysEqual(setX, setY)) {
throw new ValueError(`Input Tensors should have the same number of samples as target ` +
`Tensors. Found ${setX[0]} input sample(s) and ${setY[0]} target ` +
`sample(s).`);
}
}
/**
* Validation on the compatibility of targes and loss functions.
*
* This helps prevent users from using loss functions incorrectly.
*
* @param targets `Array` of `tf.Tensor`s of targets.
* @param lossFns `Array` of loss functions.
* @param outputShapes `Array` of shapes of model outputs.
*/
function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) {
// TODO(cais): Dedicated test coverage?
const keyLosses = [
losses.meanSquaredError, losses.binaryCrossentropy,
losses.categoricalCrossentropy
];
for (let i = 0; i < targets.length; ++i) {
const y = targets[i];
const loss = lossFns[i];
const shape = outputShapes[i];
if (loss == null) {
continue;
}
if (loss === losses.categoricalCrossentropy) {
if (y.shape[y.shape.length - 1] === 1) {
throw new ValueError(`You are passing a target array of shape ${y.shape} while using ` +
`a loss 'categorical_crossentropy'. 'categorical_crossentropy'` +
`expects targets to be binary matrices (1s and 0s) of shape ` +
`[samples, classes].`);
// TODO(cais): Example code in error message.
}
}
if (keyLosses.indexOf(loss) !== -1) {
const slicedYShape = y.shape.slice(1);
const slicedShape = shape.slice(1);
for (let j = 0; j < slicedYShape.length; ++j) {
const targetDim = slicedYShape[j];
const outDim = slicedShape[j];
if (outDim != null && targetDim !== outDim) {
throw new ValueError(`A target Tensor with shape ${y.shape} was passed for an ` +
`output of shape ${shape}, while using a loss function that ` +
`expects targets to have the same shape as the output.`);
}
}
}
}
}
/**
* Check inputs provided by the user.
*
* Porting Note: This corresponds to _standardize_input_data() in Python
* Keras. Because of the strong typing in TF.js, we do not need to convert
* the data. Specifically:
* 1) in PyKeras, `data` can be `DataFrame` instances from pandas, for
* example. We don't need to worry about that here because there is no
* widely popular javascript/typesdcript equivalent of pandas (so far).
* If one becomes available in the future, we can add support.
* 2) in PyKeras, inputs can be Python dict. But here we are stipulating
* that the data is either a single `tf.Tensor` or an Array of `tf.Tensor`s. We
* may add support for `Object` data inputs in the future when the need
* arises.
*
* Instead, we perform basic checks for number of parameters and shapes.
*
* @param data: The input data.
* @param names: Name for the inputs, from the model.
* @param shapes: Expected shapes for the input data, from the model.
* @param checkBatchAxis: Whether the size along the batch axis (i.e., the
* first dimension) will be checked for matching.
* @param exceptionPrefix: Execption prefix message, used in generating error
* messages.
* @throws ValueError: on incorrect number of inputs or mismatches in shapes.
*/
function checkInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
let arrays;
if (Array.isArray(data)) {
if (data.length !== names.length) {
throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
`Tensors that you are passing to your model is not the size the ` +
`the model expected. Expected to see ${names.length} Tensor(s),` +
` but instead got ${data.length} Tensors(s).`);
}
arrays = data;
}
else {
if (names.length > 1) {
throw new ValueError(`The model expects ${names.length} ${exceptionPrefix} Tensors, ` +
`but only received one Tensor. Found: array with shape ` +
`${JSON.stringify(data.shape)}.`);
}
arrays = [data];
}
if (shapes != null) {
for (let i = 0; i < names.length; ++i) {
if (shapes[i] == null) {
continue;
}
const array = arrays[i];
if (array.shape.length !== shapes[i].length) {
throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
`to have ${shapes[i].length} dimension(s), but got array with ` +
`shape ${JSON.stringify(array.shape)}`);
}
for (let j = 0; j < shapes[i].length; ++j) {
if (j === 0 && !checkBatchAxis) {
continue;
}
const dim = array.shape[j];
const refDim = shapes[i][j];
if (refDim != null) {
if (refDim !== dim) {
throw new ValueError(`Error when checking ${exceptionPrefix}: expected ` +
`${names[i]} to have shape ${JSON.stringify(shapes[i])} but ` +
`got array with shape ${JSON.stringify(array.shape)}.`);
}
}
}
}
}
}
/**
* Maps metric functions to model outputs.
* @param metrics An shortcut strings name, metric function, `Array` or dict
* (`Object`) of metric functions.
* @param outputNames An `Array` of the names of model outputs.
* @returns An `Array` (one entry per model output) of `Array` of metric
* functions. For instance, if the model has 2 outputs, and for the first
* output we want to compute `binaryAccuracy` and `binaryCrossentropy`,
* and just `binaryAccuracy` for the second output, the `Array` would look
* like:
* `[[binaryAccuracy, binaryCrossentropy], [binaryAccuracy]]`
* @throws TypeError: incompatible metrics format.
*/
export function collectMetrics(metrics, outputNames) {
if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
return outputNames.map(name => []);
}
let wrappedMetrics;
if (typeof metrics === 'string' || typeof metrics === 'function') {
wrappedMetrics = [metrics];
}
else if (Array.isArray(metrics) || typeof metrics === 'object') {
wrappedMetrics = metrics;
}
else {
throw new TypeError('Type of metrics argument not understood. Expected an string,' +
`function, Array, or Object, found: ${metrics}`);
}
if (Array.isArray(wrappedMetrics)) {
// We then apply all metrics to all outputs.
return outputNames.map(name => wrappedMetrics);
}
else {
// In this case, metrics is a dict.
const nestedMetrics = [];
for (const name of outputNames) {
let outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : [];
if (!Array.isArray(outputMetrics)) {
outputMetrics = [outputMetrics];
}
nestedMetrics.push(outputMetrics);
}
return nestedMetrics;
}
}
const LAYERS_MODEL_FORMAT_NAME = 'layers-model';
/**
* A `tf.LayersModel` is a directed, acyclic graph of `tf.Layer`s plus methods
* for training, evaluation, prediction and saving.
*
* `tf.LayersModel` is the basic unit of training, inference and evaluation in
* TensorFlow.js. To create a `tf.LayersModel`, use `tf.LayersModel`.
*
* See also:
* `tf.Sequential`, `tf.loadLayersModel`.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
class LayersModel extends Container {
constructor(args) {
super(args);
this.isTraining = false;
}
/**
* Print a text summary of the model's layers.
*
* The summary includes
* - Name and type of all layers that comprise the model.
* - Output shape(s) of the layers
* - Number of weight parameters of each layer
* - If the model has non-sequential-like topology, the inputs each layer
* receives
* - The total number of trainable and non-trainable parameters of the model.
*
* ```js
* const input1 = tf.input({shape: [10]});
* const input2 = tf.input({shape: [20]});
* const dense1 = tf.layers.dense({units: 4}).apply(input1);
* const dense2 = tf.layers.dense({units: 8}).apply(input2);
* const concat = tf.layers.concatenate().apply([dense1, dense2]);
* const output =
* tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat);
*
* const model = tf.model({inputs: [input1, input2], outputs: output});
* model.summary();
* ```
*
* @param lineLength Custom line length, in number of characters.
* @param positions Custom widths of each of the columns, as either
* fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
* of characters (e.g., `[30, 50, 65]`). Each number corresponds to
* right-most (i.e., ending) position of a column.
* @param printFn Custom print function. Can be used to replace the default
* `console.log`. For example, you can use `x => {}` to mute the printed
* messages in the console.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
summary(lineLength, positions, printFn = console.log) {
if (!this.built) {
throw new ValueError(`This model has never been called, thus its weights have not been ` +
`created yet. So no summary can be displayed. Build the model ` +
`first (e.g., by calling it on some test data).`);
}
printSummary(this, lineLength, positions, printFn);
}
/**
* Configures and prepares the model for training and evaluation. Compiling
* outfits the model with an optimizer, loss, and/or metrics. Calling `fit`
* or `evaluate` on an un-compiled model will throw an error.
*
* @param args a `ModelCompileArgs` specifying the loss, optimizer, and
* metrics to be used for fitting and evaluating this model.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
compile(args) {
if (args.loss == null) {
args.loss = [];
}
this.loss = args.loss;
if (typeof args.optimizer === 'string') {
this.optimizer_ = optimizers.getOptimizer(args.optimizer);
this.isOptimizerOwned = true;
}
else {
if (!(args.optimizer instanceof Optimizer)) {
throw new ValueError(`User-defined optimizer must be an instance of tf.Optimizer.`);
}
this.optimizer_ = args.optimizer;
this.isOptimizerOwned = false;
}
// TODO(cais): Add lossWeights.
// TODO(cais): Add sampleWeightMode.
// Prepare loss functions.
let lossFunctions = [];
if (!Array.isArray(args.loss) && typeof args.loss !== 'string' &&
typeof args.loss !== 'function') {
args.loss = args.loss;
for (const name in args.loss) {
if (this.outputNames.indexOf(name) === -1) {
throw new ValueError(`Unknown entry in loss dictionary: "${name}". ` +
`Only expected the following keys: ${this.outputNames}`);
}
}
for (const name of this.outputNames) {
if (args.loss[name] == null) {
console.warn(`Output "${name}" is missing from loss dictionary. We assume ` +
`this was done on purpose, and we will not be expecting data ` +
`to be passed to ${name} during training`);
}
lossFunctions.push(losses.get(args.loss[name]));
}
}
else if (Array.isArray(args.loss)) {
if (args.loss.length !== this.outputs.length) {
throw new ValueError(`When passing an Array as loss, it should have one entry per ` +
`model output. The model has ${this.outputs.length} output(s), ` +
`but you passed loss=${args.loss}.`);
}
const theLosses = args.loss;
lossFunctions = theLosses.map(l => losses.get(l));
}
else {
const lossFunction = losses.get(args.loss);
this.outputs.forEach(_ => {
lossFunctions.push(lossFunction);
});
}
this.lossFunctions = lossFunctions;
this.feedOutputNames = [];
this.feedOutputShapes = [];
this.feedLossFns = [];
for (let i = 0; i < this.outputs.length; ++i) {
// TODO(cais): Logic for skipping target(s).
const shape = this.internalOutputShapes[i];
const name = this.outputNames[i];
this.feedOutputNames.push(name);
this.feedOutputShapes.push(shape);
this.feedLossFns.push(this.lossFunctions[i]);
}
// TODO(cais): Add logic for output masks.
// TODO(cais): Add logic for sample weights.
const skipTargetIndices = [];
// Prepare metrics.
this.metrics = args.metrics;
// TODO(cais): Add weightedMetrics.
this.metricsNames = ['loss'];
this.metricsTensors = [];
// Compute total loss.
// Porting Note: In PyKeras, metrics_tensors are symbolic tensor objects.
// Here, metricsTensors are TypeScript functions. This difference is due
// to the difference in symbolic/imperative property of the backends.
nameScope('loss', () => {
for (let i = 0; i < this.outputs.length; ++i) {
if (skipTargetIndices.indexOf(i) !== -1) {
continue;
}
// TODO(cais): Add weightedLoss, sampleWeight and mask.
// The following line should be weightedLoss
const weightedLoss = this.lossFunctions[i];
if (this.outputs.length > 1) {
this.metricsTensors.push([weightedLoss, i]);
this.metricsNames.push(this.outputNames[i] + '_loss');
}
}
// Porting Note: Due to the imperative nature of the backend, we calculate
// the regularizer penalties in the totalLossFunction, instead of here.
});
const nestedMetrics = collectMetrics(args.metrics, this.outputNames);
// TODO(cais): Add nestedWeightedMetrics.
/**
* Helper function used in loop below.
*/
const appendMetric = (outputIndex, metricName, metricTensor) => {
if (this.outputNames.length > 1) {
metricName = this.outputNames[outputIndex] + '_' + metricName;
}
this.metricsNames.push(metricName);
this.metricsTensors.push([metricTensor, outputIndex]);
};
nameScope('metric', () => {
for (let i = 0; i < this.outputs.length; ++i) {
if (skipTargetIndices.indexOf(i) !== -1) {
continue;
}
const outputMetrics = nestedMetrics[i];
// TODO(cais): Add weights and outputWeightedMetrics.
// TODO(cais): Add optional arg `weights` to the following function.
const handleMetrics = (metrics) => {
const metricNamePrefix = '';
let metricName;
let accFn;
let weightedMetricFn;
// TODO(cais): Use 'weights_' for weighted metrics.
for (const metric of metrics) {
if (typeof metric === 'string' &&
['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !==
-1) {
const outputShape = this.internalOutputShapes[i];
if (outputShape[outputShape.length - 1] === 1 ||
this.lossFunctions[i] === losses.binaryCrossentropy) {
// case: binary accuracy/crossentropy.
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = Metrics.binaryAccuracy;
}
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = Metrics.binaryCrossentropy;
}
}
else if (this.lossFunctions[i] ===
losses.sparseCategoricalCrossentropy) {
// case: categorical accuracy / crossentropy with sparse
// targets.
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = Metrics.sparseCategoricalAccuracy;
}
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = Metrics.sparseCategoricalCrossentropy;
}
}
else {
// case: categorical accuracy / crossentropy.
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = Metrics.categoricalAccuracy;
}
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = Metrics.categoricalCrossentropy;
}
}
let suffix;
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
suffix = 'acc';
}
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
suffix = 'ce';
}
// TODO(cais): Add weighting actually.
weightedMetricFn = accFn;
metricName = metricNamePrefix + suffix;
}
else {
const metricFn = Metrics.get(metric);
// TODO(cais): Add weighting actually.
weightedMetricFn = metricFn;
metricName =
metricNamePrefix + Metrics.getLossOrMetricName(metric);
}
// TODO(cais): Add weighting and masking to metricResult.
let metricResult;
nameScope(metricName, () => {
metricResult = weightedMetricFn;
});
appendMetric(i, metricName, metricResult);
}
};
handleMetrics(outputMetrics);
// TODO(cais): Call handleMetrics with weights.
}
});
// Porting Notes: Given the imperative backend of tfjs-core,
// there is no need for constructing the symbolic graph and placeholders.
this.collectedTrainableWeights = this.trainableWeights;
}
/**
* Check trainable weights count consistency.
*
* This will raise a warning if `this.trainableWeights` and
* `this.collectedTrainableWeights` are inconsistent (i.e., have different
* numbers of parameters).
* Inconsistency will typically arise when one modifies `model.trainable`
* without calling `model.compile()` again.
*/
checkTrainableWeightsConsistency() {
if (this.collectedTrainableWeights == null) {
return;
}
if (this.trainableWeights.length !==
this.collectedTrainableWeights.length) {
console.warn('Discrepancy between trainableweights and collected trainable ' +
'weights. Did you set `model.trainable` without calling ' +
'`model.compile()` afterwards?');
}
}
/**
* Returns the loss value & metrics values for the model in test mode.
*
* Loss and metrics are specified during `compile()`, which needs to happen
* before calls to `evaluate()`.
*
* Computation is done in batches.
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
* const result = model.evaluate(
* tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4});
* result.print();
* ```
*
* @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
* model has multiple inputs.
* @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
* model has multiple outputs.
* @param args A `ModelEvaluateArgs`, containing optional fields.
*
* @return `Scalar` test loss (if the model has a single output and no
* metrics) or `Array` of `Scalar`s (if the model has multiple outputs
* and/or metrics). The attribute `model.metricsNames`
* will give you the display labels for the scalar outputs.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
evaluate(x, y, args = {}) {
const batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize);
// TODO(cais): Standardize `config.sampleWeights` as well.
// Validate user data.
const checkBatchAxis = true;
const standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
try {
// TODO(cais): If uses `useLearningPhase`, set the corresponding element
// of the input to 0.
const ins = standardizedOuts[0].concat(standardizedOuts[1]);
this.makeTestFunction();
const f = this.testFunction;
const testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps);
return singletonOrArray(testOuts);
}
finally {
disposeNewTensors(standardizedOuts[0], x);
disposeNewTensors(standardizedOuts[1], y);
}
}
// TODO(cais): Add code snippet below once real dataset objects are
// available.
/**
* Evaluate model using a dataset object.
*
* Note: Unlike `evaluate()`, this method is asynchronous (`async`).
*
* @param dataset A dataset object. Its `iterator()` method is expected
* to generate a dataset iterator object, the `next()` method of which
* is expected to produce data batches for evaluation. The return value
* of the `next()` call ought to contain a boolean `done` field and a
* `value` field. The `value` field is expected to be an array of two
* `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
* case is for models with exactly one input and one output (e.g.
* a sequential model). The latter case is for models with multiple
* inputs and/or multiple outputs. Of the two items in the array, the
* first is the input feature(s) and the second is the output target(s).
* @param args A configuration object for the dataset-based evaluation.
* @returns Loss and metric values as an Array of `Scalar` objects.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
async evaluateDataset(dataset, args) {
this.makeTestFunction();
return evaluateDataset(this, dataset, args);
}
/**
* Get number of samples provided for training, evaluation or prediction.
*
* @param ins Input `tf.Tensor`.
* @param batchSize Integer batch size, optional.
* @param steps Total number of steps (batches of samples) before
* declaring loop finished. Optional.
* @param stepsName The public API's parameter name for `steps`.
* @returns Number of samples provided.
*/
checkNumSamples(ins, batchSize, steps, stepsName = 'steps') {
let numSamples;
if (steps != null) {
numSamples = null;
if (batchSize != null) {
throw new ValueError(`If ${stepsName} is set, batchSize must be null or undefined.` +
`Got batchSize = ${batchSize}`);
}
}
else if (ins != null) {
if (Array.isArray(ins)) {
numSamples = ins[0].shape[0];
}
else {
numSamples = ins.shape[0];
}
}
else {
throw new ValueError(`Either the input data should have a defined shape, or ` +
`${stepsName} shoud be specified.`);
}
return numSamples;
}
/**
* Execute internal tensors of the model with input data feed.
* @param inputs Input data feed. Must match the inputs of the model.
* @param outputs Names of the output tensors to be fetched. Must match
* names of the SymbolicTensors that belong to the graph.
* @returns Fetched values for `outputs`.
*/
execute(inputs, outputs) {
if (Array.isArray(outputs) && outputs.length === 0) {
throw new ValueError('`outputs` is an empty Array, which is not allowed.');
}
const outputsIsArray = Array.isArray(outputs);
const outputNames = (outputsIsArray ? outputs : [outputs]);
const outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames);
// Format the input into a FeedDict.
const feedDict = new FeedDict();
if (inputs instanceof Tensor) {
inputs = [inputs];
}
if (Array.isArray(inputs)) {
if (inputs.length !== this.inputs.length) {
throw new ValueError(`The number of inputs provided (${inputs.length}) ` +
`does not match the number of inputs of this model ` +
`(${this.inputs.length}).`);
}
for (let i = 0; i < this.inputs.length; ++i) {
feedDict.add(this.inputs[i], inputs[i]);
}
}
else {
for (const input of this.inputs) {
const tensorValue = inputs[input.name];
if (tensorValue == null) {
throw new ValueError(`No value is provided for the model's input ${input.name}`);
}
feedDict.add(input, tensorValue);
}
}
// Run execution.
const executeOutputs = execute(outputSymbolicTensors, feedDict);
return outputsIsArray ? executeOutputs : executeOutputs[0];
}
/**
* Retrieve the model's internal symbolic tensors from symbolic-tensor names.
*/
retrieveSymbolicTensors(symbolicTensorNames) {
const outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length);
let outputsRemaining = symbolicTensorNames.length;
for (const layer of this.layers) {
const layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output];
const layerOutputNames = layerOutputs.map(output => output.name);
for (let i = 0; i < symbolicTensorNames.length; ++i) {
const index = layerOutputNames.indexOf(symbolicTensorNames[i]);
if (index !== -1) {
outputSymbolicTensors[i] = layerOutputs[index];
outputsRemaining--;
}
if (outputsRemaining === 0) {
break;
}
}
if (outputsRemaining === 0) {
break;
}
}
if (outputsRemaining > 0) {
const remainingNames = [];
outputSymbolicTensors.forEach((tensor, i) => {
if (tensor == null) {
remainingNames.push(symbolicTensorNames[i]);
}
});
throw new ValueError(`Cannot find SymbolicTensors for output name(s): ` +
`${JSON.stringify(remainingNames)}`);
}
return outputSymbolicTensors;
}
/**
* Helper method to loop over some data in batches.
*
* Porting Note: Not using the functional approach in the Python equivalent
* due to the imperative backend.
* Porting Note: Does not support step mode currently.
*
* @param ins: input data
* @param batchSize: integer batch size.
* @param verbose: verbosity model
* @returns: Predictions as `tf.Tensor` (if a single output) or an `Array` of
* `tf.Tensor` (if multipe outputs).
*/
predictLoop(ins, batchSize = 32, verbose = false) {
return tfc.tidy(() => {
const numSamples = this.checkNumSamples(ins);
if (verbose) {
throw new NotImplementedError('Verbose predictLoop() is not implemented yet.');
}
// Sample-based predictions.
// Porting Note: Tensor currently does not support sliced assignments as
// in numpy, e.g., x[1:3] = y. Therefore we use concatenation while
// iterating over the batches.
const batches = makeBatches(numSamples, batchSize);
const outsBatches = this.outputs.map(output => []);
// TODO(cais): Can the scope() be pushed down inside the for loop?
for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
const batchOuts = tfc.tidy(() => {
const batchStart = batches[batchIndex][0];
const batchEnd = batches[batchIndex][1];
// TODO(cais): Take care of the case of the last element is a flag for
// training/test.
const insBatch = sliceArrays(ins, batchStart, batchEnd);
// Construct the feeds for execute();
const feeds = [];
if (Array.isArray(insBatch)) {
for (let i = 0; i < insBatch.length; ++i) {
feeds.push({ key: this.inputs[i], value: insBatch[i] });
}
}
else {
feeds.push({ key: this.inputs[0], value: insBatch });
}
const feedDict = new FeedDict(feeds);
return execute(this.outputs, feedDict);
});
batchOuts.forEach((batchOut, i) => outsBatches[i].push(batchOut));
}
return singletonOrArray(outsBatches.map(batches => tfc.concat(batches, 0)));
});
}
/**
* Generates output predictions for the input samples.
*
* Computation is done in batches.
*
* Note: the "step" mode of predict() is currently not supported.
* This is because the TensorFlow.js core backend is imperative only.
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.predict(tf.ones([8, 10]), {batchSize: 4}).print();
* ```
*
* @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
* the model has multiple inputs.
* @param args A `ModelPredictArgs` object containing optional fields.
*
* @return Prediction results as a `tf.Tensor`(s).
*
* @exception ValueError In case of mismatch between the provided input data
* and the model's expectations, or in case a stateful model receives a
* number of samples that is not a multiple of the batch size.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
predict(x, args = {}) {
const xsRank2OrHigher = ensureTensorsRank2OrHigher(x);
checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false);
try {
// TODO(cais): Take care of stateful models.
// if (this.stateful) ...
// TODO(cais): Take care of the learning_phase boolean flag.
// if (this.useLearningPhase) ...
const batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize);
return this.predictLoop(xsRank2OrHigher, batchSize);
}
finally {
disposeNewTensors(xsRank2OrHigher, x);
}
}
/**
* Returns predictions for a single batch of samples.
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.predictOnBatch(tf.ones([8, 10])).print();
* ```
* @param x: Input samples, as a Tensor (for models with exactly one
* input) or an array of Tensors (for models with more than one input).
* @return Tensor(s) of predictions
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
predictOnBatch(x) {
checkInputData(x, this.inputNames, this.feedInputShapes, true);
// TODO(cais): Take care of the learning_phase boolean flag.
// if (this.useLearningPhase) ...
const batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
return this.predictLoop(x, batchSize);
}
standardizeUserDataXY(x, y, checkBatchAxis = true, batchSize) {
// TODO(cais): Add sampleWeight, classWeight
if (this.optimizer_ == null) {
throw new RuntimeError('You must compile a model before training/testing. Use ' +
'LayersModel.compile(modelCompileArgs).');
}
const outputShapes = [];
for (let i = 0; i < this.feedOutputShapes.length; ++i) {
const outputShape = this.feedOutputShapes[i];
const lossFn = this.feedLossFns[i];
if (lossFn === losses.sparseCategoricalCrossentropy) {
outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1]));
}
else {
// Porting Note: Because of strong typing `lossFn` must be a function.
outputShapes.push(outputShape);
}
}
x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input');
y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target');
// TODO(cais): Standardize sampleWeights & classWeights.
checkArrayLengths(x, y, null);
// TODO(cais): Check sampleWeights as well.
checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
if (this.stateful && batchSize != null && batchSize > 0) {
if (x[0].shape[0] % batchSize !== 0) {
throw new ValueError(`In a stateful network, you should only pass inputs with a ` +
`number of samples that is divisible by the batch size ` +
`${batchSize}. Found: ${x[0].shape[0]} sample(s).`);
}
}
return [x, y];
}
async standardizeUserData(x, y, sampleWeight, classWeight, checkBatchAxis = true, batchSize) {
const [standardXs, standardYs] = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
// TODO(cais): Handle sampleWeights.
if (sampleWeight != null) {
throw new Error('sample weight is not supported yet.');
}
let standardSampleWeights = null;
if (classWeight != null) {
const classWeights = standardizeClassWeights(classWeight, this.outputNames);
standardSampleWeights = [];
for (let i = 0; i < classWeights.length; ++i) {
standardSampleWeights.push(await standardizeWeights(standardYs[i], null, classWeights[i]));
}
}
// TODO(cais): Deal with the case of model.stateful == true.
return [standardXs, standardYs, standardSampleWeights];
}
/**
* Loop over some test data in batches.
* @param f A Function returning a list of tensors.
* @param ins Array of tensors to be fed to `f`.
* @param batchSize Integer batch size or `null` / `undefined`.
* @param verbose verbosity mode.
* @param steps Total number of steps (batches of samples) before
* declaring test finished. Ignored with the default value of `null` /
* `undefined`.
* @returns Array of Scalars.
*/
testLoop(f, ins, batchSize, verbose = 0, steps) {
return tfc.tidy(() => {
const numSamples = this.checkNumSamples(ins, batchSize, steps, 'steps');
const outs = [];
if (verbose > 0) {
throw new NotImplementedError('Verbose mode is not implemented yet.');
}
// TODO(cais): Use `indicesForConversionToDense' to prevent slow down.
if (steps != null) {
throw new NotImplementedError('steps mode in testLoop() is not implemented yet');
}
else {
const batches = makeBatches(numSamples, batchSize);
const indexArray = tensor1d(range(0, numSamples));
for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
const batchStart = batches[batchIndex][0];
const batchEnd = batches[batchIndex][1];
const batchIds = K.sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart);
// TODO(cais): In ins, train flag can be a number, instead of an
// Tensor? Do we need to handle this in tfjs-layers?
const insBatch = sliceArraysByIndices(ins, batchIds);
const batchOuts = f(insBatch);
if (batchIndex === 0) {
for (let i = 0; i < batchOuts.length; ++i) {
outs.push(scalar(0));
}
}
for (let i = 0; i < batchOuts.length; ++i) {
const batchOut = batchOuts[i];
outs[i] =
tfc.add(outs[i], tfc.mul(batchEnd - batchStart, batchOut));
}
}
for (let i = 0; i < outs.length; ++i) {
outs[i] = tfc.div(outs[i], numSamples);
}
}
return outs;
});
}
getDedupedMetricsNames() {
const outLabels = this.metricsNames;
// Rename duplicated metrics names (can happen with an output layer
// shared among multiple dataflows).
const dedupedOutLabels = [];
for (let i = 0; i < outLabels.length; ++i) {
const label = outLabels[i];
let newLabel = label;
if (count(outLabels, label) > 1) {
const dupIndex = count(outLabels.slice(0, i), label);
newLabel += `_${dupIndex}`;
}
dedupedOutLabels.push(newLabel);
}
return dedupedOutLabels;
}
/**
* Creates a function that performs the following actions:
*
* 1. computes the losses
* 2. sums them to get the total loss
* 3. call the optimizer computes the gradients of the LayersModel's
* trainable weights w.r.t. the total loss and update the variables
* 4. calculates the metrics
* 5. returns the values of the losses and metrics.
*/
makeTrainFunction() {
return (data) => {
const lossValues = [];
const inputs = data.slice(0, this.inputs.length);
const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
const sampleWeights = data.slice(this.inputs.length + this.outputs.length, this.inputs.length + this.outputs.length * 2);
const metricsValues = [];
// Create a function that computes the total loss based on the
// inputs. This function is used for obtaining gradients through
// backprop.
const totalLossFunction = () => {
const feeds = [];
for (let i = 0; i < this.inputs.length; ++i) {
feeds.push({ key: this.inputs[i], value: inputs[i] });
}
const feedDict = new FeedDict(feeds);
const outputs = execute(this.outputs, feedDict, { 'training': true });
// TODO(cais): Take care of the case of multiple outputs from a
// single layer?
let totalLoss;
for (let i = 0; i < this.lossFunctions.length; ++i) {
const lossFunction = this.lossFunctions[i];
let loss = lossFunction(targets[i], outputs[i]);
if (sampleWeights[i] != null) {
loss = computeWeightedLoss(loss, sampleWeights[i]);
}
// TODO(cais): push Scalar instead.
const meanLoss = tfc.mean(loss);
// TODO(cais): Use a scope() instead, to avoid ownership.
lossValues.push(meanLoss);
if (i === 0) {
totalLoss = loss;
}
else {
totalLoss = tfc.add(totalLoss, loss);
}
}
// Compute the metrics.
// TODO(cais): These should probably be calculated outside
// totalLossFunction to benefit speed?
for (let i = 0; i < this.metricsTensors.length; ++i) {
let weightedMetric;
if (this.outputs.length > 1 && i < this.outputs.length) {
weightedMetric = lossValues[i];
}
els