@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
1,444 lines (1,428 loc) • 1.47 MB
JavaScript
/**
* @license
* Copyright 2024 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
'use strict';
var tfc = require('@tensorflow/tfjs-core');
function _interopNamespaceDefault(e) {
var n = Object.create(null);
if (e) {
Object.keys(e).forEach(function (k) {
if (k !== 'default') {
var d = Object.getOwnPropertyDescriptor(e, k);
Object.defineProperty(n, k, d.get ? d : {
enumerable: true,
get: function () { return e[k]; }
});
}
});
}
n.default = e;
return n;
}
function _mergeNamespaces(n, m) {
m.forEach(function (e) {
e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) {
if (k !== 'default' && !(k in n)) {
var d = Object.getOwnPropertyDescriptor(e, k);
Object.defineProperty(n, k, d.get ? d : {
enumerable: true,
get: function () { return e[k]; }
});
}
});
});
return n;
}
var tfc__namespace = /*#__PURE__*/_interopNamespaceDefault(tfc);
/******************************************************************************
Copyright (c) Microsoft Corporation.
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.
***************************************************************************** */
/* global Reflect, Promise */
var extendStatics = function (d, b) {
extendStatics = Object.setPrototypeOf ||
({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
function (d, b) { for (var p in b)
if (Object.prototype.hasOwnProperty.call(b, p))
d[p] = b[p]; };
return extendStatics(d, b);
};
function __extends(d, b) {
if (typeof b !== "function" && b !== null)
throw new TypeError("Class extends value " + String(b) + " is not a constructor or null");
extendStatics(d, b);
function __() { this.constructor = d; }
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
}
function __awaiter(thisArg, _arguments, P, generator) {
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try {
step(generator.next(value));
}
catch (e) {
reject(e);
} }
function rejected(value) { try {
step(generator["throw"](value));
}
catch (e) {
reject(e);
} }
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
}
function __generator(thisArg, body) {
var _ = { label: 0, sent: function () { if (t[0] & 1)
throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function () { return this; }), g;
function verb(n) { return function (v) { return step([n, v]); }; }
function step(op) {
if (f)
throw new TypeError("Generator is already executing.");
while (_)
try {
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done)
return t;
if (y = 0, t)
op = [op[0] & 2, t.value];
switch (op[0]) {
case 0:
case 1:
t = op;
break;
case 4:
_.label++;
return { value: op[1], done: false };
case 5:
_.label++;
y = op[1];
op = [0];
continue;
case 7:
op = _.ops.pop();
_.trys.pop();
continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) {
_ = 0;
continue;
}
if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) {
_.label = op[1];
break;
}
if (op[0] === 6 && _.label < t[1]) {
_.label = t[1];
t = op;
break;
}
if (t && _.label < t[2]) {
_.label = t[2];
_.ops.push(op);
break;
}
if (t[2])
_.ops.pop();
_.trys.pop();
continue;
}
op = body.call(thisArg, _);
}
catch (e) {
op = [6, e];
y = 0;
}
finally {
f = t = 0;
}
if (op[0] & 5)
throw op[1];
return { value: op[0] ? op[1] : void 0, done: true };
}
}
function __values(o) {
var s = typeof Symbol === "function" && Symbol.iterator, m = s && o[s], i = 0;
if (m)
return m.call(o);
if (o && typeof o.length === "number")
return {
next: function () {
if (o && i >= o.length)
o = void 0;
return { value: o && o[i++], done: !o };
}
};
throw new TypeError(s ? "Object is not iterable." : "Symbol.iterator is not defined.");
}
function __read(o, n) {
var m = typeof Symbol === "function" && o[Symbol.iterator];
if (!m)
return o;
var i = m.call(o), r, ar = [], e;
try {
while ((n === void 0 || n-- > 0) && !(r = i.next()).done)
ar.push(r.value);
}
catch (error) {
e = { error: error };
}
finally {
try {
if (r && !r.done && (m = i["return"]))
m.call(i);
}
finally {
if (e)
throw e.error;
}
}
return ar;
}
function __spreadArray(to, from, pack) {
if (pack || arguments.length === 2)
for (var i = 0, l = from.length, ar; i < l; i++) {
if (ar || !(i in from)) {
if (!ar)
ar = Array.prototype.slice.call(from, 0, i);
ar[i] = from[i];
}
}
return to.concat(ar || Array.prototype.slice.call(from));
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Explicit error types.
*
* See the following link for more information about why the code includes
* calls to setPrototypeOf:
*
* https://github.com/Microsoft/TypeScript-wiki/blob/master/Breaking-Changes.md#extending-built-ins-like-error-array-and-map-may-no-longer-work
*/
// tslint:enable
/**
* Equivalent of Python's AttributeError.
*/
var AttributeError = /** @class */ (function (_super) {
__extends(AttributeError, _super);
function AttributeError(message) {
var _this = _super.call(this, message) || this;
// Set the prototype explicitly.
Object.setPrototypeOf(_this, AttributeError.prototype);
return _this;
}
return AttributeError;
}(Error));
/**
* Equivalent of Python's RuntimeError.
*/
var RuntimeError = /** @class */ (function (_super) {
__extends(RuntimeError, _super);
function RuntimeError(message) {
var _this = _super.call(this, message) || this;
// Set the prototype explicitly.
Object.setPrototypeOf(_this, RuntimeError.prototype);
return _this;
}
return RuntimeError;
}(Error));
/**
* Equivalent of Python's ValueError.
*/
var ValueError = /** @class */ (function (_super) {
__extends(ValueError, _super);
function ValueError(message) {
var _this = _super.call(this, message) || this;
// Set the prototype explicitly.
Object.setPrototypeOf(_this, ValueError.prototype);
return _this;
}
return ValueError;
}(Error));
/**
* Equivalent of Python's NotImplementedError.
*/
var NotImplementedError = /** @class */ (function (_super) {
__extends(NotImplementedError, _super);
function NotImplementedError(message) {
var _this = _super.call(this, message) || this;
// Set the prototype explicitly.
Object.setPrototypeOf(_this, NotImplementedError.prototype);
return _this;
}
return NotImplementedError;
}(Error));
/**
* Equivalent of Python's AssertionError.
*/
var AssertionError = /** @class */ (function (_super) {
__extends(AssertionError, _super);
function AssertionError(message) {
var _this = _super.call(this, message) || this;
// Set the prototype explicitly.
Object.setPrototypeOf(_this, AssertionError.prototype);
return _this;
}
return AssertionError;
}(Error));
/**
* Equivalent of Python's IndexError.
*/
/** @class */ ((function (_super) {
__extends(IndexError, _super);
function IndexError(message) {
var _this = _super.call(this, message) || this;
// Set the prototype explicitly.
Object.setPrototypeOf(_this, IndexError.prototype);
return _this;
}
return IndexError;
})(Error));
/**
* @license
* Copyright 2022 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* LruCache: A mapping from the String to T. If the number of the entries is
* exceeding the `maxEntries`, the LruCache will delete the least recently
* used entry.
*/
var LruCache = /** @class */ (function () {
function LruCache(maxEntries) {
this.maxEntries = maxEntries || 100;
this.cache = new Map();
}
/**
* Get the entry for the key and mark it as used recently.
*/
LruCache.prototype.get = function (key) {
var entry;
if (this.cache.has(key)) {
entry = this.cache.get(key);
this.cache.delete(key);
this.cache.set(key, entry);
}
return entry;
};
/**
* Put the entry into the cache. If the key already existed, mark the key as
* used recently.
*/
LruCache.prototype.put = function (key, value) {
if (this.cache.has(key)) {
this.cache.delete(key);
}
else if (this.cache.size >= this.maxEntries) {
var keyToDelete = this.cache.keys().next().value;
this.cache.delete(keyToDelete);
}
this.cache.set(key, value);
};
/**
* Get the MaxEntries of the cache.
*/
LruCache.prototype.getMaxEntries = function () {
return this.maxEntries;
};
/**
* Set the MaxEntries of the cache. If the maxEntries is decreased, reduce
* entries in the cache.
*/
LruCache.prototype.setMaxEntries = function (maxEntries) {
if (maxEntries < 0) {
throw new Error("The maxEntries of LRU caches must be at least 0, but got ".concat(maxEntries, "."));
}
if (this.maxEntries > maxEntries) {
for (var i = 0; i < this.maxEntries - maxEntries; i++) {
var keyToDelete = this.cache.keys().next().value;
this.cache.delete(keyToDelete);
}
}
this.maxEntries = maxEntries;
};
return LruCache;
}());
// tslint:enable
/**
* If `value` is an Array, equivalent to Python's `value * numValues`.
* If `value` is not an Array, equivalent to Python's `[value] * numValues`
*/
// tslint:disable-next-line:no-any
function pyListRepeat(value, numValues) {
if (Array.isArray(value)) {
// tslint:disable-next-line:no-any
var newArray = [];
for (var i = 0; i < numValues; i++) {
newArray = newArray.concat(value);
}
return newArray;
}
else {
var newArray = new Array(numValues);
newArray.fill(value);
return newArray;
}
}
function assert$1(val, message) {
if (!val) {
throw new AssertionError(message);
}
}
/**
* Count the number of elements of the `array` that are equal to `reference`.
*/
function count(array, refernce) {
var e_1, _a;
var counter = 0;
try {
for (var array_1 = __values(array), array_1_1 = array_1.next(); !array_1_1.done; array_1_1 = array_1.next()) {
var item = array_1_1.value;
if (item === refernce) {
counter++;
}
}
}
catch (e_1_1) { e_1 = { error: e_1_1 }; }
finally {
try {
if (array_1_1 && !array_1_1.done && (_a = array_1.return)) _a.call(array_1);
}
finally { if (e_1) throw e_1.error; }
}
return counter;
}
/**
* If an array is of length 1, just return the first element. Otherwise, return
* the full array.
* @param tensors
*/
function singletonOrArray(xs) {
if (xs.length === 1) {
return xs[0];
}
return xs;
}
/**
* Normalizes a list/tensor into a list.
*
* If a tensor is passed, we return
* a list of size 1 containing the tensor.
*
* @param x target object to be normalized.
*/
// tslint:disable-next-line:no-any
function toList(x) {
if (Array.isArray(x)) {
return x;
}
return [x];
}
/**
* Converts string to snake-case.
* @param name
*/
function toSnakeCase(name) {
var intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');
var insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();
/*
If the class is private the name starts with "_" which is not secure
for creating scopes. We prefix the name with "private" in this case.
*/
if (insecure[0] !== '_') {
return insecure;
}
return 'private' + insecure;
}
function toCamelCase(identifier) {
// quick return for empty string or single character strings
if (identifier.length <= 1) {
return identifier;
}
// Check for the underscore indicating snake_case
if (identifier.indexOf('_') === -1) {
return identifier;
}
return identifier.replace(/[_]+(\w|$)/g, function (m, p1) { return p1.toUpperCase(); });
}
// tslint:disable-next-line:no-any
var _GLOBAL_CUSTOM_OBJECTS = {};
function serializeKerasObject(instance) {
if (instance === null || instance === undefined) {
return null;
}
var dict = {};
dict['className'] = instance.getClassName();
dict['config'] = instance.getConfig();
return dict;
}
/**
* Replace ndarray-style scalar objects in serialization objects with numbers.
*
* Background: In some versions of tf.keras, certain scalar values in the HDF5
* model save file can be serialized as: `{'type': 'ndarray', 'value': num}`,
* where in `num` is a plain number. This method converts such serialization
* to a `number`.
*
* @param config The keras-format serialization object to be processed
* (in place).
*/
function convertNDArrayScalarsInConfig(config) {
var e_3, _a;
if (config == null || typeof config !== 'object') {
return;
}
else if (Array.isArray(config)) {
config.forEach(function (configItem) { return convertNDArrayScalarsInConfig(configItem); });
}
else {
var fields = Object.keys(config);
try {
for (var fields_1 = __values(fields), fields_1_1 = fields_1.next(); !fields_1_1.done; fields_1_1 = fields_1.next()) {
var field = fields_1_1.value;
var value = config[field];
if (value != null && typeof value === 'object') {
if (!Array.isArray(value) && value['type'] === 'ndarray' &&
typeof value['value'] === 'number') {
config[field] = value['value'];
}
else {
convertNDArrayScalarsInConfig(value);
}
}
}
}
catch (e_3_1) { e_3 = { error: e_3_1 }; }
finally {
try {
if (fields_1_1 && !fields_1_1.done && (_a = fields_1.return)) _a.call(fields_1);
}
finally { if (e_3) throw e_3.error; }
}
}
}
/**
* Deserialize a saved Keras Object
* @param identifier either a string ID or a saved Keras dictionary
* @param moduleObjects a list of Python class names to object constructors
* @param customObjects a list of Python class names to object constructors
* @param printableModuleName debug text for the object being reconstituted
* @param fastWeightInit Optional flag to use fast weight initialization
* during deserialization. This is applicable to cases in which
* the initialization will be immediately overwritten by loaded weight
* values. Default: `false`.
* @returns a TensorFlow.js Layers object
*/
// tslint:disable:no-any
function deserializeKerasObject(identifier, moduleObjects, customObjects, printableModuleName, fastWeightInit) {
var _a, _b, _c, e_4, _d, e_5, _e, e_6, _f, e_7, _g;
if (moduleObjects === void 0) { moduleObjects = {}; }
if (customObjects === void 0) { customObjects = {}; }
if (printableModuleName === void 0) { printableModuleName = 'object'; }
if (fastWeightInit === void 0) { fastWeightInit = false; }
// tslint:enable
if (typeof identifier === 'string') {
var functionName = identifier;
var fn = void 0;
if (functionName in customObjects) {
fn = customObjects[functionName];
}
else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {
fn = _GLOBAL_CUSTOM_OBJECTS[functionName];
}
else {
fn = moduleObjects[functionName];
if (fn == null) {
throw new ValueError("Unknown ".concat(printableModuleName, ": ").concat(identifier, ". ") +
"This may be due to one of the following reasons:\n" +
"1. The ".concat(printableModuleName, " is defined in Python, in which ") +
"case it needs to be ported to TensorFlow.js or your JavaScript " +
"code.\n" +
"2. The custom ".concat(printableModuleName, " is defined in JavaScript, ") +
"but is not registered properly with " +
"tf.serialization.registerClass().");
// TODO(cais): Add link to tutorial page on custom layers.
}
}
return fn;
}
else {
// In this case we are dealing with a Keras config dictionary.
var config = identifier;
if (config['className'] == null || config['config'] == null) {
throw new ValueError("".concat(printableModuleName, ": Improper config format: ") +
"".concat(JSON.stringify(config), ".\n") +
"'className' and 'config' must set.");
}
var className = config['className'];
var cls = void 0, fromConfig = void 0;
if (className in customObjects) {
_a = __read(customObjects[className], 2), cls = _a[0], fromConfig = _a[1];
}
else if (className in _GLOBAL_CUSTOM_OBJECTS) {
_b = __read(_GLOBAL_CUSTOM_OBJECTS['className'], 2), cls = _b[0], fromConfig = _b[1];
}
else if (className in moduleObjects) {
_c = __read(moduleObjects[className], 2), cls = _c[0], fromConfig = _c[1];
}
if (cls == null) {
throw new ValueError("Unknown ".concat(printableModuleName, ": ").concat(className, ". ") +
"This may be due to one of the following reasons:\n" +
"1. The ".concat(printableModuleName, " is defined in Python, in which ") +
"case it needs to be ported to TensorFlow.js or your JavaScript " +
"code.\n" +
"2. The custom ".concat(printableModuleName, " is defined in JavaScript, ") +
"but is not registered properly with " +
"tf.serialization.registerClass().");
// TODO(cais): Add link to tutorial page on custom layers.
}
if (fromConfig != null) {
// Porting notes: Instead of checking to see whether fromConfig accepts
// customObjects, we create a customObjects dictionary and tack it on to
// config['config'] as config['config'].customObjects. Objects can use it,
// if they want.
// tslint:disable-next-line:no-any
var customObjectsCombined = {};
try {
for (var _h = __values(Object.keys(_GLOBAL_CUSTOM_OBJECTS)), _j = _h.next(); !_j.done; _j = _h.next()) {
var key = _j.value;
customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];
}
}
catch (e_4_1) { e_4 = { error: e_4_1 }; }
finally {
try {
if (_j && !_j.done && (_d = _h.return)) _d.call(_h);
}
finally { if (e_4) throw e_4.error; }
}
try {
for (var _k = __values(Object.keys(customObjects)), _l = _k.next(); !_l.done; _l = _k.next()) {
var key = _l.value;
customObjectsCombined[key] = customObjects[key];
}
}
catch (e_5_1) { e_5 = { error: e_5_1 }; }
finally {
try {
if (_l && !_l.done && (_e = _k.return)) _e.call(_k);
}
finally { if (e_5) throw e_5.error; }
}
// Add the customObjects to config
var nestedConfig = config['config'];
nestedConfig['customObjects'] = customObjectsCombined;
var backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
try {
for (var _m = __values(Object.keys(customObjects)), _o = _m.next(); !_o.done; _o = _m.next()) {
var key = _o.value;
_GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
}
}
catch (e_6_1) { e_6 = { error: e_6_1 }; }
finally {
try {
if (_o && !_o.done && (_f = _m.return)) _f.call(_m);
}
finally { if (e_6) throw e_6.error; }
}
convertNDArrayScalarsInConfig(config['config']);
var returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit);
_GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
return returnObj;
}
else {
// Then `cls` may be a function returning a class.
// In this case by convention `config` holds
// the kwargs of the function.
var backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
try {
for (var _p = __values(Object.keys(customObjects)), _q = _p.next(); !_q.done; _q = _p.next()) {
var key = _q.value;
_GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
}
}
catch (e_7_1) { e_7 = { error: e_7_1 }; }
finally {
try {
if (_q && !_q.done && (_g = _p.return)) _g.call(_p);
}
finally { if (e_7) throw e_7.error; }
}
// In python this is **config['config'], for tfjs-layers we require
// classes that use this fall-through construction method to take
// a config interface that mimics the expansion of named parameters.
var returnObj = new cls(config['config']);
_GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
return returnObj;
}
}
}
/**
* Compares two numbers for sorting.
* @param a
* @param b
*/
function numberCompare(a, b) {
return (a < b) ? -1 : ((a > b) ? 1 : 0);
}
/**
* Comparison of two numbers for reverse sorting.
* @param a
* @param b
*/
function reverseNumberCompare(a, b) {
return -1 * numberCompare(a, b);
}
/**
* Get the unique elements of an array.
* @param xs Array.
* @returns An Array consisting of the unique elements in `xs`.
*/
function unique(xs) {
var e_8, _a;
if (xs == null) {
return xs;
}
var out = [];
try {
// TODO(cais): Maybe improve performance by sorting.
for (var xs_1 = __values(xs), xs_1_1 = xs_1.next(); !xs_1_1.done; xs_1_1 = xs_1.next()) {
var x = xs_1_1.value;
if (out.indexOf(x) === -1) {
out.push(x);
}
}
}
catch (e_8_1) { e_8 = { error: e_8_1 }; }
finally {
try {
if (xs_1_1 && !xs_1_1.done && (_a = xs_1.return)) _a.call(xs_1);
}
finally { if (e_8) throw e_8.error; }
}
return out;
}
/**
* Determine if an Object is empty (i.e., does not have own properties).
* @param obj Object
* @returns Whether the Object is empty.
* @throws ValueError: If object is `null` or `undefined`.
*/
function isObjectEmpty(obj) {
if (obj == null) {
throw new ValueError("Invalid value in obj: ".concat(JSON.stringify(obj)));
}
for (var key in obj) {
if (obj.hasOwnProperty(key)) {
return false;
}
}
return true;
}
/**
* Helper function used to build type union/enum run-time checkers.
* @param values The list of allowed values.
* @param label A string name for the type
* @param value The value to test.
* @throws ValueError: If the value is not in values nor `undefined`/`null`.
*/
function checkStringTypeUnionValue(values, label, value) {
if (value == null) {
return;
}
if (values.indexOf(value) < 0) {
throw new ValueError("".concat(value, " is not a valid ").concat(label, ". Valid values are ").concat(values, " or null/undefined."));
}
}
/**
* Helper function for verifying the types of inputs.
*
* Ensures that the elements of `x` are all of type `expectedType`.
* Also verifies that the length of `x` is within bounds.
*
* @param x Object to test.
* @param expectedType The string expected type of all of the elements in the
* Array.
* @param minLength Return false if x.length is less than this.
* @param maxLength Return false if x.length is greater than this.
* @returns true if and only if `x` is an `Array<expectedType>` with
* length >= `minLength` and <= `maxLength`.
*/
// tslint:disable:no-any
function checkArrayTypeAndLength(x, expectedType, minLength, maxLength) {
if (minLength === void 0) { minLength = 0; }
if (maxLength === void 0) { maxLength = Infinity; }
assert$1(minLength >= 0);
assert$1(maxLength >= minLength);
return (Array.isArray(x) && x.length >= minLength && x.length <= maxLength &&
x.every(function (e) { return typeof e === expectedType; }));
}
// tslint:enable:no-any
/**
* Assert that a value or an array of value are positive integer.
*
* @param value The value being asserted on. May be a single number or an array
* of numbers.
* @param name Name of the value, used to make the error message.
*/
function assertPositiveInteger(value, name) {
if (Array.isArray(value)) {
tfc.util.assert(value.length > 0, function () { return "".concat(name, " is unexpectedly an empty array."); });
value.forEach(function (v, i) { return assertPositiveInteger(v, "element ".concat(i + 1, " of ").concat(name)); });
}
else {
tfc.util.assert(Number.isInteger(value) && value > 0, function () { return "Expected ".concat(name, " to be a positive integer, but got ") +
"".concat(formatAsFriendlyString(value), "."); });
}
}
/**
* Format a value into a display-friendly, human-readable fashion.
*
* - `null` is formatted as `'null'`
* - Strings are formated with flanking pair of quotes.
* - Arrays are formatted with flanking pair of square brackets.
*
* @param value The value to display.
* @return Formatted string.
*/
// tslint:disable-next-line:no-any
function formatAsFriendlyString(value) {
if (value === null) {
return 'null';
}
else if (Array.isArray(value)) {
return '[' + value.map(function (v) { return formatAsFriendlyString(v); }).join(',') + ']';
}
else if (typeof value === 'string') {
return "\"".concat(value, "\"");
}
else {
return "".concat(value);
}
}
/**
* Returns a function `f2` (decorator) which wraps the original function
* `f`. `f2` guarantees that `f` can be called at most once
* every `waitMs` ms. If `f2` is called more often, it will return
* the last returned result of `f`.
*
* @param f The original function `f` to wrap.
* @param waitMs The time between two consecutive calls to `f` in ms.
*/
function debounce(f, waitMs, nowFunc) {
var lastTime = nowFunc != null ? nowFunc() : tfc.util.now();
var lastResult;
var f2 = function () {
var args = [];
for (var _i = 0; _i < arguments.length; _i++) {
args[_i] = arguments[_i];
}
var now = nowFunc != null ? nowFunc() : tfc.util.now();
if (now - lastTime < waitMs) {
return lastResult;
}
lastTime = now;
lastResult = f.apply(void 0, __spreadArray([], __read(args), false));
return lastResult;
};
return f2;
}
/**
* Returns the fusable activation given a layers identifier.
*
* @param activationName The layers identifier string.
* @return The name of the fusable activation.
*/
function mapActivationToFusedKernel(activationName) {
if (activationName === 'relu') {
return 'relu';
}
if (activationName === 'linear') {
return 'linear';
}
if (activationName === 'elu') {
return 'elu';
}
return null;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Utilities related to persistent state in the backend.
*/
/**
* An ID to track `tf.SymbolicTensor`s and derived classes.
* Required in different places in engine/topology.ts to identify unique
* tensors.
*/
var _nextUniqueTensorId = 0;
function getNextUniqueTensorId() {
return _nextUniqueTensorId++;
}
var _uidPrefixes = {};
/**
* Provides a unique UID given a string prefix.
*
* @param prefix
*/
function getUid(prefix) {
if (prefix === void 0) { prefix = ''; }
if (!(prefix in _uidPrefixes)) {
_uidPrefixes[prefix] = 0;
}
_uidPrefixes[prefix] += 1;
return prefix + _uidPrefixes[prefix].toString();
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast'];
var VALID_INTERPOLATION_FORMAT_VALUES = ['nearest', 'bilinear'];
var VALID_PADDING_MODE_VALUES = ['valid', 'same', 'causal'];
var VALID_POOL_MODE_VALUES = ['max', 'avg'];
var VALID_BIDIRECTIONAL_MERGE_MODES = ['sum', 'mul', 'concat', 'ave'];
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
// A map from the requested scoped name of a Tensor to the number of Tensors
// wanting that name so far. This allows enforcing name uniqueness by appending
// an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.
var nameMap = new Map();
function checkDataFormat(value) {
checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
}
function checkInterpolationFormat(value) {
checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value);
}
function checkPaddingMode(value) {
checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);
}
function checkPoolMode(value) {
checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);
}
var _nameScopeStack = [];
var _nameScopeDivider = '/';
/**
* Enter namescope, which can be nested.
*/
function nameScope(name, fn) {
_nameScopeStack.push(name);
try {
var val = fn();
_nameScopeStack.pop();
return val;
}
catch (e) {
_nameScopeStack.pop();
throw e;
}
}
/**
* Get the current namescope as a flat, concatenated string.
*/
function currentNameScopePrefix() {
if (_nameScopeStack.length === 0) {
return '';
}
else {
return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
}
}
/**
* Get the name a Tensor (or Variable) would have if not uniqueified.
* @param tensorName
* @return Scoped name string.
*/
function getScopedTensorName(tensorName) {
if (!isValidTensorName(tensorName)) {
throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
}
return currentNameScopePrefix() + tensorName;
}
/**
* Get unique names for Tensors and Variables.
* @param scopedName The fully-qualified name of the Tensor, i.e. as produced by
* `getScopedTensorName()`.
* @return A unique version of the given fully scoped name.
* If this is the first time that the scoped name is seen in this session,
* then the given `scopedName` is returned unaltered. If the same name is
* seen again (producing a collision), an incrementing suffix is added to the
* end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.
*/
function getUniqueTensorName(scopedName) {
if (!isValidTensorName(scopedName)) {
throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
}
if (!nameMap.has(scopedName)) {
nameMap.set(scopedName, 0);
}
var index = nameMap.get(scopedName);
nameMap.set(scopedName, nameMap.get(scopedName) + 1);
if (index > 0) {
var result = "".concat(scopedName, "_").concat(index);
// Mark the composed name as used in case someone wants
// to call getUniqueTensorName("name_1").
nameMap.set(result, 1);
return result;
}
else {
return scopedName;
}
}
var tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
/**
* Determine whether a string is a valid tensor name.
* @param name
* @returns A Boolean indicating whether `name` is a valid tensor name.
*/
function isValidTensorName(name) {
return !!name.match(tensorNameRegex);
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Determine if a number is an integer.
*/
function isInteger(x) {
return x === parseInt(x.toString(), 10);
}
/**
* Calculate the product of an array of numbers.
* @param array The array to calculate the product over.
* @param begin Beginning index, inclusive.
* @param end Ending index, exclusive.
* @return The product.
*/
function arrayProd(array, begin, end) {
if (begin == null) {
begin = 0;
}
if (end == null) {
end = array.length;
}
var prod = 1;
for (var i = begin; i < end; ++i) {
prod *= array[i];
}
return prod;
}
/**
* Compute minimum value.
* @param array
* @return minimum value.
*/
function min(array) {
// same behavior as tf.min()
if (array.length === 0) {
return Number.NaN;
}
var min = Number.POSITIVE_INFINITY;
for (var i = 0; i < array.length; i++) {
var value = array[i];
if (value < min) {
min = value;
}
}
return min;
}
/**
* Compute maximum value.
* @param array
* @return maximum value
*/
function max(array) {
// same behavior as tf.max()
if (array.length === 0) {
return Number.NaN;
}
var max = Number.NEGATIVE_INFINITY;
for (var i = 0; i < array.length; i++) {
var value = array[i];
if (value > max) {
max = value;
}
}
return max;
}
/**
* Generate an array of integers in [begin, end).
* @param begin Beginning integer, inclusive.
* @param end Ending integer, exclusive.
* @returns Range array.
* @throws ValueError, iff `end` < `begin`.
*/
function range(begin, end) {
if (end < begin) {
throw new ValueError("end (".concat(end, ") < begin (").concat(begin, ") is forbidden."));
}
var out = [];
for (var i = begin; i < end; ++i) {
out.push(i);
}
return out;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var _epsilon;
/**
* Returns the value of the fuzz factor used in numeric expressions.
*/
function epsilon() {
if (_epsilon == null) {
_epsilon = tfc.backend().epsilon();
}
return _epsilon;
}
/**
* Returns the default image data format convention.
*/
function imageDataFormat() {
return 'channelsLast';
}
/**
* Casts a tensor to a different dtype and returns it.
* @param x Input tensor.
* @param dtype String: 'float32'|'int32'|'bool'.
* @returns Tensor of the specified `dtype`.
*/
function cast$1(x, dtype) {
return tfc__namespace.cast(x, dtype);
}
/**
* Adds a 1-sized dimension at index "axis".
* @param x Input tensor.
* @param axis Position where to add the new axis.
* @returns Result of the dimension expansion.
*/
function expandDims$1(x, axis) {
if (axis === void 0) { axis = -1; }
var outShape = x.shape.slice();
if (axis < 0) {
axis = outShape.length + axis + 1;
}
outShape.splice(axis, 0, 1);
return tfc__namespace.reshape(x, outShape);
}
/**
* Repeats a 2D tensor.
*
* If `x` has shape `[samples, dim]` and `n` is 2, for example, the output
* will have shape `[samples, 2, dim]`.
*
* @param x Input tensor.
* @param n Integer, number of times to repeat.
* @returns The result of the repeat operation.
* @throws ValueError: If input tensor is not 2D.
*/
function repeat(x, n) {
return tfc.tidy(function () {
if (x.shape.length !== 2) {
throw new ValueError("repeat() expects a rank-2 tensor, but received a " +
"rank-".concat(x.shape.length, " tensor."));
}
var y = expandDims$1(x, 1);
return tile$1(y, [1, n, 1]);
});
}
/**
* Flatten a Tensor into 1D.
* @param x Input tensor.
* @return The result of the flattening `x`.
*/
function flatten$2(x) {
var newShape = [arrayProd(x.shape)];
return tfc__namespace.reshape(x, newShape);
}
/**
* Turn a nD tensor into a 2D tensor with same 0th dimension.
* In other words, it flattens each data samples of a batch.
*
* @param x The tensor to flatten. The rank of this tensor is required to be 2
* or higher.
* @return The result of the flattening.
*/
function batchFlatten(x) {
if (x.rank <= 1) {
throw new ValueError("batchFlatten requires a minimum rank of 2. Got rank: ".concat(x.rank, "."));
}
var newShape = [x.shape[0], arrayProd(x.shape, 1)];
return tfc__namespace.reshape(x, newShape);
}
/**
* Do slicing along the first axis.
* @param array input `tf.Tensor`.
* @param start starting index, inclusive.
* @param size size of the slice along the first axis.
* @returns result of the slicing.
* @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
*/
function sliceAlongFirstAxis(array, start, size) {
return tfc.tidy(function () {
switch (array.rank) {
case 1:
return tfc__namespace.slice1d(array, start, size);
case 2:
return tfc__namespace.slice2d(array, [start, 0], [size, array.shape[1]]);
case 3:
return tfc__namespace.slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
case 4:
return tfc__namespace.slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
case 5:
return tfc__namespace.slice(array, [start, 0, 0, 0, 0], [
size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]
]);
case 6:
return tfc__namespace.slice(array, [start, 0, 0, 0, 0, 0], [
size, array.shape[1], array.shape[2], array.shape[3], array.shape[4],
array.shape[5]
]);
default:
throw new ValueError("sliceAlongFirstAxis() received an unsupported tensor rank: " +
"".concat(array.rank));
}
});
}
/**
* Do slicing along the last axis.
* @param array input `tf.Tensor`.
* @param start starting index, inclusive.
* @param size size of the slice along the last axis.
* @returns result of the slicing.
* @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
*/
function sliceAlongLastAxis(array, start, size) {
return tfc.tidy(function () {
switch (array.rank) {
case 1:
return tfc__namespace.slice1d(array, start, size);
case 2:
return tfc__namespace.slice2d(array, [0, start], [array.shape[0], size]);
case 3:
return tfc__namespace.slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]);
case 4:
return tfc__namespace.slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]);
default:
throw new ValueError("sliceAlongLastAxis() received an unsupported tensor rank: " +
"".concat(array.rank));
}
});
}
/**
* Do slicing along the sepcified axis.
* @param array input `tf.Tensor`.
* @param start starting index, inclusive.
* @param size of the slice along the chosen axis.
* @param choose an axis.
* @returns result of the slicing.
* @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
*/
function sliceAlongAxis(array, start, size, axis) {
return tfc.tidy(function () {
switch (array.rank) {
case 1:
return tfc__namespace.slice1d(array, start, size);
case 2:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError("The axis is not within the rank of the tensor " +
"".concat(axis));
}
case 3:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return tfc__namespace.slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]);
case 3:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError("The axis is not within the rank of the tensor " +
"".concat(axis));
}
case 4:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return tfc__namespace.slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]);
case 3:
return tfc__namespace.slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]);
case 4:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError("The axis is not within the rank of the tensor " +
"".concat(axis));
}
default:
throw new ValueError("sliceAlongLastAxis() received an unsupported tensor rank: " +
"".concat(array.rank));
}
});
}
/**
* Concatenates a list of tensors alongside the specified axis.
* @param tensors `Array` of tensors to concatenate.
* @param axis Concatenation axis.
* @returns The result of the concatenation.
*/
function concatenate$1(tensors, axis) {
if (axis === void 0) { axis = -1; }
var rank;
if (axis < 0) {
rank = tensors[0].rank;
if (rank !== 0) {
axis = rank;
}
else {
axis = 0;
}
}
if (axis === tensors[0].rank) {
// Porting Note: This is necessary because tfc.concat() requires axis to be
// in the interval [-rank, rank).
axis = -1;
}
// Porting Note: Sparse concat is not supported yet.
return tfc__namespace.concat(tensors, axis);
}
/**
* Concatenate two arrays along the first dimension.
* @param a The 1st `tf.Tensor` to concatenate.
* @param b The 2nd `tf.Tensor` to concatenate.
* @returns Result of the concatenation.
* @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`.
*/
function concatAlongFirstAxis(a, b) {
switch (a.rank) {
case 1:
return tfc__namespace.concat1d([a, b]);
case 2:
return tfc__namespace.concat2d([a, b], 0);
case 3:
return tfc__namespace.concat3d([a, b], 0);
case 4:
return tfc__namespace.concat4d([a, b], 0);
default:
throw new ValueError("concatAlongFirstAxis() received an unsupported " +
"tensor rank: ".concat(a.rank));
}
}
/**
* Creates a tensor by tiling `x` by `n`.
* @param x A tensor.
* @param n An Array of integers or a single integer. If an Array, the length
* must be the same as the number of dimensions in `x`. If a single integer,
* it will be treated as an Array of length 1.
*/
function tile$1(x, n) {
if (!Array.isArray(n)) {
n = [n];
}
if (x.rank !== n.length) {
throw new ValueError("The length of input n (".concat(n.length, ") does not match ") +
"the number of dimensions in input x (".concat(x.rank, ")"));
}
return tfc__namespace.tile(x, n);
}
/* Creation of random tensors. */
/**
* Get a tensor with normal distribution of values.
*
* @param shape Shape of the tensor.
* @param mean mean value of the normal distribution.
* @param stddev standard deviation of the normal distribution.
* @param dtype
* @param seed
* @return The normal tensor.
*/
function randomNormal$1(shape, mean, stddev, dtype, seed) {
if (mean === void 0) { mean = 0.0; }
if (stddev === void 0) { stddev = 1.0; }
return tfc__namespace.randomNormal(shape, mean, stddev, dtype, seed);
}
/* Linear Algebra */
/**
* Multiply two tensors and returns the result as a tensor.
*
* For 2D tensors, this is equivalent to matrix multiplication (matMul).
* For tensors of higher ranks, it follows the Theano behavior,
* (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation:
*
* For N dimensions it is a sum product over the last axis of x and the
* second-to-last of y:
*
* @param a A tensor of at least rank 2.
* @param b A tensor of at least rank 2.
*