handsfree
Version:
Quickly integrate face, hand, and/or pose tracking to your frontend projects in a snap ✨👌
1,289 lines (1,275 loc) • 1.46 MB
JavaScript
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
(function (global, factory) {
typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) :
typeof define === 'function' && define.amd ? define(['exports'], factory) :
(global = global || self, factory(global.tf = global.tf || {}));
}(this, (function (exports) { 'use strict';
/*! *****************************************************************************
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at http://www.apache.org/licenses/LICENSE-2.0
THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
MERCHANTABLITY OR NON-INFRINGEMENT.
See the Apache Version 2.0 License for specific language governing permissions
and limitations under the License.
***************************************************************************** */
/* global Reflect, Promise */
var extendStatics = function(d, b) {
extendStatics = Object.setPrototypeOf ||
({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; };
return extendStatics(d, b);
};
function __extends(d, b) {
extendStatics(d, b);
function __() { this.constructor = d; }
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
}
function __awaiter(thisArg, _arguments, P, generator) {
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
}
function __generator(thisArg, body) {
var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g;
function verb(n) { return function (v) { return step([n, v]); }; }
function step(op) {
if (f) throw new TypeError("Generator is already executing.");
while (_) try {
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;
if (y = 0, t) op = [op[0] & 2, t.value];
switch (op[0]) {
case 0: case 1: t = op; break;
case 4: _.label++; return { value: op[1], done: false };
case 5: _.label++; y = op[1]; op = [0]; continue;
case 7: op = _.ops.pop(); _.trys.pop(); continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
if (t[2]) _.ops.pop();
_.trys.pop(); continue;
}
op = body.call(thisArg, _);
} catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
}
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
/**
* The environment contains evaluated flags as well as the registered platform.
* This is always used as a global singleton and can be retrieved with
* `tf.env()`.
*/
/** @doc {heading: 'Environment'} */
var Environment = /** @class */ (function () {
// tslint:disable-next-line: no-any
function Environment(global) {
this.global = global;
this.flags = {};
this.flagRegistry = {};
this.urlFlags = {};
this.populateURLFlags();
}
Environment.prototype.setPlatform = function (platformName, platform) {
if (this.platform != null) {
console.warn("Platform " + this.platformName + " has already been set. " +
("Overwriting the platform with " + platform + "."));
}
this.platformName = platformName;
this.platform = platform;
};
Environment.prototype.registerFlag = function (flagName, evaluationFn, setHook) {
this.flagRegistry[flagName] = { evaluationFn: evaluationFn, setHook: setHook };
// Override the flag value from the URL. This has to happen here because the
// environment is initialized before flags get registered.
if (this.urlFlags[flagName] != null) {
var flagValue = this.urlFlags[flagName];
console.warn("Setting feature override from URL " + flagName + ": " + flagValue + ".");
this.set(flagName, flagValue);
}
};
Environment.prototype.getAsync = function (flagName) {
return __awaiter(this, void 0, void 0, function () {
var _a, _b;
return __generator(this, function (_c) {
switch (_c.label) {
case 0:
if (flagName in this.flags) {
return [2 /*return*/, this.flags[flagName]];
}
_a = this.flags;
_b = flagName;
return [4 /*yield*/, this.evaluateFlag(flagName)];
case 1:
_a[_b] = _c.sent();
return [2 /*return*/, this.flags[flagName]];
}
});
});
};
Environment.prototype.get = function (flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
var flagValue = this.evaluateFlag(flagName);
if (flagValue instanceof Promise) {
throw new Error("Flag " + flagName + " cannot be synchronously evaluated. " +
"Please use getAsync() instead.");
}
this.flags[flagName] = flagValue;
return this.flags[flagName];
};
Environment.prototype.getNumber = function (flagName) {
return this.get(flagName);
};
Environment.prototype.getBool = function (flagName) {
return this.get(flagName);
};
Environment.prototype.getFlags = function () {
return this.flags;
};
Object.defineProperty(Environment.prototype, "features", {
// For backwards compatibility.
get: function () {
return this.flags;
},
enumerable: true,
configurable: true
});
Environment.prototype.set = function (flagName, value) {
if (this.flagRegistry[flagName] == null) {
throw new Error("Cannot set flag " + flagName + " as it has not been registered.");
}
this.flags[flagName] = value;
if (this.flagRegistry[flagName].setHook != null) {
this.flagRegistry[flagName].setHook(value);
}
};
Environment.prototype.evaluateFlag = function (flagName) {
if (this.flagRegistry[flagName] == null) {
throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found.");
}
return this.flagRegistry[flagName].evaluationFn();
};
Environment.prototype.setFlags = function (flags) {
this.flags = Object.assign({}, flags);
};
Environment.prototype.reset = function () {
this.flags = {};
this.urlFlags = {};
this.populateURLFlags();
};
Environment.prototype.populateURLFlags = function () {
var _this = this;
if (typeof this.global === 'undefined' ||
typeof this.global.location === 'undefined' ||
typeof this.global.location.search === 'undefined') {
return;
}
var urlParams = getQueryParams(this.global.location.search);
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
keyValues.forEach(function (keyValue) {
var _a = keyValue.split(':'), key = _a[0], value = _a[1];
_this.urlFlags[key] = parseValue(key, value);
});
}
};
return Environment;
}());
function getQueryParams(queryString) {
var params = {};
queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) {
var t = [];
for (var _i = 1; _i < arguments.length; _i++) {
t[_i - 1] = arguments[_i];
}
decodeParam(params, t[0], t[1]);
return t.join('=');
});
return params;
}
function decodeParam(params, name, value) {
params[decodeURIComponent(name)] = decodeURIComponent(value || '');
}
function parseValue(flagName, value) {
value = value.toLowerCase();
if (value === 'true' || value === 'false') {
return value === 'true';
}
else if ("" + +value === value) {
return +value;
}
throw new Error("Could not parse value flag value " + value + " for flag " + flagName + ".");
}
/**
* Returns the current environment (a global singleton).
*
* The environment object contains the evaluated feature values as well as the
* active platform.
*/
/** @doc {heading: 'Environment'} */
function env() {
return exports.ENV;
}
exports.ENV = null;
function setEnvironmentGlobal(environment) {
exports.ENV = environment;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Note that the identifier globalNameSpace is scoped to this module, but will
// always resolve to the same global object regardless of how the module is
// resolved.
// tslint:disable-next-line:no-any
var globalNameSpace;
// tslint:disable-next-line:no-any
function getGlobalNamespace() {
if (globalNameSpace == null) {
// tslint:disable-next-line:no-any
var ns = void 0;
if (typeof (window) !== 'undefined') {
ns = window;
}
else if (typeof (global) !== 'undefined') {
ns = global;
}
else if (typeof (process) !== 'undefined') {
ns = process;
}
else if (typeof (self) !== 'undefined') {
ns = self;
}
else {
throw new Error('Could not find a global object');
}
globalNameSpace = ns;
}
return globalNameSpace;
}
// tslint:disable-next-line:no-any
function getGlobalMap() {
var ns = getGlobalNamespace();
if (ns._tfGlobals == null) {
ns._tfGlobals = new Map();
}
return ns._tfGlobals;
}
/**
* Returns a globally accessible 'singleton' object.
*
* @param key the name of the object
* @param init a function to initialize to initialize this object
* the first time it is fetched.
*/
function getGlobal(key, init) {
var globalMap = getGlobalMap();
if (globalMap.has(key)) {
return globalMap.get(key);
}
else {
var singleton = init();
globalMap.set(key, singleton);
return globalMap.get(key);
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var kernelRegistry = getGlobal('kernelRegistry', function () { return new Map(); });
var gradRegistry = getGlobal('gradRegistry', function () { return new Map(); });
/**
* Returns the kernel function (code) associated with the provided names.
*
* @param kernelName The official name of the kernel.
* @param backendName The official name of the backend.
*/
function getKernel(kernelName, backendName) {
var key = makeKey(kernelName, backendName);
return kernelRegistry.get(key);
}
/**
* Returns the registered gradient info associated with the provided kernel.
* @param kernelName The official TF kernel name.
*/
function getGradient(kernelName) {
return gradRegistry.get(kernelName);
}
function getKernelsForBackend(backendName) {
var it = kernelRegistry.entries();
var result = [];
while (true) {
var _a = it.next(), done = _a.done, value = _a.value;
if (done) {
break;
}
var key = value[0], config = value[1];
var backend = key.split('_')[0];
if (backend === backendName) {
result.push(config);
}
}
return result;
}
/**
* Registers the function (forward pass) for the kernel in a global registry.
*
* @param config A config object with the following properties:
* - `kernelName` The official name of the kernel.
* - `backendName` The official name of the backend.
* - `kernelFunc` The function to run during the forward pass of the kernel.
* - `setupFunc` Optional. Gets called once, after the backend initializes.
* - `disposeFunc` Optional. Gets called once, right before the backend is
* disposed.
*/
function registerKernel(config) {
var kernelName = config.kernelName, backendName = config.backendName;
var key = makeKey(kernelName, backendName);
if (kernelRegistry.has(key)) {
console.warn("The kernel '" + kernelName + "' for backend " +
("'" + backendName + "' is already registered"));
}
kernelRegistry.set(key, config);
}
/**
* Registers a gradient function for a given kernel in the global registry,
* to be used during the back-propagation of that kernel.
*
* @param config An object with the following properties:
* - `kernelName` The name of the kernel that the gradient function is for.
* - `gradFunc` The function to run during back-propagation.
*/
function registerGradient(config) {
var kernelName = config.kernelName;
if (gradRegistry.has(kernelName)) {
// TODO (yassogba) after 3.0 assess whether we need to keep this gated
// to debug mode.
if (env().getBool('DEBUG')) {
console.warn("Overriding the gradient for '" + kernelName + "'");
}
}
gradRegistry.set(kernelName, config);
}
/**
* Removes the kernel function from the registry.
*
* @param kernelName The official name of the kernel.
* @param backendName The official name of the backend.
*
*/
function unregisterKernel(kernelName, backendName) {
var key = makeKey(kernelName, backendName);
if (!kernelRegistry.has(key)) {
throw new Error("The kernel '" + kernelName + "' for backend " +
("'" + backendName + "' is not registered"));
}
kernelRegistry.delete(key);
}
/** Removes the registered gradient from the global registry. */
function unregisterGradient(kernelName) {
if (!gradRegistry.has(kernelName)) {
throw new Error("The gradient '" + kernelName + "' for backend is not registered");
}
gradRegistry.delete(kernelName);
}
function makeKey(kernelName, backendName) {
return backendName + "_" + kernelName;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Shuffles the array in-place using Fisher-Yates algorithm.
*
* ```js
* const a = [1, 2, 3, 4, 5];
* tf.util.shuffle(a);
* console.log(a);
* ```
*
* @param array The array to shuffle in-place.
*/
/** @doc {heading: 'Util', namespace: 'util'} */
// tslint:disable-next-line:no-any
function shuffle(array) {
var counter = array.length;
var temp = 0;
var index = 0;
// While there are elements in the array
while (counter > 0) {
// Pick a random index
index = (Math.random() * counter) | 0;
// Decrease counter by 1
counter--;
// And swap the last element with it
temp = array[counter];
array[counter] = array[index];
array[index] = temp;
}
}
/** Clamps a value to a specified range. */
function clamp(min, x, max) {
return Math.max(min, Math.min(x, max));
}
function nearestLargerEven(val) {
return val % 2 === 0 ? val : val + 1;
}
function sum(arr) {
var sum = 0;
for (var i = 0; i < arr.length; i++) {
sum += arr[i];
}
return sum;
}
/**
* Returns a sample from a uniform [a, b) distribution.
*
* @param a The minimum support (inclusive).
* @param b The maximum support (exclusive).
* @return A pseudorandom number on the half-open interval [a,b).
*/
function randUniform(a, b) {
var r = Math.random();
return (b * r) + (1 - r) * a;
}
/** Returns the squared Euclidean distance between two vectors. */
function distSquared(a, b) {
var result = 0;
for (var i = 0; i < a.length; i++) {
var diff = Number(a[i]) - Number(b[i]);
result += diff * diff;
}
return result;
}
/**
* Asserts that the expression is true. Otherwise throws an error with the
* provided message.
*
* ```js
* const x = 2;
* tf.util.assert(x === 2, 'x is not 2');
* ```
*
* @param expr The expression to assert (as a boolean).
* @param msg A function that returns the message to report when throwing an
* error. We use a function for performance reasons.
*/
/** @doc {heading: 'Util', namespace: 'util'} */
function assert(expr, msg) {
if (!expr) {
throw new Error(typeof msg === 'string' ? msg : msg());
}
}
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) {
if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; }
assert(arraysEqual(shapeA, shapeB), function () { return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match"); });
}
function assertNonNull(a) {
assert(a != null, function () { return "The input to the tensor constructor must be a non-null value."; });
}
// NOTE: We explicitly type out what T extends instead of any so that
// util.flatten on a nested array of number doesn't try to infer T as a
// number[][], causing us to explicitly type util.flatten<number>().
/**
* Flattens an arbitrarily nested array.
*
* ```js
* const a = [[1, 2], [3, 4], [5, [6, [7]]]];
* const flat = tf.util.flatten(a);
* console.log(flat);
* ```
*
* @param arr The nested array to flatten.
* @param result The destination array which holds the elements.
* @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
* to false.
*/
/** @doc {heading: 'Util', namespace: 'util'} */
function flatten(arr, result, skipTypedArray) {
if (result === void 0) { result = []; }
if (skipTypedArray === void 0) { skipTypedArray = false; }
if (result == null) {
result = [];
}
if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {
for (var i = 0; i < arr.length; ++i) {
flatten(arr[i], result, skipTypedArray);
}
}
else {
result.push(arr);
}
return result;
}
/**
* Returns the size (number of elements) of the tensor given its shape.
*
* ```js
* const shape = [3, 4, 2];
* const size = tf.util.sizeFromShape(shape);
* console.log(size);
* ```
*/
/** @doc {heading: 'Util', namespace: 'util'} */
function sizeFromShape(shape) {
if (shape.length === 0) {
// Scalar.
return 1;
}
var size = shape[0];
for (var i = 1; i < shape.length; i++) {
size *= shape[i];
}
return size;
}
function isScalarShape(shape) {
return shape.length === 0;
}
function arraysEqual(n1, n2) {
if (n1 === n2) {
return true;
}
if (n1 == null || n2 == null) {
return false;
}
if (n1.length !== n2.length) {
return false;
}
for (var i = 0; i < n1.length; i++) {
if (n1[i] !== n2[i]) {
return false;
}
}
return true;
}
function isInt(a) {
return a % 1 === 0;
}
function tanh(x) {
// tslint:disable-next-line:no-any
if (Math.tanh != null) {
// tslint:disable-next-line:no-any
return Math.tanh(x);
}
if (x === Infinity) {
return 1;
}
else if (x === -Infinity) {
return -1;
}
else {
var e2x = Math.exp(2 * x);
return (e2x - 1) / (e2x + 1);
}
}
function sizeToSquarishShape(size) {
var width = Math.ceil(Math.sqrt(size));
return [width, Math.ceil(size / width)];
}
/**
* Creates a new array with randomized indicies to a given quantity.
*
* ```js
* const randomTen = tf.util.createShuffledIndices(10);
* console.log(randomTen);
* ```
*
* @param number Quantity of how many shuffled indicies to create.
*/
/** @doc {heading: 'Util', namespace: 'util'} */
function createShuffledIndices(n) {
var shuffledIndices = new Uint32Array(n);
for (var i = 0; i < n; ++i) {
shuffledIndices[i] = i;
}
shuffle(shuffledIndices);
return shuffledIndices;
}
function rightPad(a, size) {
if (size <= a.length) {
return a;
}
return a + ' '.repeat(size - a.length);
}
function repeatedTry(checkFn, delayFn, maxCounter) {
if (delayFn === void 0) { delayFn = function (counter) { return 0; }; }
return new Promise(function (resolve, reject) {
var tryCount = 0;
var tryFn = function () {
if (checkFn()) {
resolve();
return;
}
tryCount++;
var nextBackoff = delayFn(tryCount);
if (maxCounter != null && tryCount >= maxCounter) {
reject();
return;
}
setTimeout(tryFn, nextBackoff);
};
tryFn();
});
}
/**
* Given the full size of the array and a shape that may contain -1 as the
* implicit dimension, returns the inferred shape where -1 is replaced.
* E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
*
* @param shape The shape, which may contain -1 in some dimension.
* @param size The full size (number of elements) of the array.
* @return The inferred shape where -1 is replaced with the inferred size.
*/
function inferFromImplicitShape(shape, size) {
var shapeProd = 1;
var implicitIdx = -1;
for (var i = 0; i < shape.length; ++i) {
if (shape[i] >= 0) {
shapeProd *= shape[i];
}
else if (shape[i] === -1) {
if (implicitIdx !== -1) {
throw Error("Shapes can only have 1 implicit size. " +
("Found -1 at dim " + implicitIdx + " and dim " + i));
}
implicitIdx = i;
}
else if (shape[i] < 0) {
throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i);
}
}
if (implicitIdx === -1) {
if (size > 0 && size !== shapeProd) {
throw Error("Size(" + size + ") must match the product of shape " + shape);
}
return shape;
}
if (shapeProd === 0) {
throw Error("Cannot infer the missing size in [" + shape + "] when " +
"there are 0 elements");
}
if (size % shapeProd !== 0) {
throw Error("The implicit shape can't be a fractional number. " +
("Got " + size + " / " + shapeProd));
}
var newShape = shape.slice();
newShape[implicitIdx] = size / shapeProd;
return newShape;
}
function parseAxisParam(axis, shape) {
var rank = shape.length;
// Normalize input
axis = axis == null ? shape.map(function (s, i) { return i; }) : [].concat(axis);
// Check for valid range
assert(axis.every(function (ax) { return ax >= -rank && ax < rank; }), function () {
return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " +
("got axis " + axis);
});
// Check for only integers
assert(axis.every(function (ax) { return isInt(ax); }), function () { return "All values in axis param must be integers but " +
("got axis " + axis); });
// Handle negative axis.
return axis.map(function (a) { return a < 0 ? rank + a : a; });
}
/** Reduces the shape by removing all dimensions of shape 1. */
function squeezeShape(shape, axis) {
var newShape = [];
var keptDims = [];
var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
var axes = (axis == null || isEmptyArray) ?
null :
parseAxisParam(axis, shape).sort();
var j = 0;
for (var i = 0; i < shape.length; ++i) {
if (axes != null) {
if (axes[j] === i && shape[i] !== 1) {
throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1");
}
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
if (axes[j] <= i) {
j++;
}
}
if (shape[i] !== 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
}
return { newShape: newShape, keptDims: keptDims };
}
function getTypedArrayFromDType(dtype, size) {
var values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
}
else if (dtype === 'int32') {
values = new Int32Array(size);
}
else if (dtype === 'bool') {
values = new Uint8Array(size);
}
else {
throw new Error("Unknown data type " + dtype);
}
return values;
}
function getArrayFromDType(dtype, size) {
var values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
}
else if (dtype === 'int32') {
values = new Int32Array(size);
}
else if (dtype === 'bool') {
values = new Uint8Array(size);
}
else if (dtype === 'string') {
values = new Array(size);
}
else {
throw new Error("Unknown data type " + dtype);
}
return values;
}
function checkConversionForErrors(vals, dtype) {
for (var i = 0; i < vals.length; i++) {
var num = vals[i];
if (isNaN(num) || !isFinite(num)) {
throw Error("A tensor of type " + dtype + " being uploaded contains " + num + ".");
}
}
}
/** Returns true if the dtype is valid. */
function isValidDtype(dtype) {
return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
dtype === 'int32' || dtype === 'string';
}
/**
* Returns true if the new type can't encode the old type without loss of
* precision.
*/
function hasEncodingLoss(oldType, newType) {
if (newType === 'complex64') {
return false;
}
if (newType === 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'bool' && oldType === 'bool') {
return false;
}
return true;
}
function isTypedArray(a) {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array;
}
function bytesPerElement(dtype) {
if (dtype === 'float32' || dtype === 'int32') {
return 4;
}
else if (dtype === 'complex64') {
return 8;
}
else if (dtype === 'bool') {
return 1;
}
else {
throw new Error("Unknown dtype " + dtype);
}
}
/**
* Returns the approximate number of bytes allocated in the string array - 2
* bytes per character. Computing the exact bytes for a native string in JS is
* not possible since it depends on the encoding of the html page that serves
* the website.
*/
function bytesFromStringArray(arr) {
if (arr == null) {
return 0;
}
var bytes = 0;
arr.forEach(function (x) { return bytes += x.length; });
return bytes;
}
/** Returns true if the value is a string. */
function isString(value) {
return typeof value === 'string' || value instanceof String;
}
function isBoolean(value) {
return typeof value === 'boolean';
}
function isNumber(value) {
return typeof value === 'number';
}
function inferDtype(values) {
if (Array.isArray(values)) {
return inferDtype(values[0]);
}
if (values instanceof Float32Array) {
return 'float32';
}
else if (values instanceof Int32Array || values instanceof Uint8Array) {
return 'int32';
}
else if (isNumber(values)) {
return 'float32';
}
else if (isString(values)) {
return 'string';
}
else if (isBoolean(values)) {
return 'bool';
}
return 'float32';
}
function isFunction(f) {
return !!(f && f.constructor && f.call && f.apply);
}
function nearestDivisor(size, start) {
for (var i = start; i < size; ++i) {
if (size % i === 0) {
return i;
}
}
return size;
}
function computeStrides(shape) {
var rank = shape.length;
if (rank < 2) {
return [];
}
// Last dimension has implicit stride of 1, thus having D-1 (instead of D)
// strides.
var strides = new Array(rank - 1);
strides[rank - 2] = shape[rank - 1];
for (var i = rank - 3; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
function toTypedArray(a, dtype) {
if (dtype === 'string') {
throw new Error('Cannot convert a string[] to a TypedArray');
}
if (Array.isArray(a)) {
a = flatten(a);
}
if (env().getBool('DEBUG')) {
checkConversionForErrors(a, dtype);
}
if (noConversionNeeded(a, dtype)) {
return a;
}
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(a);
}
else if (dtype === 'int32') {
return new Int32Array(a);
}
else if (dtype === 'bool') {
var bool = new Uint8Array(a.length);
for (var i = 0; i < bool.length; ++i) {
if (Math.round(a[i]) !== 0) {
bool[i] = 1;
}
}
return bool;
}
else {
throw new Error("Unknown data type " + dtype);
}
}
function createNestedArray(offset, shape, a) {
var ret = new Array();
if (shape.length === 1) {
var d = shape[0];
for (var i = 0; i < d; i++) {
ret[i] = a[offset + i];
}
}
else {
var d = shape[0];
var rest = shape.slice(1);
var len = rest.reduce(function (acc, c) { return acc * c; });
for (var i = 0; i < d; i++) {
ret[i] = createNestedArray(offset + i * len, rest, a);
}
}
return ret;
}
// Provide a nested array of TypedArray in given shape.
function toNestedArray(shape, a) {
if (shape.length === 0) {
// Scalar type should return a single number.
return a[0];
}
var size = shape.reduce(function (acc, c) { return acc * c; });
if (size === 0) {
// A tensor with shape zero should be turned into empty list.
return [];
}
if (size !== a.length) {
throw new Error("[" + shape + "] does not match the input size " + a.length + ".");
}
return createNestedArray(0, shape, a);
}
function noConversionNeeded(a, dtype) {
return (a instanceof Float32Array && dtype === 'float32') ||
(a instanceof Int32Array && dtype === 'int32') ||
(a instanceof Uint8Array && dtype === 'bool');
}
function makeOnesTypedArray(size, dtype) {
var array = makeZerosTypedArray(size, dtype);
for (var i = 0; i < array.length; i++) {
array[i] = 1;
}
return array;
}
function makeZerosTypedArray(size, dtype) {
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(size);
}
else if (dtype === 'int32') {
return new Int32Array(size);
}
else if (dtype === 'bool') {
return new Uint8Array(size);
}
else {
throw new Error("Unknown data type " + dtype);
}
}
/**
* Make nested `TypedArray` filled with zeros.
* @param shape The shape information for the nested array.
* @param dtype dtype of the array element.
*/
function makeZerosNestedTypedArray(shape, dtype) {
var size = shape.reduce(function (prev, curr) { return prev * curr; }, 1);
if (dtype == null || dtype === 'float32') {
return toNestedArray(shape, new Float32Array(size));
}
else if (dtype === 'int32') {
return toNestedArray(shape, new Int32Array(size));
}
else if (dtype === 'bool') {
return toNestedArray(shape, new Uint8Array(size));
}
else {
throw new Error("Unknown data type " + dtype);
}
}
/**
* Returns the current high-resolution time in milliseconds relative to an
* arbitrary time in the past. It works across different platforms (node.js,
* browsers).
*
* ```js
* console.log(tf.util.now());
* ```
*/
/** @doc {heading: 'Util', namespace: 'util'} */
function now() {
return env().platform.now();
}
function assertNonNegativeIntegerDimensions(shape) {
shape.forEach(function (dimSize) {
assert(Number.isInteger(dimSize) && dimSize >= 0, function () {
return "Tensor must have a shape comprised of positive integers but got " +
("shape [" + shape + "].");
});
});
}
/**
* Returns a platform-specific implementation of
* [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
*
* If `fetch` is defined on the global object (`window`, `process`, etc.),
* `tf.util.fetch` returns that function.
*
* If not, `tf.util.fetch` returns a platform-specific solution.
*
* ```js
* const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs');
* // handle response
* ```
*/
/** @doc {heading: 'Util'} */
function fetch$1(path, requestInits) {
return env().platform.fetch(path, requestInits);
}
/**
* Encodes the provided string into bytes using the provided encoding scheme.
*
* @param s The string to encode.
* @param encoding The encoding scheme. Defaults to utf-8.
*
*/
/** @doc {heading: 'Util'} */
function encodeString(s, encoding) {
if (encoding === void 0) { encoding = 'utf-8'; }
encoding = encoding || 'utf-8';
return env().platform.encode(s, encoding);
}
/**
* Decodes the provided bytes into a string using the provided encoding scheme.
* @param bytes The bytes to decode.
*
* @param encoding The encoding scheme. Defaults to utf-8.
*/
/** @doc {heading: 'Util'} */
function decodeString(bytes, encoding) {
if (encoding === void 0) { encoding = 'utf-8'; }
encoding = encoding || 'utf-8';
return env().platform.decode(bytes, encoding);
}
/**
* Computes flat index for a given location (multidimentionsal index) in a
* Tensor/multidimensional array.
*
* @param locs Location in the tensor.
* @param rank Rank of the tensor.
* @param strides Tensor strides.
*/
function locToIndex(locs, rank, strides) {
if (rank === 0) {
return 0;
}
else if (rank === 1) {
return locs[0];
}
var index = locs[locs.length - 1];
for (var i = 0; i < locs.length - 1; ++i) {
index += strides[i] * locs[i];
}
return index;
}
/**
* Computes the location (multidimensional index) in a tensor/multidimentional
* array for a given flat index.
*
* @param index Index in flat array.
* @param rank Rank of tensor.
* @param strides Strides of tensor.
*/
function indexToLoc(index, rank, strides) {
if (rank === 0) {
return [];
}
else if (rank === 1) {
return [index];
}
var locs = new Array(rank);
for (var i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / strides[i]);
index -= locs[i] * strides[i];
}
locs[locs.length - 1] = index;
return locs;
}
var util = {
__proto__: null,
shuffle: shuffle,
clamp: clamp,
nearestLargerEven: nearestLargerEven,
sum: sum,
randUniform: randUniform,
distSquared: distSquared,
assert: assert,
assertShapesMatch: assertShapesMatch,
assertNonNull: assertNonNull,
flatten: flatten,
sizeFromShape: sizeFromShape,
isScalarShape: isScalarShape,
arraysEqual: arraysEqual,
isInt: isInt,
tanh: tanh,
sizeToSquarishShape: sizeToSquarishShape,
createShuffledIndices: createShuffledIndices,
rightPad: rightPad,
repeatedTry: repeatedTry,
inferFromImplicitShape: inferFromImplicitShape,
parseAxisParam: parseAxisParam,
squeezeShape: squeezeShape,
getTypedArrayFromDType: getTypedArrayFromDType,
getArrayFromDType: getArrayFromDType,
checkConversionForErrors: checkConversionForErrors,
isValidDtype: isValidDtype,
hasEncodingLoss: hasEncodingLoss,
isTypedArray: isTypedArray,
bytesPerElement: bytesPerElement,
bytesFromStringArray: bytesFromStringArray,
isString: isString,
isBoolean: isBoolean,
isNumber: isNumber,
inferDtype: inferDtype,
isFunction: isFunction,
nearestDivisor: nearestDivisor,
computeStrides: computeStrides,
toTypedArray: toTypedArray,
toNestedArray: toNestedArray,
makeOnesTypedArray: makeOnesTypedArray,
makeZerosTypedArray: makeZerosTypedArray,
makeZerosNestedTypedArray: makeZerosNestedTypedArray,
now: now,
assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions,
fetch: fetch$1,
encodeString: encodeString,
decodeString: decodeString,
locToIndex: locToIndex,
indexToLoc: indexToLoc
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Profiler = /** @class */ (function () {
function Profiler(backendTimer, logger) {
this.backendTimer = backendTimer;
this.logger = logger;
if (logger == null) {
this.logger = new Logger();
}
}
Profiler.prototype.profileKernel = function (kernelName, inputs, f) {
var outputs;
var holdResultWrapperFn = function () {
outputs = f();
};
var timer = this.backendTimer.time(holdResultWrapperFn);
outputs.map(function (r) {
// Dangling promise here because we don't want to propagate up
// asynchronicity.
r.data().then(function (tensorVals) {
checkComputationForErrors(tensorVals, r.dtype, kernelName);
});
});
var kernelProfile = {
kernelName: kernelName,
outputs: outputs,
inputs: inputs,
timeMs: timer.then(function (timing) { return timing.kernelMs; }),
extraInfo: timer.then(function (timing) { return timing.getExtraProfileInfo != null ?
timing.getExtraProfileInfo() :
''; })
};
return kernelProfile;
};
Profiler.prototype.logKernelProfile = function (kernelProfile) {
var _this = this;
var kernelName = kernelProfile.kernelName, outputs = kernelProfile.outputs, timeMs = kernelProfile.timeMs, inputs = kernelProfile.inputs, extraInfo = kernelProfile.extraInfo;
outputs.forEach(function (result) {
Promise.all([result.data(), timeMs, extraInfo]).then(function (valueContainer) {
_this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
});
});
};
return Profiler;
}());
function checkComputationForErrors(vals, dtype, kernelName) {
if (dtype !== 'float32') {
// Only floating point computations will generate NaN values
return false;
}
for (var i = 0; i < vals.length; i++) {