@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
1,499 lines (1,486 loc) • 1.28 MB
JavaScript
/**
* @license
* Copyright 2024 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tfc from '@tensorflow/tfjs-core';
import { util, backend, tidy as tidy$1, tensor1d, serialization, zeros as zeros$2, ones as ones$3, mul as mul$1, scalar as scalar$1, randomUniform as randomUniform$1, truncatedNormal as truncatedNormal$1, eye, linalg, dispose, memory, cast as cast$2, env as env$1, nextFrame, add as add$3, div as div$1, keep, train, clone as clone$1, argMax, reshape as reshape$2, Tensor as Tensor$1, Optimizer, io, sum as sum$1, abs, relu, clipByValue, leakyRelu, prelu as prelu$1, elu as elu$2, greater as greater$2, sub as sub$1, exp as exp$1, logSumExp, transpose as transpose$1, any, notEqual, zerosLike as zerosLike$1, greaterEqual as greaterEqual$1, moments, stack as stack$1, tensor, range as range$1, unstack as unstack$1, image, expandDims as expandDims$2, denseBincount, max as max$1, min as min$1 } from '@tensorflow/tfjs-core';
function _mergeNamespaces(n, m) {
m.forEach(function (e) {
e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) {
if (k !== 'default' && !(k in n)) {
var d = Object.getOwnPropertyDescriptor(e, k);
Object.defineProperty(n, k, d.get ? d : {
enumerable: true,
get: function () { return e[k]; }
});
}
});
});
return n;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Explicit error types.
*
* See the following link for more information about why the code includes
* calls to setPrototypeOf:
*
* https://github.com/Microsoft/TypeScript-wiki/blob/master/Breaking-Changes.md#extending-built-ins-like-error-array-and-map-may-no-longer-work
*/
// tslint:enable
/**
* Equivalent of Python's AttributeError.
*/
class AttributeError extends Error {
constructor(message) {
super(message);
// Set the prototype explicitly.
Object.setPrototypeOf(this, AttributeError.prototype);
}
}
/**
* Equivalent of Python's RuntimeError.
*/
class RuntimeError extends Error {
constructor(message) {
super(message);
// Set the prototype explicitly.
Object.setPrototypeOf(this, RuntimeError.prototype);
}
}
/**
* Equivalent of Python's ValueError.
*/
class ValueError extends Error {
constructor(message) {
super(message);
// Set the prototype explicitly.
Object.setPrototypeOf(this, ValueError.prototype);
}
}
/**
* Equivalent of Python's NotImplementedError.
*/
class NotImplementedError extends Error {
constructor(message) {
super(message);
// Set the prototype explicitly.
Object.setPrototypeOf(this, NotImplementedError.prototype);
}
}
/**
* Equivalent of Python's AssertionError.
*/
class AssertionError extends Error {
constructor(message) {
super(message);
// Set the prototype explicitly.
Object.setPrototypeOf(this, AssertionError.prototype);
}
}
/**
* @license
* Copyright 2022 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* LruCache: A mapping from the String to T. If the number of the entries is
* exceeding the `maxEntries`, the LruCache will delete the least recently
* used entry.
*/
class LruCache {
constructor(maxEntries) {
this.maxEntries = maxEntries || 100;
this.cache = new Map();
}
/**
* Get the entry for the key and mark it as used recently.
*/
get(key) {
let entry;
if (this.cache.has(key)) {
entry = this.cache.get(key);
this.cache.delete(key);
this.cache.set(key, entry);
}
return entry;
}
/**
* Put the entry into the cache. If the key already existed, mark the key as
* used recently.
*/
put(key, value) {
if (this.cache.has(key)) {
this.cache.delete(key);
}
else if (this.cache.size >= this.maxEntries) {
const keyToDelete = this.cache.keys().next().value;
this.cache.delete(keyToDelete);
}
this.cache.set(key, value);
}
/**
* Get the MaxEntries of the cache.
*/
getMaxEntries() {
return this.maxEntries;
}
/**
* Set the MaxEntries of the cache. If the maxEntries is decreased, reduce
* entries in the cache.
*/
setMaxEntries(maxEntries) {
if (maxEntries < 0) {
throw new Error(`The maxEntries of LRU caches must be at least 0, but got ${maxEntries}.`);
}
if (this.maxEntries > maxEntries) {
for (let i = 0; i < this.maxEntries - maxEntries; i++) {
const keyToDelete = this.cache.keys().next().value;
this.cache.delete(keyToDelete);
}
}
this.maxEntries = maxEntries;
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
// tslint:enable
/**
* If `value` is an Array, equivalent to Python's `value * numValues`.
* If `value` is not an Array, equivalent to Python's `[value] * numValues`
*/
// tslint:disable-next-line:no-any
function pyListRepeat(value, numValues) {
if (Array.isArray(value)) {
// tslint:disable-next-line:no-any
let newArray = [];
for (let i = 0; i < numValues; i++) {
newArray = newArray.concat(value);
}
return newArray;
}
else {
const newArray = new Array(numValues);
newArray.fill(value);
return newArray;
}
}
function assert$1(val, message) {
if (!val) {
throw new AssertionError(message);
}
}
/**
* Count the number of elements of the `array` that are equal to `reference`.
*/
function count(array, refernce) {
let counter = 0;
for (const item of array) {
if (item === refernce) {
counter++;
}
}
return counter;
}
/**
* If an array is of length 1, just return the first element. Otherwise, return
* the full array.
* @param tensors
*/
function singletonOrArray(xs) {
if (xs.length === 1) {
return xs[0];
}
return xs;
}
/**
* Normalizes a list/tensor into a list.
*
* If a tensor is passed, we return
* a list of size 1 containing the tensor.
*
* @param x target object to be normalized.
*/
// tslint:disable-next-line:no-any
function toList(x) {
if (Array.isArray(x)) {
return x;
}
return [x];
}
/**
* Converts string to snake-case.
* @param name
*/
function toSnakeCase(name) {
const intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');
const insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();
/*
If the class is private the name starts with "_" which is not secure
for creating scopes. We prefix the name with "private" in this case.
*/
if (insecure[0] !== '_') {
return insecure;
}
return 'private' + insecure;
}
function toCamelCase(identifier) {
// quick return for empty string or single character strings
if (identifier.length <= 1) {
return identifier;
}
// Check for the underscore indicating snake_case
if (identifier.indexOf('_') === -1) {
return identifier;
}
return identifier.replace(/[_]+(\w|$)/g, (m, p1) => p1.toUpperCase());
}
// tslint:disable-next-line:no-any
let _GLOBAL_CUSTOM_OBJECTS = {};
function serializeKerasObject(instance) {
if (instance === null || instance === undefined) {
return null;
}
const dict = {};
dict['className'] = instance.getClassName();
dict['config'] = instance.getConfig();
return dict;
}
/**
* Replace ndarray-style scalar objects in serialization objects with numbers.
*
* Background: In some versions of tf.keras, certain scalar values in the HDF5
* model save file can be serialized as: `{'type': 'ndarray', 'value': num}`,
* where in `num` is a plain number. This method converts such serialization
* to a `number`.
*
* @param config The keras-format serialization object to be processed
* (in place).
*/
function convertNDArrayScalarsInConfig(config) {
if (config == null || typeof config !== 'object') {
return;
}
else if (Array.isArray(config)) {
config.forEach(configItem => convertNDArrayScalarsInConfig(configItem));
}
else {
const fields = Object.keys(config);
for (const field of fields) {
const value = config[field];
if (value != null && typeof value === 'object') {
if (!Array.isArray(value) && value['type'] === 'ndarray' &&
typeof value['value'] === 'number') {
config[field] = value['value'];
}
else {
convertNDArrayScalarsInConfig(value);
}
}
}
}
}
/**
* Deserialize a saved Keras Object
* @param identifier either a string ID or a saved Keras dictionary
* @param moduleObjects a list of Python class names to object constructors
* @param customObjects a list of Python class names to object constructors
* @param printableModuleName debug text for the object being reconstituted
* @param fastWeightInit Optional flag to use fast weight initialization
* during deserialization. This is applicable to cases in which
* the initialization will be immediately overwritten by loaded weight
* values. Default: `false`.
* @returns a TensorFlow.js Layers object
*/
// tslint:disable:no-any
function deserializeKerasObject(identifier, moduleObjects = {}, customObjects = {}, printableModuleName = 'object', fastWeightInit = false) {
// tslint:enable
if (typeof identifier === 'string') {
const functionName = identifier;
let fn;
if (functionName in customObjects) {
fn = customObjects[functionName];
}
else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {
fn = _GLOBAL_CUSTOM_OBJECTS[functionName];
}
else {
fn = moduleObjects[functionName];
if (fn == null) {
throw new ValueError(`Unknown ${printableModuleName}: ${identifier}. ` +
`This may be due to one of the following reasons:\n` +
`1. The ${printableModuleName} is defined in Python, in which ` +
`case it needs to be ported to TensorFlow.js or your JavaScript ` +
`code.\n` +
`2. The custom ${printableModuleName} is defined in JavaScript, ` +
`but is not registered properly with ` +
`tf.serialization.registerClass().`);
// TODO(cais): Add link to tutorial page on custom layers.
}
}
return fn;
}
else {
// In this case we are dealing with a Keras config dictionary.
const config = identifier;
if (config['className'] == null || config['config'] == null) {
throw new ValueError(`${printableModuleName}: Improper config format: ` +
`${JSON.stringify(config)}.\n` +
`'className' and 'config' must set.`);
}
const className = config['className'];
let cls, fromConfig;
if (className in customObjects) {
[cls, fromConfig] = customObjects[className];
}
else if (className in _GLOBAL_CUSTOM_OBJECTS) {
[cls, fromConfig] = _GLOBAL_CUSTOM_OBJECTS['className'];
}
else if (className in moduleObjects) {
[cls, fromConfig] = moduleObjects[className];
}
if (cls == null) {
throw new ValueError(`Unknown ${printableModuleName}: ${className}. ` +
`This may be due to one of the following reasons:\n` +
`1. The ${printableModuleName} is defined in Python, in which ` +
`case it needs to be ported to TensorFlow.js or your JavaScript ` +
`code.\n` +
`2. The custom ${printableModuleName} is defined in JavaScript, ` +
`but is not registered properly with ` +
`tf.serialization.registerClass().`);
// TODO(cais): Add link to tutorial page on custom layers.
}
if (fromConfig != null) {
// Porting notes: Instead of checking to see whether fromConfig accepts
// customObjects, we create a customObjects dictionary and tack it on to
// config['config'] as config['config'].customObjects. Objects can use it,
// if they want.
// tslint:disable-next-line:no-any
const customObjectsCombined = {};
for (const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)) {
customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];
}
for (const key of Object.keys(customObjects)) {
customObjectsCombined[key] = customObjects[key];
}
// Add the customObjects to config
const nestedConfig = config['config'];
nestedConfig['customObjects'] = customObjectsCombined;
const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
for (const key of Object.keys(customObjects)) {
_GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
}
convertNDArrayScalarsInConfig(config['config']);
const returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit);
_GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
return returnObj;
}
else {
// Then `cls` may be a function returning a class.
// In this case by convention `config` holds
// the kwargs of the function.
const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
for (const key of Object.keys(customObjects)) {
_GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
}
// In python this is **config['config'], for tfjs-layers we require
// classes that use this fall-through construction method to take
// a config interface that mimics the expansion of named parameters.
const returnObj = new cls(config['config']);
_GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
return returnObj;
}
}
}
/**
* Compares two numbers for sorting.
* @param a
* @param b
*/
function numberCompare(a, b) {
return (a < b) ? -1 : ((a > b) ? 1 : 0);
}
/**
* Comparison of two numbers for reverse sorting.
* @param a
* @param b
*/
function reverseNumberCompare(a, b) {
return -1 * numberCompare(a, b);
}
/**
* Get the unique elements of an array.
* @param xs Array.
* @returns An Array consisting of the unique elements in `xs`.
*/
function unique(xs) {
if (xs == null) {
return xs;
}
const out = [];
// TODO(cais): Maybe improve performance by sorting.
for (const x of xs) {
if (out.indexOf(x) === -1) {
out.push(x);
}
}
return out;
}
/**
* Determine if an Object is empty (i.e., does not have own properties).
* @param obj Object
* @returns Whether the Object is empty.
* @throws ValueError: If object is `null` or `undefined`.
*/
function isObjectEmpty(obj) {
if (obj == null) {
throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`);
}
for (const key in obj) {
if (obj.hasOwnProperty(key)) {
return false;
}
}
return true;
}
/**
* Helper function used to build type union/enum run-time checkers.
* @param values The list of allowed values.
* @param label A string name for the type
* @param value The value to test.
* @throws ValueError: If the value is not in values nor `undefined`/`null`.
*/
function checkStringTypeUnionValue(values, label, value) {
if (value == null) {
return;
}
if (values.indexOf(value) < 0) {
throw new ValueError(`${value} is not a valid ${label}. Valid values are ${values} or null/undefined.`);
}
}
/**
* Helper function for verifying the types of inputs.
*
* Ensures that the elements of `x` are all of type `expectedType`.
* Also verifies that the length of `x` is within bounds.
*
* @param x Object to test.
* @param expectedType The string expected type of all of the elements in the
* Array.
* @param minLength Return false if x.length is less than this.
* @param maxLength Return false if x.length is greater than this.
* @returns true if and only if `x` is an `Array<expectedType>` with
* length >= `minLength` and <= `maxLength`.
*/
// tslint:disable:no-any
function checkArrayTypeAndLength(x, expectedType, minLength = 0, maxLength = Infinity) {
assert$1(minLength >= 0);
assert$1(maxLength >= minLength);
return (Array.isArray(x) && x.length >= minLength && x.length <= maxLength &&
x.every(e => typeof e === expectedType));
}
// tslint:enable:no-any
/**
* Assert that a value or an array of value are positive integer.
*
* @param value The value being asserted on. May be a single number or an array
* of numbers.
* @param name Name of the value, used to make the error message.
*/
function assertPositiveInteger(value, name) {
if (Array.isArray(value)) {
util.assert(value.length > 0, () => `${name} is unexpectedly an empty array.`);
value.forEach((v, i) => assertPositiveInteger(v, `element ${i + 1} of ${name}`));
}
else {
util.assert(Number.isInteger(value) && value > 0, () => `Expected ${name} to be a positive integer, but got ` +
`${formatAsFriendlyString(value)}.`);
}
}
/**
* Format a value into a display-friendly, human-readable fashion.
*
* - `null` is formatted as `'null'`
* - Strings are formated with flanking pair of quotes.
* - Arrays are formatted with flanking pair of square brackets.
*
* @param value The value to display.
* @return Formatted string.
*/
// tslint:disable-next-line:no-any
function formatAsFriendlyString(value) {
if (value === null) {
return 'null';
}
else if (Array.isArray(value)) {
return '[' + value.map(v => formatAsFriendlyString(v)).join(',') + ']';
}
else if (typeof value === 'string') {
return `"${value}"`;
}
else {
return `${value}`;
}
}
/**
* Returns a function `f2` (decorator) which wraps the original function
* `f`. `f2` guarantees that `f` can be called at most once
* every `waitMs` ms. If `f2` is called more often, it will return
* the last returned result of `f`.
*
* @param f The original function `f` to wrap.
* @param waitMs The time between two consecutive calls to `f` in ms.
*/
function debounce(f, waitMs, nowFunc) {
let lastTime = nowFunc != null ? nowFunc() : util.now();
let lastResult;
const f2 = (...args) => {
const now = nowFunc != null ? nowFunc() : util.now();
if (now - lastTime < waitMs) {
return lastResult;
}
lastTime = now;
lastResult = f(...args);
return lastResult;
};
return f2;
}
/**
* Returns the fusable activation given a layers identifier.
*
* @param activationName The layers identifier string.
* @return The name of the fusable activation.
*/
function mapActivationToFusedKernel(activationName) {
if (activationName === 'relu') {
return 'relu';
}
if (activationName === 'linear') {
return 'linear';
}
if (activationName === 'elu') {
return 'elu';
}
return null;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Utilities related to persistent state in the backend.
*/
/**
* An ID to track `tf.SymbolicTensor`s and derived classes.
* Required in different places in engine/topology.ts to identify unique
* tensors.
*/
let _nextUniqueTensorId = 0;
function getNextUniqueTensorId() {
return _nextUniqueTensorId++;
}
const _uidPrefixes = {};
/**
* Provides a unique UID given a string prefix.
*
* @param prefix
*/
function getUid(prefix = '') {
if (!(prefix in _uidPrefixes)) {
_uidPrefixes[prefix] = 0;
}
_uidPrefixes[prefix] += 1;
return prefix + _uidPrefixes[prefix].toString();
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
const VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast'];
const VALID_INTERPOLATION_FORMAT_VALUES = ['nearest', 'bilinear'];
const VALID_PADDING_MODE_VALUES = ['valid', 'same', 'causal'];
const VALID_POOL_MODE_VALUES = ['max', 'avg'];
const VALID_BIDIRECTIONAL_MERGE_MODES = ['sum', 'mul', 'concat', 'ave'];
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
// A map from the requested scoped name of a Tensor to the number of Tensors
// wanting that name so far. This allows enforcing name uniqueness by appending
// an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.
const nameMap = new Map();
function checkDataFormat(value) {
checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
}
function checkInterpolationFormat(value) {
checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value);
}
function checkPaddingMode(value) {
checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);
}
function checkPoolMode(value) {
checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);
}
const _nameScopeStack = [];
const _nameScopeDivider = '/';
/**
* Enter namescope, which can be nested.
*/
function nameScope(name, fn) {
_nameScopeStack.push(name);
try {
const val = fn();
_nameScopeStack.pop();
return val;
}
catch (e) {
_nameScopeStack.pop();
throw e;
}
}
/**
* Get the current namescope as a flat, concatenated string.
*/
function currentNameScopePrefix() {
if (_nameScopeStack.length === 0) {
return '';
}
else {
return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
}
}
/**
* Get the name a Tensor (or Variable) would have if not uniqueified.
* @param tensorName
* @return Scoped name string.
*/
function getScopedTensorName(tensorName) {
if (!isValidTensorName(tensorName)) {
throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
}
return currentNameScopePrefix() + tensorName;
}
/**
* Get unique names for Tensors and Variables.
* @param scopedName The fully-qualified name of the Tensor, i.e. as produced by
* `getScopedTensorName()`.
* @return A unique version of the given fully scoped name.
* If this is the first time that the scoped name is seen in this session,
* then the given `scopedName` is returned unaltered. If the same name is
* seen again (producing a collision), an incrementing suffix is added to the
* end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.
*/
function getUniqueTensorName(scopedName) {
if (!isValidTensorName(scopedName)) {
throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
}
if (!nameMap.has(scopedName)) {
nameMap.set(scopedName, 0);
}
const index = nameMap.get(scopedName);
nameMap.set(scopedName, nameMap.get(scopedName) + 1);
if (index > 0) {
const result = `${scopedName}_${index}`;
// Mark the composed name as used in case someone wants
// to call getUniqueTensorName("name_1").
nameMap.set(result, 1);
return result;
}
else {
return scopedName;
}
}
const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
/**
* Determine whether a string is a valid tensor name.
* @param name
* @returns A Boolean indicating whether `name` is a valid tensor name.
*/
function isValidTensorName(name) {
return !!name.match(tensorNameRegex);
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Determine if a number is an integer.
*/
function isInteger(x) {
return x === parseInt(x.toString(), 10);
}
/**
* Calculate the product of an array of numbers.
* @param array The array to calculate the product over.
* @param begin Beginning index, inclusive.
* @param end Ending index, exclusive.
* @return The product.
*/
function arrayProd(array, begin, end) {
if (begin == null) {
begin = 0;
}
if (end == null) {
end = array.length;
}
let prod = 1;
for (let i = begin; i < end; ++i) {
prod *= array[i];
}
return prod;
}
/**
* Compute minimum value.
* @param array
* @return minimum value.
*/
function min(array) {
// same behavior as tf.min()
if (array.length === 0) {
return Number.NaN;
}
let min = Number.POSITIVE_INFINITY;
for (let i = 0; i < array.length; i++) {
const value = array[i];
if (value < min) {
min = value;
}
}
return min;
}
/**
* Compute maximum value.
* @param array
* @return maximum value
*/
function max(array) {
// same behavior as tf.max()
if (array.length === 0) {
return Number.NaN;
}
let max = Number.NEGATIVE_INFINITY;
for (let i = 0; i < array.length; i++) {
const value = array[i];
if (value > max) {
max = value;
}
}
return max;
}
/**
* Generate an array of integers in [begin, end).
* @param begin Beginning integer, inclusive.
* @param end Ending integer, exclusive.
* @returns Range array.
* @throws ValueError, iff `end` < `begin`.
*/
function range(begin, end) {
if (end < begin) {
throw new ValueError(`end (${end}) < begin (${begin}) is forbidden.`);
}
const out = [];
for (let i = begin; i < end; ++i) {
out.push(i);
}
return out;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
let _epsilon;
/**
* Returns the value of the fuzz factor used in numeric expressions.
*/
function epsilon() {
if (_epsilon == null) {
_epsilon = backend().epsilon();
}
return _epsilon;
}
/**
* Returns the default image data format convention.
*/
function imageDataFormat() {
return 'channelsLast';
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Casts a tensor to a different dtype and returns it.
* @param x Input tensor.
* @param dtype String: 'float32'|'int32'|'bool'.
* @returns Tensor of the specified `dtype`.
*/
function cast$1(x, dtype) {
return tfc.cast(x, dtype);
}
/**
* Adds a 1-sized dimension at index "axis".
* @param x Input tensor.
* @param axis Position where to add the new axis.
* @returns Result of the dimension expansion.
*/
function expandDims$1(x, axis = -1) {
const outShape = x.shape.slice();
if (axis < 0) {
axis = outShape.length + axis + 1;
}
outShape.splice(axis, 0, 1);
return tfc.reshape(x, outShape);
}
/**
* Repeats a 2D tensor.
*
* If `x` has shape `[samples, dim]` and `n` is 2, for example, the output
* will have shape `[samples, 2, dim]`.
*
* @param x Input tensor.
* @param n Integer, number of times to repeat.
* @returns The result of the repeat operation.
* @throws ValueError: If input tensor is not 2D.
*/
function repeat(x, n) {
return tidy$1(() => {
if (x.shape.length !== 2) {
throw new ValueError(`repeat() expects a rank-2 tensor, but received a ` +
`rank-${x.shape.length} tensor.`);
}
const y = expandDims$1(x, 1);
return tile$1(y, [1, n, 1]);
});
}
/**
* Flatten a Tensor into 1D.
* @param x Input tensor.
* @return The result of the flattening `x`.
*/
function flatten$2(x) {
const newShape = [arrayProd(x.shape)];
return tfc.reshape(x, newShape);
}
/**
* Turn a nD tensor into a 2D tensor with same 0th dimension.
* In other words, it flattens each data samples of a batch.
*
* @param x The tensor to flatten. The rank of this tensor is required to be 2
* or higher.
* @return The result of the flattening.
*/
function batchFlatten(x) {
if (x.rank <= 1) {
throw new ValueError(`batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`);
}
const newShape = [x.shape[0], arrayProd(x.shape, 1)];
return tfc.reshape(x, newShape);
}
/**
* Do slicing along the first axis.
* @param array input `tf.Tensor`.
* @param start starting index, inclusive.
* @param size size of the slice along the first axis.
* @returns result of the slicing.
* @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
*/
function sliceAlongFirstAxis(array, start, size) {
return tidy$1(() => {
switch (array.rank) {
case 1:
return tfc.slice1d(array, start, size);
case 2:
return tfc.slice2d(array, [start, 0], [size, array.shape[1]]);
case 3:
return tfc.slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
case 4:
return tfc.slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
case 5:
return tfc.slice(array, [start, 0, 0, 0, 0], [
size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]
]);
case 6:
return tfc.slice(array, [start, 0, 0, 0, 0, 0], [
size, array.shape[1], array.shape[2], array.shape[3], array.shape[4],
array.shape[5]
]);
default:
throw new ValueError(`sliceAlongFirstAxis() received an unsupported tensor rank: ` +
`${array.rank}`);
}
});
}
/**
* Do slicing along the last axis.
* @param array input `tf.Tensor`.
* @param start starting index, inclusive.
* @param size size of the slice along the last axis.
* @returns result of the slicing.
* @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
*/
function sliceAlongLastAxis(array, start, size) {
return tidy$1(() => {
switch (array.rank) {
case 1:
return tfc.slice1d(array, start, size);
case 2:
return tfc.slice2d(array, [0, start], [array.shape[0], size]);
case 3:
return tfc.slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]);
case 4:
return tfc.slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]);
default:
throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
`${array.rank}`);
}
});
}
/**
* Do slicing along the sepcified axis.
* @param array input `tf.Tensor`.
* @param start starting index, inclusive.
* @param size of the slice along the chosen axis.
* @param choose an axis.
* @returns result of the slicing.
* @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
*/
function sliceAlongAxis(array, start, size, axis) {
return tidy$1(() => {
switch (array.rank) {
case 1:
return tfc.slice1d(array, start, size);
case 2:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError(`The axis is not within the rank of the tensor ` +
`${axis}`);
}
case 3:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return tfc.slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]);
case 3:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError(`The axis is not within the rank of the tensor ` +
`${axis}`);
}
case 4:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return tfc.slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]);
case 3:
return tfc.slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]);
case 4:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError(`The axis is not within the rank of the tensor ` +
`${axis}`);
}
default:
throw new ValueError(`sliceAlongLastAxis() received an unsupported tensor rank: ` +
`${array.rank}`);
}
});
}
/**
* Concatenates a list of tensors alongside the specified axis.
* @param tensors `Array` of tensors to concatenate.
* @param axis Concatenation axis.
* @returns The result of the concatenation.
*/
function concatenate$1(tensors, axis = -1) {
let rank;
if (axis < 0) {
rank = tensors[0].rank;
if (rank !== 0) {
axis = rank;
}
else {
axis = 0;
}
}
if (axis === tensors[0].rank) {
// Porting Note: This is necessary because tfc.concat() requires axis to be
// in the interval [-rank, rank).
axis = -1;
}
// Porting Note: Sparse concat is not supported yet.
return tfc.concat(tensors, axis);
}
/**
* Concatenate two arrays along the first dimension.
* @param a The 1st `tf.Tensor` to concatenate.
* @param b The 2nd `tf.Tensor` to concatenate.
* @returns Result of the concatenation.
* @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`.
*/
function concatAlongFirstAxis(a, b) {
switch (a.rank) {
case 1:
return tfc.concat1d([a, b]);
case 2:
return tfc.concat2d([a, b], 0);
case 3:
return tfc.concat3d([a, b], 0);
case 4:
return tfc.concat4d([a, b], 0);
default:
throw new ValueError(`concatAlongFirstAxis() received an unsupported ` +
`tensor rank: ${a.rank}`);
}
}
/**
* Creates a tensor by tiling `x` by `n`.
* @param x A tensor.
* @param n An Array of integers or a single integer. If an Array, the length
* must be the same as the number of dimensions in `x`. If a single integer,
* it will be treated as an Array of length 1.
*/
function tile$1(x, n) {
if (!Array.isArray(n)) {
n = [n];
}
if (x.rank !== n.length) {
throw new ValueError(`The length of input n (${n.length}) does not match ` +
`the number of dimensions in input x (${x.rank})`);
}
return tfc.tile(x, n);
}
/* Creation of random tensors. */
/**
* Get a tensor with normal distribution of values.
*
* @param shape Shape of the tensor.
* @param mean mean value of the normal distribution.
* @param stddev standard deviation of the normal distribution.
* @param dtype
* @param seed
* @return The normal tensor.
*/
function randomNormal$1(shape, mean = 0.0, stddev = 1.0, dtype, seed) {
return tfc.randomNormal(shape, mean, stddev, dtype, seed);
}
/* Linear Algebra */
/**
* Multiply two tensors and returns the result as a tensor.
*
* For 2D tensors, this is equivalent to matrix multiplication (matMul).
* For tensors of higher ranks, it follows the Theano behavior,
* (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation:
*
* For N dimensions it is a sum product over the last axis of x and the
* second-to-last of y:
*
* @param a A tensor of at least rank 2.
* @param b A tensor of at least rank 2.
* @param activation (optional) A string identifying the activation
* function.
* @return Result of the dot operation.
*/
function dot$1(a, b, activation, bias) {
if ((a.rank < 2) || (b.rank < 2)) {
throw new NotImplementedError(`dot requires both inputs to be rank >= 2` +
` but got x shape = ${a.shape} and y shape = ${b.shape}`);
}
if (b.rank >= 3) {
const xLastDim = a.shape.slice(-1)[0];
const ySecondLastDim = b.shape.slice(-2)[0];
if (xLastDim !== ySecondLastDim) {
throw new NotImplementedError(`If rank y >= 3, then the second last dim` +
` of y must equal the last dim of x but got x shape = ${a.shape} and ` +
` y shape = ${b.shape}`);
}
}
// Handle basic 2D x 2D case.
if ((a.rank === 2) && (b.rank === 2)) {
const transposeA = false;
const transposeB = false;
// tfc.fused.matMul only fuses certain activation functions. Unsupported
// activation functions are treated as 'linear' activations, which is
// equivalent to a no-op.
return tfc.fused.matMul({
a,
b: b,
transposeA,
transposeB,
bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
activation
});
}
else {
// Reshape x into the analogous 2D Tensor.
const aFirstDims = a.shape.slice(); // Holds all but the last dim of x.
const aLastDim = aFirstDims.pop();
a = tfc.reshape(a, [-1, aLastDim]);
// Reshape y into the analogous 2D Tensor, and keep track of the
// required dimensions to reproduce the output shape.
const bShape = b.shape.slice();
const bLastDim = bShape.pop();
const ySecondLastDim = bShape.pop();
const yOtherDims = [...bShape, bLastDim];
// permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1]
// where r is the rank of y.
const perm = Array.from({ length: b.rank }, (_, i) => {
if (i === 0) {
return b.rank - 2;
}
else if (i <= b.rank - 2) {
return i - 1;
}
return i;
});
b = tfc.reshape(tfc.transpose(b, perm), [ySecondLastDim, -1]);
// Multiply x and y as 2D Tensors, and then reshape back to original.
const outputShape = [...aFirstDims, ...yOtherDims];
const transposeA = false;
const transposeB = false;
return tfc.reshape(tfc.fused.matMul({
a,
b,
transposeA,
transposeB,
bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
activation
}), outputShape);
}
}
/* Elementary math functions. */
/**
* Retrieves the elements of indices `indices` in the tensor `reference`.
* @param reference A tensor.
* @param indices An integer tensor of indices or an `Array` of integers.
* @param axis Axis along which to perform the gather operation.
* @returns The result of the gathering as a tensor.
*/
function gather$1(reference, indices, axis) {
return tidy$1(() => {
if (Array.isArray(indices)) {
indices = tensor1d(indices, 'int32');
}
else {
indices = tfc.cast(indices, 'int32');
}
return tfc.gather(reference, indices, axis);
});
}
/**
* Element-wise square.
* @param x Input tensor.
* @return element-wise x^2
*/
function square$1(x) {
return tfc.mul(x, x);
}
/**
* Reshapes bias tensor according to rank of x.
*/
function reshapeBias(xRank, bias, dataFormat) {
const biasShape = bias.shape;
if (bias.rank !== 1 && bias.rank !== xRank) {
throw new ValueError(`Unexpected bias dimensions: ${bias.rank}` +
`; expected it to be 1 or ${xRank}`);
}
if (xRank === 5) {
if (dataFormat === 'channelsFirst') {
if (biasShape.length === 1) {
return tfc.reshape(bias, [1, biasShape[0], 1, 1, 1]);
}
else {
return tfc.reshape(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
}
}
else if (dataFormat === 'channelsLast') {
if (biasShape.length === 1) {
return tfc.reshape(bias, [1, 1, 1, 1, biasShape[0]]);
}
else {
return tfc.reshape(bias, [1].concat(biasShape));
}
}
}
else if (xRank === 4) {
if (dataFormat === 'channelsFirst') {
if (biasShape.length === 1) {
return tfc.reshape(bias, [1, biasShape[0], 1, 1]);
}
else {
return tfc.reshape(bias, [1, biasShape[2], biasShape[0], biasShape[1]]);
}
}
else if (dataFormat === 'channelsLast') {
if (biasShape.length === 1) {
return tfc.reshape(bias, [1, 1, 1, biasShape[0]]);
}
else {
return tfc.reshape(bias, [1].concat(biasShape));
}
}
}
else if (xRank === 3) {
if (dataFormat === 'channelsFirst') {
if (biasShape.length === 1) {
return tfc.reshape(bias, [1, biasShape[0], 1]);
}
else {
return tfc.reshape(bias, [1, biasShape[1], biasShape[0]]);
}
}
else if (dataFormat === 'channelsLast') {
if (biasShape.length === 1) {
return tfc.reshape(bias, [1, 1, biasShape[0]]);
}
else {
return tfc.reshape(bias, [1].concat(biasShape));
}
}
}
else if (xRank < 3) {
return bias;
}
throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`);
}
/* Neural-network operations. */
/**
* Add a bias to a tensor.
*
* @param x The tensor to add the bias to.
* @param bias The bias to add to `x`. Must be 1D or the same rank as `x`.
* @return Result of the bias adding.
* @throws ValueError: If the rank of `bias` is incorrect.
*/
function biasAdd(x, bias, dataFormat) {
return tidy$1(() => {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
return tfc.add(x, reshapeBias(x.rank, bias, dataFormat));
});
}
/**
* Exponential linear unit (ELU).
* @param x A tensor or variable to compute the activation function for.
* @param alpha: A scalar, a scaling factor for the negative section.
* @return Output of the ELU operation.
*/
function elu$1(x, alpha = 1) {
// TODO(cais): Add support for alpha values other than 1.
if (alpha !== 1) {
throw new NotImplementedError(`Support for alpha values other than 1 (${alpha}) is not implemented ` +
`yet.`);
}
return tfc.elu(x);
}
/**
* Softsign of a tensor.
*
* Defined as x / (abs(x) + 1), element-wise.
*
* @param x: Input.
* @returns Output.
*/
function softsign(x) {
return tidy$1(() => tfc.div(x, tfc.add(tfc.abs(x), 1)));
}
/**
* Sets entries in `x` to zero at random, while scaling the entire tensor.
*
* @param x input tensor.
* @param level fraction of the entries in the tensor that will be set to 0.
* @param noiseShape shape of randomly generated keep/drop flags, must be
* broadcastable to the shape of `x`. Optional.
* @param seed random seed to ensure determinism. Optional.
* @returns Result of the dropout operation.
*/
function dropout$1(x, level, noiseShape, seed) {
return tidy$1(() => tfc.dropout(x, level, noiseShape, seed));
}
/**
* Element-wise, segment-wise linear approximation of sigmoid.
*
* Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
* In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
*
* @param x Input tensor.
* @returns Output tensor.
*/
function hardSigmoid(x) {
return tidy$1(() => {
const y = tfc.add(.5, tfc.mul(.2, x));
return tfc.clipByValue(y, 0, 1);
});
}
/**
* Invoke `x` in the training phase, and `alt` otherwise.
*
* Porting Note: We do not create placeholder tensors for the `training`
* boolean flag here, because there is no such thing in the TF.js imperative
* backend.
*
* @param x The function to invoke iff `training` is `true`.
* @param alt The function to invoke iff `training` is `false`.
* @param training Boolean flag for whether training phase is active.
* @returns The return value of `x()` if `training` is `true`, or the return
* value of `alt()` if `training` is `false`.
*/
function inTrainPhase(x, alt, training = false) {
return training ? x() : alt();
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
const VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg'];
const VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal'];
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function checkFanMode(value) {
checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value);
}
function checkDistribution(value) {
checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value);
}
/**
* Initializer base class.
*
* @doc {
* heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'}
*/
class Initializer extends serialization.Serializable {
fromConfigUsesCustomObjects() {
return false;
}
getConfig() {
return {};
}
}
class Zeros extends Initializer {
apply(shape, dtype) {
return zeros$2(shape, dtype);
}
}
/** @nocollapse */
Zeros.className = 'Zeros';
serialization.registerClass(Zeros);
class Ones extends Initializer {
apply(shape, dtype) {
return ones$3(shape, dtype);
}
}
/** @nocollapse */
Ones.className = 'Ones';
serialization.registerClass(Ones);
class Constant extends Initializer {
constructor(args) {
super();
if (typeof args !== 'object') {
throw new ValueError(`Expected argument of type ConstantConfig but got ${args}`);
}
if (args.value === undefined) {
throw new ValueError(`config must have value set but got ${args}`);
}
this.value = args.value;
}
apply(shape, dtype) {
return tidy$1(() => mul$1(scalar$1(this.value), ones$3(shape, dtype)));
}
getConfig() {
return {
value: this.value,
};
}
}
/** @nocollapse */
Constant.className = 'Constant';
serialization.registerClass(Constant);
class RandomUniform extends Initializer {
constructor(args) {
super();
this.DEFAULT_MINVAL = -0.05;
this.DEFAULT_MAXVAL = 0.05;
this.minval = args.minval || this.DEFAULT_MINVAL;
this.maxval = args.maxval || this.DEFAULT_MAXVAL;
this.seed = args.seed;
}