UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

1,499 lines (1,486 loc) 1.28 MB
/** * @license * Copyright 2024 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as tfc from '@tensorflow/tfjs-core'; import { util, backend, tidy as tidy$1, tensor1d, serialization, zeros as zeros$2, ones as ones$3, mul as mul$1, scalar as scalar$1, randomUniform as randomUniform$1, truncatedNormal as truncatedNormal$1, eye, linalg, dispose, memory, cast as cast$2, env as env$1, nextFrame, add as add$3, div as div$1, keep, train, clone as clone$1, argMax, reshape as reshape$2, Tensor as Tensor$1, Optimizer, io, sum as sum$1, abs, relu, clipByValue, leakyRelu, prelu as prelu$1, elu as elu$2, greater as greater$2, sub as sub$1, exp as exp$1, logSumExp, transpose as transpose$1, any, notEqual, zerosLike as zerosLike$1, greaterEqual as greaterEqual$1, moments, stack as stack$1, tensor, range as range$1, unstack as unstack$1, image, expandDims as expandDims$2, denseBincount, max as max$1, min as min$1 } from '@tensorflow/tfjs-core'; function _mergeNamespaces(n, m) { m.forEach(function (e) { e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) { if (k !== 'default' && !(k in n)) { var d = Object.getOwnPropertyDescriptor(e, k); Object.defineProperty(n, k, d.get ? d : { enumerable: true, get: function () { return e[k]; } }); } }); }); return n; } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Explicit error types. * * See the following link for more information about why the code includes * calls to setPrototypeOf: * * https://github.com/Microsoft/TypeScript-wiki/blob/master/Breaking-Changes.md#extending-built-ins-like-error-array-and-map-may-no-longer-work */ // tslint:enable /** * Equivalent of Python's AttributeError. */ class AttributeError extends Error { constructor(message) { super(message); // Set the prototype explicitly. Object.setPrototypeOf(this, AttributeError.prototype); } } /** * Equivalent of Python's RuntimeError. */ class RuntimeError extends Error { constructor(message) { super(message); // Set the prototype explicitly. Object.setPrototypeOf(this, RuntimeError.prototype); } } /** * Equivalent of Python's ValueError. */ class ValueError extends Error { constructor(message) { super(message); // Set the prototype explicitly. Object.setPrototypeOf(this, ValueError.prototype); } } /** * Equivalent of Python's NotImplementedError. */ class NotImplementedError extends Error { constructor(message) { super(message); // Set the prototype explicitly. Object.setPrototypeOf(this, NotImplementedError.prototype); } } /** * Equivalent of Python's AssertionError. */ class AssertionError extends Error { constructor(message) { super(message); // Set the prototype explicitly. Object.setPrototypeOf(this, AssertionError.prototype); } } /** * @license * Copyright 2022 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * LruCache: A mapping from the String to T. If the number of the entries is * exceeding the `maxEntries`, the LruCache will delete the least recently * used entry. */ class LruCache { constructor(maxEntries) { this.maxEntries = maxEntries || 100; this.cache = new Map(); } /** * Get the entry for the key and mark it as used recently. */ get(key) { let entry; if (this.cache.has(key)) { entry = this.cache.get(key); this.cache.delete(key); this.cache.set(key, entry); } return entry; } /** * Put the entry into the cache. If the key already existed, mark the key as * used recently. */ put(key, value) { if (this.cache.has(key)) { this.cache.delete(key); } else if (this.cache.size >= this.maxEntries) { const keyToDelete = this.cache.keys().next().value; this.cache.delete(keyToDelete); } this.cache.set(key, value); } /** * Get the MaxEntries of the cache. */ getMaxEntries() { return this.maxEntries; } /** * Set the MaxEntries of the cache. If the maxEntries is decreased, reduce * entries in the cache. */ setMaxEntries(maxEntries) { if (maxEntries < 0) { throw new Error(`The maxEntries of LRU caches must be at least 0, but got ${maxEntries}.`); } if (this.maxEntries > maxEntries) { for (let i = 0; i < this.maxEntries - maxEntries; i++) { const keyToDelete = this.cache.keys().next().value; this.cache.delete(keyToDelete); } } this.maxEntries = maxEntries; } } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ // tslint:enable /** * If `value` is an Array, equivalent to Python's `value * numValues`. * If `value` is not an Array, equivalent to Python's `[value] * numValues` */ // tslint:disable-next-line:no-any function pyListRepeat(value, numValues) { if (Array.isArray(value)) { // tslint:disable-next-line:no-any let newArray = []; for (let i = 0; i < numValues; i++) { newArray = newArray.concat(value); } return newArray; } else { const newArray = new Array(numValues); newArray.fill(value); return newArray; } } function assert$1(val, message) { if (!val) { throw new AssertionError(message); } } /** * Count the number of elements of the `array` that are equal to `reference`. */ function count(array, refernce) { let counter = 0; for (const item of array) { if (item === refernce) { counter++; } } return counter; } /** * If an array is of length 1, just return the first element. Otherwise, return * the full array. * @param tensors */ function singletonOrArray(xs) { if (xs.length === 1) { return xs[0]; } return xs; } /** * Normalizes a list/tensor into a list. * * If a tensor is passed, we return * a list of size 1 containing the tensor. * * @param x target object to be normalized. */ // tslint:disable-next-line:no-any function toList(x) { if (Array.isArray(x)) { return x; } return [x]; } /** * Converts string to snake-case. * @param name */ function toSnakeCase(name) { const intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2'); const insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase(); /* If the class is private the name starts with "_" which is not secure for creating scopes. We prefix the name with "private" in this case. */ if (insecure[0] !== '_') { return insecure; } return 'private' + insecure; } function toCamelCase(identifier) { // quick return for empty string or single character strings if (identifier.length <= 1) { return identifier; } // Check for the underscore indicating snake_case if (identifier.indexOf('_') === -1) { return identifier; } return identifier.replace(/[_]+(\w|$)/g, (m, p1) => p1.toUpperCase()); } // tslint:disable-next-line:no-any let _GLOBAL_CUSTOM_OBJECTS = {}; function serializeKerasObject(instance) { if (instance === null || instance === undefined) { return null; } const dict = {}; dict['className'] = instance.getClassName(); dict['config'] = instance.getConfig(); return dict; } /** * Replace ndarray-style scalar objects in serialization objects with numbers. * * Background: In some versions of tf.keras, certain scalar values in the HDF5 * model save file can be serialized as: `{'type': 'ndarray', 'value': num}`, * where in `num` is a plain number. This method converts such serialization * to a `number`. * * @param config The keras-format serialization object to be processed * (in place). */ function convertNDArrayScalarsInConfig(config) { if (config == null || typeof config !== 'object') { return; } else if (Array.isArray(config)) { config.forEach(configItem => convertNDArrayScalarsInConfig(configItem)); } else { const fields = Object.keys(config); for (const field of fields) { const value = config[field]; if (value != null && typeof value === 'object') { if (!Array.isArray(value) && value['type'] === 'ndarray' && typeof value['value'] === 'number') { config[field] = value['value']; } else { convertNDArrayScalarsInConfig(value); } } } } } /** * Deserialize a saved Keras Object * @param identifier either a string ID or a saved Keras dictionary * @param moduleObjects a list of Python class names to object constructors * @param customObjects a list of Python class names to object constructors * @param printableModuleName debug text for the object being reconstituted * @param fastWeightInit Optional flag to use fast weight initialization * during deserialization. This is applicable to cases in which * the initialization will be immediately overwritten by loaded weight * values. Default: `false`. * @returns a TensorFlow.js Layers object */ // tslint:disable:no-any function deserializeKerasObject(identifier, moduleObjects = {}, customObjects = {}, printableModuleName = 'object', fastWeightInit = false) { // tslint:enable if (typeof identifier === 'string') { const functionName = identifier; let fn; if (functionName in customObjects) { fn = customObjects[functionName]; } else if (functionName in _GLOBAL_CUSTOM_OBJECTS) { fn = _GLOBAL_CUSTOM_OBJECTS[functionName]; } else { fn = moduleObjects[functionName]; if (fn == null) { throw new ValueError(`Unknown ${printableModuleName}: ${identifier}. ` + `This may be due to one of the following reasons:\n` + `1. The ${printableModuleName} is defined in Python, in which ` + `case it needs to be ported to TensorFlow.js or your JavaScript ` + `code.\n` + `2. The custom ${printableModuleName} is defined in JavaScript, ` + `but is not registered properly with ` + `tf.serialization.registerClass().`); // TODO(cais): Add link to tutorial page on custom layers. } } return fn; } else { // In this case we are dealing with a Keras config dictionary. const config = identifier; if (config['className'] == null || config['config'] == null) { throw new ValueError(`${printableModuleName}: Improper config format: ` + `${JSON.stringify(config)}.\n` + `'className' and 'config' must set.`); } const className = config['className']; let cls, fromConfig; if (className in customObjects) { [cls, fromConfig] = customObjects[className]; } else if (className in _GLOBAL_CUSTOM_OBJECTS) { [cls, fromConfig] = _GLOBAL_CUSTOM_OBJECTS['className']; } else if (className in moduleObjects) { [cls, fromConfig] = moduleObjects[className]; } if (cls == null) { throw new ValueError(`Unknown ${printableModuleName}: ${className}. ` + `This may be due to one of the following reasons:\n` + `1. The ${printableModuleName} is defined in Python, in which ` + `case it needs to be ported to TensorFlow.js or your JavaScript ` + `code.\n` + `2. The custom ${printableModuleName} is defined in JavaScript, ` + `but is not registered properly with ` + `tf.serialization.registerClass().`); // TODO(cais): Add link to tutorial page on custom layers. } if (fromConfig != null) { // Porting notes: Instead of checking to see whether fromConfig accepts // customObjects, we create a customObjects dictionary and tack it on to // config['config'] as config['config'].customObjects. Objects can use it, // if they want. // tslint:disable-next-line:no-any const customObjectsCombined = {}; for (const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)) { customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key]; } for (const key of Object.keys(customObjects)) { customObjectsCombined[key] = customObjects[key]; } // Add the customObjects to config const nestedConfig = config['config']; nestedConfig['customObjects'] = customObjectsCombined; const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS); for (const key of Object.keys(customObjects)) { _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key]; } convertNDArrayScalarsInConfig(config['config']); const returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit); _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects); return returnObj; } else { // Then `cls` may be a function returning a class. // In this case by convention `config` holds // the kwargs of the function. const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS); for (const key of Object.keys(customObjects)) { _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key]; } // In python this is **config['config'], for tfjs-layers we require // classes that use this fall-through construction method to take // a config interface that mimics the expansion of named parameters. const returnObj = new cls(config['config']); _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects); return returnObj; } } } /** * Compares two numbers for sorting. * @param a * @param b */ function numberCompare(a, b) { return (a < b) ? -1 : ((a > b) ? 1 : 0); } /** * Comparison of two numbers for reverse sorting. * @param a * @param b */ function reverseNumberCompare(a, b) { return -1 * numberCompare(a, b); } /** * Get the unique elements of an array. * @param xs Array. * @returns An Array consisting of the unique elements in `xs`. */ function unique(xs) { if (xs == null) { return xs; } const out = []; // TODO(cais): Maybe improve performance by sorting. for (const x of xs) { if (out.indexOf(x) === -1) { out.push(x); } } return out; } /** * Determine if an Object is empty (i.e., does not have own properties). * @param obj Object * @returns Whether the Object is empty. * @throws ValueError: If object is `null` or `undefined`. */ function isObjectEmpty(obj) { if (obj == null) { throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`); } for (const key in obj) { if (obj.hasOwnProperty(key)) { return false; } } return true; } /** * Helper function used to build type union/enum run-time checkers. * @param values The list of allowed values. * @param label A string name for the type * @param value The value to test. * @throws ValueError: If the value is not in values nor `undefined`/`null`. */ function checkStringTypeUnionValue(values, label, value) { if (value == null) { return; } if (values.indexOf(value) < 0) { throw new ValueError(`${value} is not a valid ${label}. Valid values are ${values} or null/undefined.`); } } /** * Helper function for verifying the types of inputs. * * Ensures that the elements of `x` are all of type `expectedType`. * Also verifies that the length of `x` is within bounds. * * @param x Object to test. * @param expectedType The string expected type of all of the elements in the * Array. * @param minLength Return false if x.length is less than this. * @param maxLength Return false if x.length is greater than this. * @returns true if and only if `x` is an `Array<expectedType>` with * length >= `minLength` and <= `maxLength`. */ // tslint:disable:no-any function checkArrayTypeAndLength(x, expectedType, minLength = 0, maxLength = Infinity) { assert$1(minLength >= 0); assert$1(maxLength >= minLength); return (Array.isArray(x) && x.length >= minLength && x.length <= maxLength && x.every(e => typeof e === expectedType)); } // tslint:enable:no-any /** * Assert that a value or an array of value are positive integer. * * @param value The value being asserted on. May be a single number or an array * of numbers. * @param name Name of the value, used to make the error message. */ function assertPositiveInteger(value, name) { if (Array.isArray(value)) { util.assert(value.length > 0, () => `${name} is unexpectedly an empty array.`); value.forEach((v, i) => assertPositiveInteger(v, `element ${i + 1} of ${name}`)); } else { util.assert(Number.isInteger(value) && value > 0, () => `Expected ${name} to be a positive integer, but got ` + `${formatAsFriendlyString(value)}.`); } } /** * Format a value into a display-friendly, human-readable fashion. * * - `null` is formatted as `'null'` * - Strings are formated with flanking pair of quotes. * - Arrays are formatted with flanking pair of square brackets. * * @param value The value to display. * @return Formatted string. */ // tslint:disable-next-line:no-any function formatAsFriendlyString(value) { if (value === null) { return 'null'; } else if (Array.isArray(value)) { return '[' + value.map(v => formatAsFriendlyString(v)).join(',') + ']'; } else if (typeof value === 'string') { return `"${value}"`; } else { return `${value}`; } } /** * Returns a function `f2` (decorator) which wraps the original function * `f`. `f2` guarantees that `f` can be called at most once * every `waitMs` ms. If `f2` is called more often, it will return * the last returned result of `f`. * * @param f The original function `f` to wrap. * @param waitMs The time between two consecutive calls to `f` in ms. */ function debounce(f, waitMs, nowFunc) { let lastTime = nowFunc != null ? nowFunc() : util.now(); let lastResult; const f2 = (...args) => { const now = nowFunc != null ? nowFunc() : util.now(); if (now - lastTime < waitMs) { return lastResult; } lastTime = now; lastResult = f(...args); return lastResult; }; return f2; } /** * Returns the fusable activation given a layers identifier. * * @param activationName The layers identifier string. * @return The name of the fusable activation. */ function mapActivationToFusedKernel(activationName) { if (activationName === 'relu') { return 'relu'; } if (activationName === 'linear') { return 'linear'; } if (activationName === 'elu') { return 'elu'; } return null; } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Utilities related to persistent state in the backend. */ /** * An ID to track `tf.SymbolicTensor`s and derived classes. * Required in different places in engine/topology.ts to identify unique * tensors. */ let _nextUniqueTensorId = 0; function getNextUniqueTensorId() { return _nextUniqueTensorId++; } const _uidPrefixes = {}; /** * Provides a unique UID given a string prefix. * * @param prefix */ function getUid(prefix = '') { if (!(prefix in _uidPrefixes)) { _uidPrefixes[prefix] = 0; } _uidPrefixes[prefix] += 1; return prefix + _uidPrefixes[prefix].toString(); } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ const VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast']; const VALID_INTERPOLATION_FORMAT_VALUES = ['nearest', 'bilinear']; const VALID_PADDING_MODE_VALUES = ['valid', 'same', 'causal']; const VALID_POOL_MODE_VALUES = ['max', 'avg']; const VALID_BIDIRECTIONAL_MERGE_MODES = ['sum', 'mul', 'concat', 'ave']; /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ // A map from the requested scoped name of a Tensor to the number of Tensors // wanting that name so far. This allows enforcing name uniqueness by appending // an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc. const nameMap = new Map(); function checkDataFormat(value) { checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value); } function checkInterpolationFormat(value) { checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value); } function checkPaddingMode(value) { checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value); } function checkPoolMode(value) { checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value); } const _nameScopeStack = []; const _nameScopeDivider = '/'; /** * Enter namescope, which can be nested. */ function nameScope(name, fn) { _nameScopeStack.push(name); try { const val = fn(); _nameScopeStack.pop(); return val; } catch (e) { _nameScopeStack.pop(); throw e; } } /** * Get the current namescope as a flat, concatenated string. */ function currentNameScopePrefix() { if (_nameScopeStack.length === 0) { return ''; } else { return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider; } } /** * Get the name a Tensor (or Variable) would have if not uniqueified. * @param tensorName * @return Scoped name string. */ function getScopedTensorName(tensorName) { if (!isValidTensorName(tensorName)) { throw new Error('Not a valid tensor name: \'' + tensorName + '\''); } return currentNameScopePrefix() + tensorName; } /** * Get unique names for Tensors and Variables. * @param scopedName The fully-qualified name of the Tensor, i.e. as produced by * `getScopedTensorName()`. * @return A unique version of the given fully scoped name. * If this is the first time that the scoped name is seen in this session, * then the given `scopedName` is returned unaltered. If the same name is * seen again (producing a collision), an incrementing suffix is added to the * end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc. */ function getUniqueTensorName(scopedName) { if (!isValidTensorName(scopedName)) { throw new Error('Not a valid tensor name: \'' + scopedName + '\''); } if (!nameMap.has(scopedName)) { nameMap.set(scopedName, 0); } const index = nameMap.get(scopedName); nameMap.set(scopedName, nameMap.get(scopedName) + 1); if (index > 0) { const result = `${scopedName}_${index}`; // Mark the composed name as used in case someone wants // to call getUniqueTensorName("name_1"). nameMap.set(result, 1); return result; } else { return scopedName; } } const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/); /** * Determine whether a string is a valid tensor name. * @param name * @returns A Boolean indicating whether `name` is a valid tensor name. */ function isValidTensorName(name) { return !!name.match(tensorNameRegex); } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Determine if a number is an integer. */ function isInteger(x) { return x === parseInt(x.toString(), 10); } /** * Calculate the product of an array of numbers. * @param array The array to calculate the product over. * @param begin Beginning index, inclusive. * @param end Ending index, exclusive. * @return The product. */ function arrayProd(array, begin, end) { if (begin == null) { begin = 0; } if (end == null) { end = array.length; } let prod = 1; for (let i = begin; i < end; ++i) { prod *= array[i]; } return prod; } /** * Compute minimum value. * @param array * @return minimum value. */ function min(array) { // same behavior as tf.min() if (array.length === 0) { return Number.NaN; } let min = Number.POSITIVE_INFINITY; for (let i = 0; i < array.length; i++) { const value = array[i]; if (value < min) { min = value; } } return min; } /** * Compute maximum value. * @param array * @return maximum value */ function max(array) { // same behavior as tf.max() if (array.length === 0) { return Number.NaN; } let max = Number.NEGATIVE_INFINITY; for (let i = 0; i < array.length; i++) { const value = array[i]; if (value > max) { max = value; } } return max; } /** * Generate an array of integers in [begin, end). * @param begin Beginning integer, inclusive. * @param end Ending integer, exclusive. * @returns Range array. * @throws ValueError, iff `end` < `begin`. */ function range(begin, end) { if (end < begin) { throw new ValueError(`end (${end}) < begin (${begin}) is forbidden.`); } const out = []; for (let i = begin; i < end; ++i) { out.push(i); } return out; } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ let _epsilon; /** * Returns the value of the fuzz factor used in numeric expressions. */ function epsilon() { if (_epsilon == null) { _epsilon = backend().epsilon(); } return _epsilon; } /** * Returns the default image data format convention. */ function imageDataFormat() { return 'channelsLast'; } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Casts a tensor to a different dtype and returns it. * @param x Input tensor. * @param dtype String: 'float32'|'int32'|'bool'. * @returns Tensor of the specified `dtype`. */ function cast$1(x, dtype) { return tfc.cast(x, dtype); } /** * Adds a 1-sized dimension at index "axis". * @param x Input tensor. * @param axis Position where to add the new axis. * @returns Result of the dimension expansion. */ function expandDims$1(x, axis = -1) { const outShape = x.shape.slice(); if (axis < 0) { axis = outShape.length + axis + 1; } outShape.splice(axis, 0, 1); return tfc.reshape(x, outShape); } /** * Repeats a 2D tensor. * * If `x` has shape `[samples, dim]` and `n` is 2, for example, the output * will have shape `[samples, 2, dim]`. * * @param x Input tensor. * @param n Integer, number of times to repeat. * @returns The result of the repeat operation. * @throws ValueError: If input tensor is not 2D. */ function repeat(x, n) { return tidy$1(() => { if (x.shape.length !== 2) { throw new ValueError(`repeat() expects a rank-2 tensor, but received a ` + `rank-${x.shape.length} tensor.`); } const y = expandDims$1(x, 1); return tile$1(y, [1, n, 1]); }); } /** * Flatten a Tensor into 1D. * @param x Input tensor. * @return The result of the flattening `x`. */ function flatten$2(x) { const newShape = [arrayProd(x.shape)]; return tfc.reshape(x, newShape); } /** * Turn a nD tensor into a 2D tensor with same 0th dimension. * In other words, it flattens each data samples of a batch. * * @param x The tensor to flatten. The rank of this tensor is required to be 2 * or higher. * @return The result of the flattening. */ function batchFlatten(x) { if (x.rank <= 1) { throw new ValueError(`batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`); } const newShape = [x.shape[0], arrayProd(x.shape, 1)]; return tfc.reshape(x, newShape); } /** * Do slicing along the first axis. * @param array input `tf.Tensor`. * @param start starting index, inclusive. * @param size size of the slice along the first axis. * @returns result of the slicing. * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`. */ function sliceAlongFirstAxis(array, start, size) { return tidy$1(() => { switch (array.rank) { case 1: return tfc.slice1d(array, start, size); case 2: return tfc.slice2d(array, [start, 0], [size, array.shape[1]]); case 3: return tfc.slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]); case 4: return tfc.slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]); case 5: return tfc.slice(array, [start, 0, 0, 0, 0], [ size, array.shape[1], array.shape[2], array.shape[3], array.shape[4] ]); case 6: return tfc.slice(array, [start, 0, 0, 0, 0, 0], [ size, array.shape[1], array.shape[2], array.shape[3], array.shape[4], array.shape[5] ]); default: throw new ValueError(`sliceAlongFirstAxis() received an unsupported tensor rank: ` + `${array.rank}`); } }); } /** * Do slicing along the last axis. * @param array input `tf.Tensor`. * @param start starting index, inclusive. * @param size size of the slice along the last axis. * @returns result of the slicing. * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`. */ function sliceAlongLastAxis(array, start, size) { return tidy$1(() => { switch (array.rank) { case 1: return tfc.slice1d(array, start, size); case 2: return tfc.slice2d(array, [0, start], [array.shape[0], size]); case 3: return tfc.slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]); case 4: return tfc.slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]); default: throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` + `${array.rank}`); } }); } /** * Do slicing along the sepcified axis. * @param array input `tf.Tensor`. * @param start starting index, inclusive. * @param size of the slice along the chosen axis. * @param choose an axis. * @returns result of the slicing. * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`. */ function sliceAlongAxis(array, start, size, axis) { return tidy$1(() => { switch (array.rank) { case 1: return tfc.slice1d(array, start, size); case 2: switch (axis) { case 1: return sliceAlongFirstAxis(array, start, size); case 2: return sliceAlongLastAxis(array, start, size); default: throw new ValueError(`The axis is not within the rank of the tensor ` + `${axis}`); } case 3: switch (axis) { case 1: return sliceAlongFirstAxis(array, start, size); case 2: return tfc.slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]); case 3: return sliceAlongLastAxis(array, start, size); default: throw new ValueError(`The axis is not within the rank of the tensor ` + `${axis}`); } case 4: switch (axis) { case 1: return sliceAlongFirstAxis(array, start, size); case 2: return tfc.slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]); case 3: return tfc.slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]); case 4: return sliceAlongLastAxis(array, start, size); default: throw new ValueError(`The axis is not within the rank of the tensor ` + `${axis}`); } default: throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` + `${array.rank}`); } }); } /** * Concatenates a list of tensors alongside the specified axis. * @param tensors `Array` of tensors to concatenate. * @param axis Concatenation axis. * @returns The result of the concatenation. */ function concatenate$1(tensors, axis = -1) { let rank; if (axis < 0) { rank = tensors[0].rank; if (rank !== 0) { axis = rank; } else { axis = 0; } } if (axis === tensors[0].rank) { // Porting Note: This is necessary because tfc.concat() requires axis to be // in the interval [-rank, rank). axis = -1; } // Porting Note: Sparse concat is not supported yet. return tfc.concat(tensors, axis); } /** * Concatenate two arrays along the first dimension. * @param a The 1st `tf.Tensor` to concatenate. * @param b The 2nd `tf.Tensor` to concatenate. * @returns Result of the concatenation. * @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`. */ function concatAlongFirstAxis(a, b) { switch (a.rank) { case 1: return tfc.concat1d([a, b]); case 2: return tfc.concat2d([a, b], 0); case 3: return tfc.concat3d([a, b], 0); case 4: return tfc.concat4d([a, b], 0); default: throw new ValueError(`concatAlongFirstAxis() received an unsupported ` + `tensor rank: ${a.rank}`); } } /** * Creates a tensor by tiling `x` by `n`. * @param x A tensor. * @param n An Array of integers or a single integer. If an Array, the length * must be the same as the number of dimensions in `x`. If a single integer, * it will be treated as an Array of length 1. */ function tile$1(x, n) { if (!Array.isArray(n)) { n = [n]; } if (x.rank !== n.length) { throw new ValueError(`The length of input n (${n.length}) does not match ` + `the number of dimensions in input x (${x.rank})`); } return tfc.tile(x, n); } /* Creation of random tensors. */ /** * Get a tensor with normal distribution of values. * * @param shape Shape of the tensor. * @param mean mean value of the normal distribution. * @param stddev standard deviation of the normal distribution. * @param dtype * @param seed * @return The normal tensor. */ function randomNormal$1(shape, mean = 0.0, stddev = 1.0, dtype, seed) { return tfc.randomNormal(shape, mean, stddev, dtype, seed); } /* Linear Algebra */ /** * Multiply two tensors and returns the result as a tensor. * * For 2D tensors, this is equivalent to matrix multiplication (matMul). * For tensors of higher ranks, it follows the Theano behavior, * (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation: * * For N dimensions it is a sum product over the last axis of x and the * second-to-last of y: * * @param a A tensor of at least rank 2. * @param b A tensor of at least rank 2. * @param activation (optional) A string identifying the activation * function. * @return Result of the dot operation. */ function dot$1(a, b, activation, bias) { if ((a.rank < 2) || (b.rank < 2)) { throw new NotImplementedError(`dot requires both inputs to be rank >= 2` + ` but got x shape = ${a.shape} and y shape = ${b.shape}`); } if (b.rank >= 3) { const xLastDim = a.shape.slice(-1)[0]; const ySecondLastDim = b.shape.slice(-2)[0]; if (xLastDim !== ySecondLastDim) { throw new NotImplementedError(`If rank y >= 3, then the second last dim` + ` of y must equal the last dim of x but got x shape = ${a.shape} and ` + ` y shape = ${b.shape}`); } } // Handle basic 2D x 2D case. if ((a.rank === 2) && (b.rank === 2)) { const transposeA = false; const transposeB = false; // tfc.fused.matMul only fuses certain activation functions. Unsupported // activation functions are treated as 'linear' activations, which is // equivalent to a no-op. return tfc.fused.matMul({ a, b: b, transposeA, transposeB, bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null, activation }); } else { // Reshape x into the analogous 2D Tensor. const aFirstDims = a.shape.slice(); // Holds all but the last dim of x. const aLastDim = aFirstDims.pop(); a = tfc.reshape(a, [-1, aLastDim]); // Reshape y into the analogous 2D Tensor, and keep track of the // required dimensions to reproduce the output shape. const bShape = b.shape.slice(); const bLastDim = bShape.pop(); const ySecondLastDim = bShape.pop(); const yOtherDims = [...bShape, bLastDim]; // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1] // where r is the rank of y. const perm = Array.from({ length: b.rank }, (_, i) => { if (i === 0) { return b.rank - 2; } else if (i <= b.rank - 2) { return i - 1; } return i; }); b = tfc.reshape(tfc.transpose(b, perm), [ySecondLastDim, -1]); // Multiply x and y as 2D Tensors, and then reshape back to original. const outputShape = [...aFirstDims, ...yOtherDims]; const transposeA = false; const transposeB = false; return tfc.reshape(tfc.fused.matMul({ a, b, transposeA, transposeB, bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null, activation }), outputShape); } } /* Elementary math functions. */ /** * Retrieves the elements of indices `indices` in the tensor `reference`. * @param reference A tensor. * @param indices An integer tensor of indices or an `Array` of integers. * @param axis Axis along which to perform the gather operation. * @returns The result of the gathering as a tensor. */ function gather$1(reference, indices, axis) { return tidy$1(() => { if (Array.isArray(indices)) { indices = tensor1d(indices, 'int32'); } else { indices = tfc.cast(indices, 'int32'); } return tfc.gather(reference, indices, axis); }); } /** * Element-wise square. * @param x Input tensor. * @return element-wise x^2 */ function square$1(x) { return tfc.mul(x, x); } /** * Reshapes bias tensor according to rank of x. */ function reshapeBias(xRank, bias, dataFormat) { const biasShape = bias.shape; if (bias.rank !== 1 && bias.rank !== xRank) { throw new ValueError(`Unexpected bias dimensions: ${bias.rank}` + `; expected it to be 1 or ${xRank}`); } if (xRank === 5) { if (dataFormat === 'channelsFirst') { if (biasShape.length === 1) { return tfc.reshape(bias, [1, biasShape[0], 1, 1, 1]); } else { return tfc.reshape(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]); } } else if (dataFormat === 'channelsLast') { if (biasShape.length === 1) { return tfc.reshape(bias, [1, 1, 1, 1, biasShape[0]]); } else { return tfc.reshape(bias, [1].concat(biasShape)); } } } else if (xRank === 4) { if (dataFormat === 'channelsFirst') { if (biasShape.length === 1) { return tfc.reshape(bias, [1, biasShape[0], 1, 1]); } else { return tfc.reshape(bias, [1, biasShape[2], biasShape[0], biasShape[1]]); } } else if (dataFormat === 'channelsLast') { if (biasShape.length === 1) { return tfc.reshape(bias, [1, 1, 1, biasShape[0]]); } else { return tfc.reshape(bias, [1].concat(biasShape)); } } } else if (xRank === 3) { if (dataFormat === 'channelsFirst') { if (biasShape.length === 1) { return tfc.reshape(bias, [1, biasShape[0], 1]); } else { return tfc.reshape(bias, [1, biasShape[1], biasShape[0]]); } } else if (dataFormat === 'channelsLast') { if (biasShape.length === 1) { return tfc.reshape(bias, [1, 1, biasShape[0]]); } else { return tfc.reshape(bias, [1].concat(biasShape)); } } } else if (xRank < 3) { return bias; } throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`); } /* Neural-network operations. */ /** * Add a bias to a tensor. * * @param x The tensor to add the bias to. * @param bias The bias to add to `x`. Must be 1D or the same rank as `x`. * @return Result of the bias adding. * @throws ValueError: If the rank of `bias` is incorrect. */ function biasAdd(x, bias, dataFormat) { return tidy$1(() => { if (dataFormat == null) { dataFormat = imageDataFormat(); } checkDataFormat(dataFormat); return tfc.add(x, reshapeBias(x.rank, bias, dataFormat)); }); } /** * Exponential linear unit (ELU). * @param x A tensor or variable to compute the activation function for. * @param alpha: A scalar, a scaling factor for the negative section. * @return Output of the ELU operation. */ function elu$1(x, alpha = 1) { // TODO(cais): Add support for alpha values other than 1. if (alpha !== 1) { throw new NotImplementedError(`Support for alpha values other than 1 (${alpha}) is not implemented ` + `yet.`); } return tfc.elu(x); } /** * Softsign of a tensor. * * Defined as x / (abs(x) + 1), element-wise. * * @param x: Input. * @returns Output. */ function softsign(x) { return tidy$1(() => tfc.div(x, tfc.add(tfc.abs(x), 1))); } /** * Sets entries in `x` to zero at random, while scaling the entire tensor. * * @param x input tensor. * @param level fraction of the entries in the tensor that will be set to 0. * @param noiseShape shape of randomly generated keep/drop flags, must be * broadcastable to the shape of `x`. Optional. * @param seed random seed to ensure determinism. Optional. * @returns Result of the dropout operation. */ function dropout$1(x, level, noiseShape, seed) { return tidy$1(() => tfc.dropout(x, level, noiseShape, seed)); } /** * Element-wise, segment-wise linear approximation of sigmoid. * * Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`. * In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`. * * @param x Input tensor. * @returns Output tensor. */ function hardSigmoid(x) { return tidy$1(() => { const y = tfc.add(.5, tfc.mul(.2, x)); return tfc.clipByValue(y, 0, 1); }); } /** * Invoke `x` in the training phase, and `alt` otherwise. * * Porting Note: We do not create placeholder tensors for the `training` * boolean flag here, because there is no such thing in the TF.js imperative * backend. * * @param x The function to invoke iff `training` is `true`. * @param alt The function to invoke iff `training` is `false`. * @param training Boolean flag for whether training phase is active. * @returns The return value of `x()` if `training` is `true`, or the return * value of `alt()` if `training` is `false`. */ function inTrainPhase(x, alt, training = false) { return training ? x() : alt(); } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ const VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg']; const VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal']; /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ function checkFanMode(value) { checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value); } function checkDistribution(value) { checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value); } /** * Initializer base class. * * @doc { * heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'} */ class Initializer extends serialization.Serializable { fromConfigUsesCustomObjects() { return false; } getConfig() { return {}; } } class Zeros extends Initializer { apply(shape, dtype) { return zeros$2(shape, dtype); } } /** @nocollapse */ Zeros.className = 'Zeros'; serialization.registerClass(Zeros); class Ones extends Initializer { apply(shape, dtype) { return ones$3(shape, dtype); } } /** @nocollapse */ Ones.className = 'Ones'; serialization.registerClass(Ones); class Constant extends Initializer { constructor(args) { super(); if (typeof args !== 'object') { throw new ValueError(`Expected argument of type ConstantConfig but got ${args}`); } if (args.value === undefined) { throw new ValueError(`config must have value set but got ${args}`); } this.value = args.value; } apply(shape, dtype) { return tidy$1(() => mul$1(scalar$1(this.value), ones$3(shape, dtype))); } getConfig() { return { value: this.value, }; } } /** @nocollapse */ Constant.className = 'Constant'; serialization.registerClass(Constant); class RandomUniform extends Initializer { constructor(args) { super(); this.DEFAULT_MINVAL = -0.05; this.DEFAULT_MAXVAL = 0.05; this.minval = args.minval || this.DEFAULT_MINVAL; this.maxval = args.maxval || this.DEFAULT_MAXVAL; this.seed = args.seed; }