UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

178 lines 26.9 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Interfaces and methods for training models using tf.Tensor objects. */ import * as tfc from '@tensorflow/tfjs-core'; import { Tensor } from '@tensorflow/tfjs-core'; import { expandDims, gather, sliceAlongFirstAxis } from '../backend/tfjs_backend'; export function checkBatchSize(batchSize) { tfc.util.assert(batchSize > 0 && Number.isInteger(batchSize), () => `batchSize is required to be a positive integer, but got ${batchSize}`); } /** * Slice a Tensor or an Array of Tensors, by start and stop indices. * * Porting Note: The `_slice_arrays` function in PyKeras is covered by this * function and `sliceArraysByIndices()` together. * * @param arrays: the input. * @param start: the starting index (inclusive). * @param stop: the stopping index (exclusive). * @returns The result of the slicing. If `arrays` is an `Array` of * `tf.Tensor`s, the slicing will be applied to all elements of the `Array` * in the same way. */ export function sliceArrays(arrays, start, stop) { if (arrays == null) { return [null]; } else if (Array.isArray(arrays)) { return arrays.map(array => sliceAlongFirstAxis(array, start, stop - start)); } else { // Tensor. return sliceAlongFirstAxis(arrays, start, stop - start); } } /** * Slice a Tensor or an Array of Tensors, by random-order indices. * * Porting Note: The `_slice_arrays` function in PyKeras is covered by this * function and `sliceArrays()` together. * * @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice. * If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the * same fashion. * @param indices The indices to use for slicing along the first (batch) * dimension. * @returns Result(s) of the slicing. */ export function sliceArraysByIndices(arrays, indices) { return tfc.tidy(() => { if (arrays == null) { return null; } else if (Array.isArray(arrays)) { return arrays.map(array => sliceArraysByIndices(array, indices)); } else { // TODO(cais): indices should be a pre-constructed Tensor1D to avoid // tensor1d() calls. return gather(arrays, indices.dtype === 'int32' ? indices : tfc.cast(indices, 'int32')); } }); } /** * Returns a list of batch indices (tuples of indices). * @param size: Integer, total size of the data to slice into batches. * @param batchSize: Integer, batch size. * @returns An Array of [batchStart, batchEnd] tuples. batchStart is * inclusive; batchEnd is exclusive. I.e., each batch consists of indices x * that satisfy batchStart <= x < batchEnd. */ export function makeBatches(size, batchSize) { const output = []; let batchStart = 0; let batchEnd = null; while (batchStart < size) { batchEnd = batchStart + batchSize; if (batchEnd >= size) { batchEnd = size; } output.push([batchStart, batchEnd]); batchStart = batchEnd; } return output; } /** * Ensure tensors all have a rank of at least 2. * * If a tensor has a rank of 1, it is dimension-expanded to rank 2. * If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown. */ export function ensureTensorsRank2OrHigher(tensors) { const outs = []; if (tensors instanceof Tensor) { tensors = [tensors]; } // Make Tensors at least 2D. for (let i = 0; i < tensors.length; ++i) { const tensor = tensors[i]; if (tensor.rank === 1) { outs.push(expandDims(tensor, 1)); } else if (tensor.rank === 0) { throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' + '(scalar).'); } else { outs.push(tensor); } } return outs; } /** * Compare a set of tensors with a reference (old) set, discard the ones * in the new set that are not present in the reference set. * * This method is used for memory clenaup during calls such as * LayersModel.fit(). * * @param tensors New set which may contain Tensors not present in * `refTensors`. * @param refTensors Reference Tensor set. */ // TODO(cais, kangyizhang): Deduplicate with tfjs-data. export function disposeNewTensors(tensors, refTensors) { if (tensors == null) { return; } const oldTensorIds = []; if (refTensors instanceof Tensor) { oldTensorIds.push(refTensors.id); } else if (Array.isArray(refTensors)) { refTensors.forEach(t => oldTensorIds.push(t.id)); } else if (refTensors != null) { // `oldTensors` is a map from string name to Tensor. for (const name in refTensors) { const oldTensor = refTensors[name]; oldTensorIds.push(oldTensor.id); } } const tensorsToDispose = []; if (tensors instanceof Tensor) { if (oldTensorIds.indexOf(tensors.id) === -1) { tensorsToDispose.push(tensors); } } else if (Array.isArray(tensors)) { tensors.forEach(t => { if (oldTensorIds.indexOf(t.id) === -1) { tensorsToDispose.push(t); } }); } else if (tensors != null) { // `oldTensors` is a map from string name to Tensor. for (const name in tensors) { const tensor = tensors[name]; if (oldTensorIds.indexOf(tensor.id) === -1) { tensorsToDispose.push(tensor); } } } tensorsToDispose.forEach(t => { if (!t.isDisposed) { t.dispose(); } }); } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"training_tensors.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/engine/training_tensors.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AAEH,OAAO,KAAK,GAAG,MAAM,uBAAuB,CAAC;AAC7C,OAAO,EAAC,MAAM,EAAW,MAAM,uBAAuB,CAAC;AACvD,OAAO,EAAC,UAAU,EAAE,MAAM,EAAE,mBAAmB,EAAC,MAAM,yBAAyB,CAAC;AA6IhF,MAAM,UAAU,cAAc,CAAC,SAAiB;IAC9C,GAAG,CAAC,IAAI,CAAC,MAAM,CACX,SAAS,GAAG,CAAC,IAAI,MAAM,CAAC,SAAS,CAAC,SAAS,CAAC,EAC5C,GAAG,EAAE,CAAC,2DACF,SAAS,EAAE,CAAC,CAAC;AACvB,CAAC;AAED;;;;;;;;;;;;GAYG;AACH,MAAM,UAAU,WAAW,CACvB,MAAuB,EAAE,KAAa,EAAE,IAAY;IACtD,IAAI,MAAM,IAAI,IAAI,EAAE;QAClB,OAAO,CAAC,IAAI,CAAC,CAAC;KACf;SAAM,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;QAChC,OAAO,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC,mBAAmB,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,GAAG,KAAK,CAAC,CAAC,CAAC;KAC7E;SAAM,EAAG,UAAU;QAClB,OAAO,mBAAmB,CAAC,MAAM,EAAE,KAAK,EAAE,IAAI,GAAG,KAAK,CAAC,CAAC;KACzD;AACH,CAAC;AAED;;;;;;;;;;;;GAYG;AACH,MAAM,UAAU,oBAAoB,CAChC,MAAuB,EAAE,OAAiB;IAC5C,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE;QACnB,IAAI,MAAM,IAAI,IAAI,EAAE;YAClB,OAAO,IAAI,CAAC;SACb;aAAM,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAChC,OAAO,MAAM,CAAC,GAAG,CACb,KAAK,CAAC,EAAE,CAAE,oBAAoB,CAAC,KAAK,EAAE,OAAO,CAAY,CAAC,CAAC;SAChE;aAAM;YACL,oEAAoE;YACpE,sBAAsB;YACtB,OAAO,MAAM,CACT,MAAM,EACN,OAAO,CAAC,KAAK,KAAK,OAAO,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;SACvE;IACH,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;;;;GAOG;AACH,MAAM,UAAU,WAAW,CACvB,IAAY,EAAE,SAAiB;IACjC,MAAM,MAAM,GAA4B,EAAE,CAAC;IAC3C,IAAI,UAAU,GAAG,CAAC,CAAC;IACnB,IAAI,QAAQ,GAAW,IAAI,CAAC;IAC5B,OAAO,UAAU,GAAG,IAAI,EAAE;QACxB,QAAQ,GAAG,UAAU,GAAG,SAAS,CAAC;QAClC,IAAI,QAAQ,IAAI,IAAI,EAAE;YACpB,QAAQ,GAAG,IAAI,CAAC;SACjB;QACD,MAAM,CAAC,IAAI,CAAC,CAAC,UAAU,EAAE,QAAQ,CAAC,CAAC,CAAC;QACpC,UAAU,GAAG,QAAQ,CAAC;KACvB;IACD,OAAO,MAAM,CAAC;AAChB,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,0BAA0B,CAAC,OAAwB;IACjE,MAAM,IAAI,GAAa,EAAE,CAAC;IAC1B,IAAI,OAAO,YAAY,MAAM,EAAE;QAC7B,OAAO,GAAG,CAAC,OAAO,CAAC,CAAC;KACrB;IAED,4BAA4B;IAC5B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACvC,MAAM,MAAM,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;QAC1B,IAAI,MAAM,CAAC,IAAI,KAAK,CAAC,EAAE;YACrB,IAAI,CAAC,IAAI,CAAC,UAAU,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC;SAClC;aAAM,IAAI,MAAM,CAAC,IAAI,KAAK,CAAC,EAAE;YAC5B,MAAM,IAAI,KAAK,CACX,8DAA8D;gBAC9D,WAAW,CAAC,CAAC;SAClB;aAAM;YACL,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;SACnB;KACF;IACD,OAAO,IAAI,CAAC;AACd,CAAC;AAED;;;;;;;;;;GAUG;AACH,uDAAuD;AACvD,MAAM,UAAU,iBAAiB,CAC7B,OAAsD,EACtD,UAAyD;IAC3D,IAAI,OAAO,IAAI,IAAI,EAAE;QACnB,OAAO;KACR;IACD,MAAM,YAAY,GAAa,EAAE,CAAC;IAClC,IAAI,UAAU,YAAY,MAAM,EAAE;QAChC,YAAY,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;KAClC;SAAM,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE;QACpC,UAAU,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;KAClD;SAAM,IAAI,UAAU,IAAI,IAAI,EAAE;QAC7B,oDAAoD;QACpD,KAAK,MAAM,IAAI,IAAI,UAAU,EAAE;YAC7B,MAAM,SAAS,GAAG,UAAU,CAAC,IAAI,CAAC,CAAC;YACnC,YAAY,CAAC,IAAI,CAAC,SAAS,CAAC,EAAE,CAAC,CAAC;SACjC;KACF;IAED,MAAM,gBAAgB,GAAa,EAAE,CAAC;IACtC,IAAI,OAAO,YAAY,MAAM,EAAE;QAC7B,IAAI,YAAY,CAAC,OAAO,CAAC,OAAO,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE;YAC3C,gBAAgB,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;SAChC;KACF;SAAM,IAAI,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,EAAE;QACjC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;YAClB,IAAI,YAAY,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE;gBACrC,gBAAgB,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;aAC1B;QACH,CAAC,CAAC,CAAC;KACJ;SAAM,IAAI,OAAO,IAAI,IAAI,EAAE;QAC1B,oDAAoD;QACpD,KAAK,MAAM,IAAI,IAAI,OAAO,EAAE;YAC1B,MAAM,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAAC;YAC7B,IAAI,YAAY,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE;gBAC1C,gBAAgB,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;aAC/B;SACF;KACF;IAED,gBAAgB,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;QAC3B,IAAI,CAAC,CAAC,CAAC,UAAU,EAAE;YACjB,CAAC,CAAC,OAAO,EAAE,CAAC;SACb;IACH,CAAC,CAAC,CAAC;AACL,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * Interfaces and methods for training models using tf.Tensor objects.\n */\n\nimport * as tfc from '@tensorflow/tfjs-core';\nimport {Tensor, Tensor1D} from '@tensorflow/tfjs-core';\nimport {expandDims, gather, sliceAlongFirstAxis} from '../backend/tfjs_backend';\nimport {BaseCallback, CustomCallbackArgs, ModelLoggingVerbosity, YieldEveryOptions} from '../base_callbacks';\nimport {ClassWeight, ClassWeightMap} from './training_utils';\n\n/**\n * Interface configuration model training based on data as `tf.Tensor`s.\n */\nexport interface ModelFitArgs {\n  /**\n   * Number of samples per gradient update. If unspecified, it\n   * will default to 32.\n   */\n  batchSize?: number;\n\n  /**\n   * Integer number of times to iterate over the training data arrays.\n   */\n  epochs?: number;\n\n  /**\n   * Verbosity level.\n   *\n   * Expected to be 0, 1, or 2. Default: 1.\n   *\n   * 0 - No printed message during fit() call.\n   * 1 - In Node.js (tfjs-node), prints the progress bar, together with\n   *     real-time updates of loss and metric values and training speed.\n   *     In the browser: no action. This is the default.\n   * 2 - Not implemented yet.\n   */\n  verbose?: ModelLoggingVerbosity | 2;\n\n  /**\n   * List of callbacks to be called during training.\n   * Can have one or more of the following callbacks:\n   *   - `onTrainBegin(logs)`: called when training starts.\n   *   - `onTrainEnd(logs)`: called when training ends.\n   *   - `onEpochBegin(epoch, logs)`: called at the start of every epoch.\n   *   - `onEpochEnd(epoch, logs)`: called at the end of every epoch.\n   *   - `onBatchBegin(batch, logs)`: called at the start of every batch.\n   *   - `onBatchEnd(batch, logs)`: called at the end of every batch.\n   *   - `onYield(epoch, batch, logs)`: called every `yieldEvery` milliseconds\n   *      with the current epoch, batch and logs. The logs are the same\n   *      as in `onBatchEnd()`. Note that `onYield` can skip batches or\n   *      epochs. See also docs for `yieldEvery` below.\n   */\n  callbacks?: BaseCallback[]|CustomCallbackArgs|CustomCallbackArgs[];\n\n  /**\n   * Float between 0 and 1: fraction of the training data\n   * to be used as validation data. The model will set apart this fraction of\n   * the training data, will not train on it, and will evaluate the loss and\n   * any model metrics on this data at the end of each epoch.\n   * The validation data is selected from the last samples in the `x` and `y`\n   * data provided, before shuffling.\n   */\n  validationSplit?: number;\n\n  /**\n   * Data on which to evaluate the loss and any model\n   * metrics at the end of each epoch. The model will not be trained on this\n   * data. This could be a tuple [xVal, yVal] or a tuple [xVal, yVal,\n   * valSampleWeights]. The model will not be trained on this data.\n   * `validationData` will override `validationSplit`.\n   */\n  validationData?: [\n    Tensor|Tensor[], Tensor|Tensor[]\n  ]|[Tensor | Tensor[], Tensor|Tensor[], Tensor|Tensor[]];\n\n  /**\n   * Whether to shuffle the training data before each epoch. Has\n   * no effect when `stepsPerEpoch` is not `null`.\n   */\n  shuffle?: boolean;\n\n  /**\n   * Optional object mapping class indices (integers) to\n   * a weight (float) to apply to the model's loss for the samples from this\n   * class during training. This can be useful to tell the model to \"pay more\n   * attention\" to samples from an under-represented class.\n   *\n   * If the model has multiple outputs, a class weight can be specified for\n   * each of the outputs by setting this field an array of weight object\n   * or an object that maps model output names (e.g., `model.outputNames[0]`)\n   * to weight objects.\n   */\n  classWeight?: ClassWeight|ClassWeight[]|ClassWeightMap;\n\n  /**\n   * Optional array of the same length as x, containing\n   * weights to apply to the model's loss for each sample. In the case of\n   * temporal data, you can pass a 2D array with shape (samples,\n   * sequenceLength), to apply a different weight to every timestep of every\n   * sample. In this case you should make sure to specify\n   * sampleWeightMode=\"temporal\" in compile().\n   */\n  sampleWeight?: Tensor;\n\n  /**\n   * Epoch at which to start training (useful for resuming a previous training\n   * run). When this is used, `epochs` is the index of the \"final epoch\".\n   * The model is not trained for a number of iterations given by `epochs`,\n   * but merely until the epoch of index `epochs` is reached.\n   */\n  initialEpoch?: number;\n\n  /**\n   * Total number of steps (batches of samples) before\n   * declaring one epoch finished and starting the next epoch. When training\n   * with Input Tensors such as TensorFlow data tensors, the default `null` is\n   * equal to the number of unique samples in your dataset divided by the\n   * batch size, or 1 if that cannot be determined.\n   */\n  stepsPerEpoch?: number;\n\n  /**\n   * Only relevant if `stepsPerEpoch` is specified. Total number of steps\n   * (batches of samples) to validate before stopping.\n   */\n  validationSteps?: number;\n\n  /**\n   * Configures the frequency of yielding the main thread to other tasks.\n   *\n   * In the browser environment, yielding the main thread can improve the\n   * responsiveness of the page during training. In the Node.js environment,\n   * it can ensure tasks queued in the event loop can be handled in a timely\n   * manner.\n   *\n   * The value can be one of the following:\n   *   - `'auto'`: The yielding happens at a certain frame rate (currently set\n   *               at 125ms). This is the default.\n   *   - `'batch'`: yield every batch.\n   *   - `'epoch'`: yield every epoch.\n   *   - any `number`: yield every `number` milliseconds.\n   *   - `'never'`: never yield. (yielding can still happen through `await\n   *      nextFrame()` calls in custom callbacks.)\n   */\n  yieldEvery?: YieldEveryOptions;\n}\n\nexport function checkBatchSize(batchSize: number) {\n  tfc.util.assert(\n      batchSize > 0 && Number.isInteger(batchSize),\n      () => `batchSize is required to be a positive integer, but got ${\n          batchSize}`);\n}\n\n/**\n * Slice a Tensor or an Array of Tensors, by start and stop indices.\n *\n * Porting Note: The `_slice_arrays` function in PyKeras is covered by this\n *   function and `sliceArraysByIndices()` together.\n *\n * @param arrays: the input.\n * @param start: the starting index (inclusive).\n * @param stop: the stopping index (exclusive).\n * @returns The result of the slicing. If `arrays` is an `Array` of\n *   `tf.Tensor`s, the slicing will be applied to all elements of the `Array`\n *   in the same way.\n */\nexport function sliceArrays(\n    arrays: Tensor|Tensor[], start: number, stop: number): Tensor|Tensor[] {\n  if (arrays == null) {\n    return [null];\n  } else if (Array.isArray(arrays)) {\n    return arrays.map(array => sliceAlongFirstAxis(array, start, stop - start));\n  } else {  // Tensor.\n    return sliceAlongFirstAxis(arrays, start, stop - start);\n  }\n}\n\n/**\n * Slice a Tensor or an Array of Tensors, by random-order indices.\n *\n * Porting Note: The `_slice_arrays` function in PyKeras is covered by this\n *   function and `sliceArrays()` together.\n *\n * @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice.\n *   If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the\n *   same fashion.\n * @param indices The indices to use for slicing along the first (batch)\n *   dimension.\n * @returns Result(s) of the slicing.\n */\nexport function sliceArraysByIndices(\n    arrays: Tensor|Tensor[], indices: Tensor1D): Tensor|Tensor[] {\n  return tfc.tidy(() => {\n    if (arrays == null) {\n      return null;\n    } else if (Array.isArray(arrays)) {\n      return arrays.map(\n          array => (sliceArraysByIndices(array, indices) as Tensor));\n    } else {\n      // TODO(cais): indices should be a pre-constructed Tensor1D to avoid\n      //   tensor1d() calls.\n      return gather(\n          arrays,\n          indices.dtype === 'int32' ? indices : tfc.cast(indices, 'int32'));\n    }\n  });\n}\n\n/**\n * Returns a list of batch indices (tuples of indices).\n * @param size: Integer, total size of the data to slice into batches.\n * @param batchSize: Integer, batch size.\n * @returns An Array of [batchStart, batchEnd] tuples. batchStart is\n *   inclusive; batchEnd is exclusive. I.e., each batch consists of indices x\n *   that satisfy batchStart <= x < batchEnd.\n */\nexport function makeBatches(\n    size: number, batchSize: number): Array<[number, number]> {\n  const output: Array<[number, number]> = [];\n  let batchStart = 0;\n  let batchEnd: number = null;\n  while (batchStart < size) {\n    batchEnd = batchStart + batchSize;\n    if (batchEnd >= size) {\n      batchEnd = size;\n    }\n    output.push([batchStart, batchEnd]);\n    batchStart = batchEnd;\n  }\n  return output;\n}\n\n/**\n * Ensure tensors all have a rank of at least 2.\n *\n * If a tensor has a rank of 1, it is dimension-expanded to rank 2.\n * If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown.\n */\nexport function ensureTensorsRank2OrHigher(tensors: Tensor|Tensor[]): Tensor[] {\n  const outs: Tensor[] = [];\n  if (tensors instanceof Tensor) {\n    tensors = [tensors];\n  }\n\n  // Make Tensors at least 2D.\n  for (let i = 0; i < tensors.length; ++i) {\n    const tensor = tensors[i];\n    if (tensor.rank === 1) {\n      outs.push(expandDims(tensor, 1));\n    } else if (tensor.rank === 0) {\n      throw new Error(\n          'Expected tensor to be at least 1D, but received a 0D tensor ' +\n          '(scalar).');\n    } else {\n      outs.push(tensor);\n    }\n  }\n  return outs;\n}\n\n/**\n * Compare a set of tensors with a reference (old) set, discard the ones\n * in the new set that are not present in the reference set.\n *\n * This method is used for memory clenaup during calls such as\n * LayersModel.fit().\n *\n * @param tensors New set which may contain Tensors not present in\n *   `refTensors`.\n * @param refTensors Reference Tensor set.\n */\n// TODO(cais, kangyizhang): Deduplicate with tfjs-data.\nexport function disposeNewTensors(\n    tensors: Tensor|Tensor[]|{[inputName: string]: Tensor},\n    refTensors: Tensor|Tensor[]|{[inputName: string]: Tensor}): void {\n  if (tensors == null) {\n    return;\n  }\n  const oldTensorIds: number[] = [];\n  if (refTensors instanceof Tensor) {\n    oldTensorIds.push(refTensors.id);\n  } else if (Array.isArray(refTensors)) {\n    refTensors.forEach(t => oldTensorIds.push(t.id));\n  } else if (refTensors != null) {\n    // `oldTensors` is a map from string name to Tensor.\n    for (const name in refTensors) {\n      const oldTensor = refTensors[name];\n      oldTensorIds.push(oldTensor.id);\n    }\n  }\n\n  const tensorsToDispose: Tensor[] = [];\n  if (tensors instanceof Tensor) {\n    if (oldTensorIds.indexOf(tensors.id) === -1) {\n      tensorsToDispose.push(tensors);\n    }\n  } else if (Array.isArray(tensors)) {\n    tensors.forEach(t => {\n      if (oldTensorIds.indexOf(t.id) === -1) {\n        tensorsToDispose.push(t);\n      }\n    });\n  } else if (tensors != null) {\n    // `oldTensors` is a map from string name to Tensor.\n    for (const name in tensors) {\n      const tensor = tensors[name];\n      if (oldTensorIds.indexOf(tensor.id) === -1) {\n        tensorsToDispose.push(tensor);\n      }\n    }\n  }\n\n  tensorsToDispose.forEach(t => {\n    if (!t.isDisposed) {\n      t.dispose();\n    }\n  });\n}\n"]}