UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

178 lines 26.9 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Interfaces and methods for training models using tf.Tensor objects. */ import * as tfc from '@tensorflow/tfjs-core'; import { Tensor } from '@tensorflow/tfjs-core'; import { expandDims, gather, sliceAlongFirstAxis } from '../backend/tfjs_backend'; export function checkBatchSize(batchSize) { tfc.util.assert(batchSize > 0 && Number.isInteger(batchSize), () => `batchSize is required to be a positive integer, but got ${batchSize}`); } /** * Slice a Tensor or an Array of Tensors, by start and stop indices. * * Porting Note: The `_slice_arrays` function in PyKeras is covered by this * function and `sliceArraysByIndices()` together. * * @param arrays: the input. * @param start: the starting index (inclusive). * @param stop: the stopping index (exclusive). * @returns The result of the slicing. If `arrays` is an `Array` of * `tf.Tensor`s, the slicing will be applied to all elements of the `Array` * in the same way. */ export function sliceArrays(arrays, start, stop) { if (arrays == null) { return [null]; } else if (Array.isArray(arrays)) { return arrays.map(array => sliceAlongFirstAxis(array, start, stop - start)); } else { // Tensor. return sliceAlongFirstAxis(arrays, start, stop - start); } } /** * Slice a Tensor or an Array of Tensors, by random-order indices. * * Porting Note: The `_slice_arrays` function in PyKeras is covered by this * function and `sliceArrays()` together. * * @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice. * If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the * same fashion. * @param indices The indices to use for slicing along the first (batch) * dimension. * @returns Result(s) of the slicing. */ export function sliceArraysByIndices(arrays, indices) { return tfc.tidy(() => { if (arrays == null) { return null; } else if (Array.isArray(arrays)) { return arrays.map(array => sliceArraysByIndices(array, indices)); } else { // TODO(cais): indices should be a pre-constructed Tensor1D to avoid // tensor1d() calls. return gather(arrays, indices.dtype === 'int32' ? indices : tfc.cast(indices, 'int32')); } }); } /** * Returns a list of batch indices (tuples of indices). * @param size: Integer, total size of the data to slice into batches. * @param batchSize: Integer, batch size. * @returns An Array of [batchStart, batchEnd] tuples. batchStart is * inclusive; batchEnd is exclusive. I.e., each batch consists of indices x * that satisfy batchStart <= x < batchEnd. */ export function makeBatches(size, batchSize) { const output = []; let batchStart = 0; let batchEnd = null; while (batchStart < size) { batchEnd = batchStart + batchSize; if (batchEnd >= size) { batchEnd = size; } output.push([batchStart, batchEnd]); batchStart = batchEnd; } return output; } /** * Ensure tensors all have a rank of at least 2. * * If a tensor has a rank of 1, it is dimension-expanded to rank 2. * If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown. */ export function ensureTensorsRank2OrHigher(tensors) { const outs = []; if (tensors instanceof Tensor) { tensors = [tensors]; } // Make Tensors at least 2D. for (let i = 0; i < tensors.length; ++i) { const tensor = tensors[i]; if (tensor.rank === 1) { outs.push(expandDims(tensor, 1)); } else if (tensor.rank === 0) { throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' + '(scalar).'); } else { outs.push(tensor); } } return outs; } /** * Compare a set of tensors with a reference (old) set, discard the ones * in the new set that are not present in the reference set. * * This method is used for memory clenaup during calls such as * LayersModel.fit(). * * @param tensors New set which may contain Tensors not present in * `refTensors`. * @param refTensors Reference Tensor set. */ // TODO(cais, kangyizhang): Deduplicate with tfjs-data. export function disposeNewTensors(tensors, refTensors) { if (tensors == null) { return; } const oldTensorIds = []; if (refTensors instanceof Tensor) { oldTensorIds.push(refTensors.id); } else if (Array.isArray(refTensors)) { refTensors.forEach(t => oldTensorIds.push(t.id)); } else if (refTensors != null) { // `oldTensors` is a map from string name to Tensor. for (const name in refTensors) { const oldTensor = refTensors[name]; oldTensorIds.push(oldTensor.id); } } const tensorsToDispose = []; if (tensors instanceof Tensor) { if (oldTensorIds.indexOf(tensors.id) === -1) { tensorsToDispose.push(tensors); } } else if (Array.isArray(tensors)) { tensors.forEach(t => { if (oldTensorIds.indexOf(t.id) === -1) { tensorsToDispose.push(t); } }); } else if (tensors != null) { // `oldTensors` is a map from string name to Tensor. for (const name in tensors) { const tensor = tensors[name]; if (oldTensorIds.indexOf(tensor.id) === -1) { tensorsToDispose.push(tensor); } } } tensorsToDispose.forEach(t => { if (!t.isDisposed) { t.dispose(); } }); } //# sourceMappingURL=data:application/json;base64,