UNPKG

@tensorflow/tfjs-core

Version:

Hardware-accelerated JavaScript library for machine intelligence

1,520 lines (1,508 loc) 1.2 MB
/** * @license * Copyright 2024 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ 'use strict'; function _mergeNamespaces(n, m) { m.forEach(function (e) { e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) { if (k !== 'default' && !(k in n)) { var d = Object.getOwnPropertyDescriptor(e, k); Object.defineProperty(n, k, d.get ? d : { enumerable: true, get: function () { return e[k]; } }); } }); }); return n; } /****************************************************************************** Copyright (c) Microsoft Corporation. Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ***************************************************************************** */ /* global Reflect, Promise */ var extendStatics = function (d, b) { extendStatics = Object.setPrototypeOf || ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || function (d, b) { for (var p in b) if (Object.prototype.hasOwnProperty.call(b, p)) d[p] = b[p]; }; return extendStatics(d, b); }; function __extends(d, b) { if (typeof b !== "function" && b !== null) throw new TypeError("Class extends value " + String(b) + " is not a constructor or null"); extendStatics(d, b); function __() { this.constructor = d; } d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); } function __awaiter(thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); } function __generator(thisArg, body) { var _ = { label: 0, sent: function () { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function () { return this; }), g; function verb(n) { return function (v) { return step([n, v]); }; } function step(op) { if (f) throw new TypeError("Generator is already executing."); while (_) try { if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; if (y = 0, t) op = [op[0] & 2, t.value]; switch (op[0]) { case 0: case 1: t = op; break; case 4: _.label++; return { value: op[1], done: false }; case 5: _.label++; y = op[1]; op = [0]; continue; case 7: op = _.ops.pop(); _.trys.pop(); continue; default: if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } if (t[2]) _.ops.pop(); _.trys.pop(); continue; } op = body.call(thisArg, _); } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; } } function __values(o) { var s = typeof Symbol === "function" && Symbol.iterator, m = s && o[s], i = 0; if (m) return m.call(o); if (o && typeof o.length === "number") return { next: function () { if (o && i >= o.length) o = void 0; return { value: o && o[i++], done: !o }; } }; throw new TypeError(s ? "Object is not iterable." : "Symbol.iterator is not defined."); } function __read(o, n) { var m = typeof Symbol === "function" && o[Symbol.iterator]; if (!m) return o; var i = m.call(o), r, ar = [], e; try { while ((n === void 0 || n-- > 0) && !(r = i.next()).done) ar.push(r.value); } catch (error) { e = { error: error }; } finally { try { if (r && !r.done && (m = i["return"])) m.call(i); } finally { if (e) throw e.error; } } return ar; } function __spreadArray(to, from, pack) { if (pack || arguments.length === 2) for (var i = 0, l = from.length, ar; i < l; i++) { if (ar || !(i in from)) { if (!ar) ar = Array.prototype.slice.call(from, 0, i); ar[i] = from[i]; } } return to.concat(ar || Array.prototype.slice.call(from)); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var EPSILON_FLOAT32 = 1e-7; var EPSILON_FLOAT16 = 1e-4; /** Convenient class for storing tensor-related data. */ var DataStorage = /** @class */ (function () { function DataStorage(backend, dataMover) { this.backend = backend; this.dataMover = dataMover; this.data = new WeakMap(); this.dataIdsCount = 0; } DataStorage.prototype.get = function (dataId) { if (!this.data.has(dataId)) { this.dataMover.moveData(this.backend, dataId); } return this.data.get(dataId); }; DataStorage.prototype.set = function (dataId, value) { this.dataIdsCount++; this.data.set(dataId, value); }; DataStorage.prototype.has = function (dataId) { return this.data.has(dataId); }; DataStorage.prototype.delete = function (dataId) { this.dataIdsCount--; return this.data.delete(dataId); }; DataStorage.prototype.numDataIds = function () { return this.dataIdsCount; }; return DataStorage; }()); /** * The interface that defines the kernels that should be implemented when * adding a new backend. New backends don't need to implement every one of the * methods, this can be done gradually (throw an error for unimplemented * methods). */ var KernelBackend = /** @class */ (function () { function KernelBackend() { } KernelBackend.prototype.refCount = function (dataId) { return notYetImplemented('refCount'); }; KernelBackend.prototype.incRef = function (dataId) { return notYetImplemented('incRef'); }; KernelBackend.prototype.timerAvailable = function () { return true; }; KernelBackend.prototype.time = function (f) { return notYetImplemented('time'); }; KernelBackend.prototype.read = function (dataId) { return notYetImplemented('read'); }; KernelBackend.prototype.readSync = function (dataId) { return notYetImplemented('readSync'); }; KernelBackend.prototype.readToGPU = function (dataId, options) { return notYetImplemented('readToGPU'); }; KernelBackend.prototype.numDataIds = function () { return notYetImplemented('numDataIds'); }; KernelBackend.prototype.disposeData = function (dataId, force) { return notYetImplemented('disposeData'); }; KernelBackend.prototype.write = function (values, shape, dtype) { return notYetImplemented('write'); }; KernelBackend.prototype.move = function (dataId, values, shape, dtype, refCount) { return notYetImplemented('move'); }; KernelBackend.prototype.createTensorFromGPUData = function (values, shape, dtype) { return notYetImplemented('createTensorFromGPUData'); }; KernelBackend.prototype.memory = function () { return notYetImplemented('memory'); }; /** Returns the highest precision for floats in bits (e.g. 16 or 32) */ KernelBackend.prototype.floatPrecision = function () { return notYetImplemented('floatPrecision'); }; /** Returns the smallest representable number. */ KernelBackend.prototype.epsilon = function () { return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; }; KernelBackend.prototype.dispose = function () { return notYetImplemented('dispose'); }; return KernelBackend; }()); function notYetImplemented(kernelName) { throw new Error("'".concat(kernelName, "' not yet implemented or not found in the registry. ") + "This kernel may not be supported by the tfjs backend you have chosen"); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Shuffles the array in-place using Fisher-Yates algorithm. * * ```js * const a = [1, 2, 3, 4, 5]; * tf.util.shuffle(a); * console.log(a); * ``` * * @param array The array to shuffle in-place. * * @doc {heading: 'Util', namespace: 'util'} */ // tslint:disable-next-line:no-any function shuffle(array) { var counter = array.length; var index = 0; // While there are elements in the array while (counter > 0) { // Pick a random index index = (Math.random() * counter) | 0; // Decrease counter by 1 counter--; // And swap the last element with it swap(array, counter, index); } } /** * Shuffles two arrays in-place the same way using Fisher-Yates algorithm. * * ```js * const a = [1,2,3,4,5]; * const b = [11,22,33,44,55]; * tf.util.shuffleCombo(a, b); * console.log(a, b); * ``` * * @param array The first array to shuffle in-place. * @param array2 The second array to shuffle in-place with the same permutation * as the first array. * * @doc {heading: 'Util', namespace: 'util'} */ function shuffleCombo( // tslint:disable-next-line:no-any array, // tslint:disable-next-line:no-any array2) { if (array.length !== array2.length) { throw new Error("Array sizes must match to be shuffled together " + "First array length was ".concat(array.length) + "Second array length was ".concat(array2.length)); } var counter = array.length; var index = 0; // While there are elements in the array while (counter > 0) { // Pick a random index index = (Math.random() * counter) | 0; // Decrease counter by 1 counter--; // And swap the last element of each array with it swap(array, counter, index); swap(array2, counter, index); } } /** Clamps a value to a specified range. */ function clamp(min, x, max) { return Math.max(min, Math.min(x, max)); } function nearestLargerEven(val) { return val % 2 === 0 ? val : val + 1; } function swap(object, left, right) { var temp = object[left]; object[left] = object[right]; object[right] = temp; } function sum$1(arr) { var sum = 0; for (var i = 0; i < arr.length; i++) { sum += arr[i]; } return sum; } /** * Returns a sample from a uniform [a, b) distribution. * * @param a The minimum support (inclusive). * @param b The maximum support (exclusive). * @return A pseudorandom number on the half-open interval [a,b). */ function randUniform(a, b) { var r = Math.random(); return (b * r) + (1 - r) * a; } /** Returns the squared Euclidean distance between two vectors. */ function distSquared(a, b) { var result = 0; for (var i = 0; i < a.length; i++) { var diff = Number(a[i]) - Number(b[i]); result += diff * diff; } return result; } /** * Asserts that the expression is true. Otherwise throws an error with the * provided message. * * ```js * const x = 2; * tf.util.assert(x === 2, 'x is not 2'); * ``` * * @param expr The expression to assert (as a boolean). * @param msg A function that returns the message to report when throwing an * error. We use a function for performance reasons. * * @doc {heading: 'Util', namespace: 'util'} */ function assert(expr, msg) { if (!expr) { throw new Error(typeof msg === 'string' ? msg : msg()); } } function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) { if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; } assert(arraysEqual(shapeA, shapeB), function () { return errorMessagePrefix + " Shapes ".concat(shapeA, " and ").concat(shapeB, " must match"); }); } function assertNonNull(a) { assert(a != null, function () { return "The input to the tensor constructor must be a non-null value."; }); } /** * Returns the size (number of elements) of the tensor given its shape. * * ```js * const shape = [3, 4, 2]; * const size = tf.util.sizeFromShape(shape); * console.log(size); * ``` * * @doc {heading: 'Util', namespace: 'util'} */ function sizeFromShape(shape) { if (shape.length === 0) { // Scalar. return 1; } var size = shape[0]; for (var i = 1; i < shape.length; i++) { size *= shape[i]; } return size; } function isScalarShape(shape) { return shape.length === 0; } function arraysEqualWithNull(n1, n2) { if (n1 === n2) { return true; } if (n1 == null || n2 == null) { return false; } if (n1.length !== n2.length) { return false; } for (var i = 0; i < n1.length; i++) { if (n1[i] !== null && n2[i] !== null && n1[i] !== n2[i]) { return false; } } return true; } function arraysEqual(n1, n2) { if (n1 === n2) { return true; } if (n1 == null || n2 == null) { return false; } if (n1.length !== n2.length) { return false; } for (var i = 0; i < n1.length; i++) { if (n1[i] !== n2[i]) { return false; } } return true; } function isInt(a) { return a % 1 === 0; } function tanh$1(x) { // tslint:disable-next-line:no-any if (Math.tanh != null) { // tslint:disable-next-line:no-any return Math.tanh(x); } if (x === Infinity) { return 1; } else if (x === -Infinity) { return -1; } else { var e2x = Math.exp(2 * x); return (e2x - 1) / (e2x + 1); } } function sizeToSquarishShape(size) { var width = Math.ceil(Math.sqrt(size)); return [width, Math.ceil(size / width)]; } /** * Creates a new array with randomized indices to a given quantity. * * ```js * const randomTen = tf.util.createShuffledIndices(10); * console.log(randomTen); * ``` * * @param number Quantity of how many shuffled indices to create. * * @doc {heading: 'Util', namespace: 'util'} */ function createShuffledIndices(n) { var shuffledIndices = new Uint32Array(n); for (var i = 0; i < n; ++i) { shuffledIndices[i] = i; } shuffle(shuffledIndices); return shuffledIndices; } function rightPad(a, size) { if (size <= a.length) { return a; } return a + ' '.repeat(size - a.length); } function repeatedTry(checkFn, delayFn, maxCounter, scheduleFn) { if (delayFn === void 0) { delayFn = function (counter) { return 0; }; } return new Promise(function (resolve, reject) { var tryCount = 0; var tryFn = function () { if (checkFn()) { resolve(); return; } tryCount++; var nextBackoff = delayFn(tryCount); if (maxCounter != null && tryCount >= maxCounter) { reject(); return; } if (scheduleFn != null) { scheduleFn(tryFn, nextBackoff); } else { // google3 does not allow assigning another variable to setTimeout. // Don't refactor this so scheduleFn has a default value of setTimeout. setTimeout(tryFn, nextBackoff); } }; tryFn(); }); } /** * Given the full size of the array and a shape that may contain -1 as the * implicit dimension, returns the inferred shape where -1 is replaced. * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3]. * * @param shape The shape, which may contain -1 in some dimension. * @param size The full size (number of elements) of the array. * @return The inferred shape where -1 is replaced with the inferred size. */ function inferFromImplicitShape(shape, size) { var shapeProd = 1; var implicitIdx = -1; for (var i = 0; i < shape.length; ++i) { if (shape[i] >= 0) { shapeProd *= shape[i]; } else if (shape[i] === -1) { if (implicitIdx !== -1) { throw Error("Shapes can only have 1 implicit size. " + "Found -1 at dim ".concat(implicitIdx, " and dim ").concat(i)); } implicitIdx = i; } else if (shape[i] < 0) { throw Error("Shapes can not be < 0. Found ".concat(shape[i], " at dim ").concat(i)); } } if (implicitIdx === -1) { if (size > 0 && size !== shapeProd) { throw Error("Size(".concat(size, ") must match the product of shape ").concat(shape)); } return shape; } if (shapeProd === 0) { throw Error("Cannot infer the missing size in [".concat(shape, "] when ") + "there are 0 elements"); } if (size % shapeProd !== 0) { throw Error("The implicit shape can't be a fractional number. " + "Got ".concat(size, " / ").concat(shapeProd)); } var newShape = shape.slice(); newShape[implicitIdx] = size / shapeProd; return newShape; } function parseAxisParam(axis, shape) { var rank = shape.length; // Normalize input axis = axis == null ? shape.map(function (s, i) { return i; }) : [].concat(axis); // Check for valid range assert(axis.every(function (ax) { return ax >= -rank && ax < rank; }), function () { return "All values in axis param must be in range [-".concat(rank, ", ").concat(rank, ") but ") + "got axis ".concat(axis); }); // Check for only integers assert(axis.every(function (ax) { return isInt(ax); }), function () { return "All values in axis param must be integers but " + "got axis ".concat(axis); }); // Handle negative axis. return axis.map(function (a) { return a < 0 ? rank + a : a; }); } /** Reduces the shape by removing all dimensions of shape 1. */ function squeezeShape(shape, axis) { var newShape = []; var keptDims = []; var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; var axes = (axis == null || isEmptyArray) ? null : parseAxisParam(axis, shape).sort(); var j = 0; for (var i = 0; i < shape.length; ++i) { if (axes != null) { if (axes[j] === i && shape[i] !== 1) { throw new Error("Can't squeeze axis ".concat(i, " since its dim '").concat(shape[i], "' is not 1")); } if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { newShape.push(shape[i]); keptDims.push(i); } if (axes[j] <= i) { j++; } } if (shape[i] !== 1) { newShape.push(shape[i]); keptDims.push(i); } } return { newShape: newShape, keptDims: keptDims }; } function getTypedArrayFromDType(dtype, size) { return getArrayFromDType(dtype, size); } function getArrayFromDType(dtype, size) { var values = null; if (dtype == null || dtype === 'float32') { values = new Float32Array(size); } else if (dtype === 'int32') { values = new Int32Array(size); } else if (dtype === 'bool') { values = new Uint8Array(size); } else if (dtype === 'string') { values = new Array(size); } else { throw new Error("Unknown data type ".concat(dtype)); } return values; } function checkConversionForErrors(vals, dtype) { for (var i = 0; i < vals.length; i++) { var num = vals[i]; if (isNaN(num) || !isFinite(num)) { throw Error("A tensor of type ".concat(dtype, " being uploaded contains ").concat(num, ".")); } } } /** Returns true if the dtype is valid. */ function isValidDtype(dtype) { return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || dtype === 'int32' || dtype === 'string'; } /** * Returns true if the new type can't encode the old type without loss of * precision. */ function hasEncodingLoss(oldType, newType) { if (newType === 'complex64') { return false; } if (newType === 'float32' && oldType !== 'complex64') { return false; } if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') { return false; } if (newType === 'bool' && oldType === 'bool') { return false; } return true; } function bytesPerElement(dtype) { if (dtype === 'float32' || dtype === 'int32') { return 4; } else if (dtype === 'complex64') { return 8; } else if (dtype === 'bool') { return 1; } else { throw new Error("Unknown dtype ".concat(dtype)); } } /** * Returns the approximate number of bytes allocated in the string array - 2 * bytes per character. Computing the exact bytes for a native string in JS * is not possible since it depends on the encoding of the html page that * serves the website. */ function bytesFromStringArray(arr) { if (arr == null) { return 0; } var bytes = 0; arr.forEach(function (x) { return bytes += x.length; }); return bytes; } /** Returns true if the value is a string. */ function isString(value) { return typeof value === 'string' || value instanceof String; } function isBoolean(value) { return typeof value === 'boolean'; } function isNumber(value) { return typeof value === 'number'; } function inferDtype(values) { if (Array.isArray(values)) { return inferDtype(values[0]); } if (values instanceof Float32Array) { return 'float32'; } else if (values instanceof Int32Array || values instanceof Uint8Array || values instanceof Uint8ClampedArray) { return 'int32'; } else if (isNumber(values)) { return 'float32'; } else if (isString(values)) { return 'string'; } else if (isBoolean(values)) { return 'bool'; } return 'float32'; } function isFunction(f) { return !!(f && f.constructor && f.call && f.apply); } function nearestDivisor(size, start) { for (var i = start; i < size; ++i) { if (size % i === 0) { return i; } } return size; } function computeStrides(shape) { var rank = shape.length; if (rank < 2) { return []; } // Last dimension has implicit stride of 1, thus having D-1 (instead of D) // strides. var strides = new Array(rank - 1); strides[rank - 2] = shape[rank - 1]; for (var i = rank - 3; i >= 0; --i) { strides[i] = strides[i + 1] * shape[i + 1]; } return strides; } function createNestedArray(offset, shape, a, isComplex) { if (isComplex === void 0) { isComplex = false; } var ret = new Array(); if (shape.length === 1) { var d = shape[0] * (isComplex ? 2 : 1); for (var i = 0; i < d; i++) { ret[i] = a[offset + i]; } } else { var d = shape[0]; var rest = shape.slice(1); var len = rest.reduce(function (acc, c) { return acc * c; }) * (isComplex ? 2 : 1); for (var i = 0; i < d; i++) { ret[i] = createNestedArray(offset + i * len, rest, a, isComplex); } } return ret; } // Provide a nested array of TypedArray in given shape. function toNestedArray(shape, a, isComplex) { if (isComplex === void 0) { isComplex = false; } if (shape.length === 0) { // Scalar type should return a single number. return a[0]; } var size = shape.reduce(function (acc, c) { return acc * c; }) * (isComplex ? 2 : 1); if (size === 0) { // A tensor with shape zero should be turned into empty list. return []; } if (size !== a.length) { throw new Error("[".concat(shape, "] does not match the input size ").concat(a.length).concat(isComplex ? ' for a complex tensor' : '', ".")); } return createNestedArray(0, shape, a, isComplex); } function convertBackendValuesAndArrayBuffer(data, dtype) { // If is type Uint8Array[], return it directly. if (Array.isArray(data)) { return data; } if (dtype === 'float32') { return data instanceof Float32Array ? data : new Float32Array(data); } else if (dtype === 'int32') { return data instanceof Int32Array ? data : new Int32Array(data); } else if (dtype === 'bool' || dtype === 'string') { return Uint8Array.from(new Int32Array(data)); } else { throw new Error("Unknown dtype ".concat(dtype)); } } function makeOnesTypedArray(size, dtype) { var array = makeZerosTypedArray(size, dtype); for (var i = 0; i < array.length; i++) { array[i] = 1; } return array; } function makeZerosTypedArray(size, dtype) { if (dtype == null || dtype === 'float32' || dtype === 'complex64') { return new Float32Array(size); } else if (dtype === 'int32') { return new Int32Array(size); } else if (dtype === 'bool') { return new Uint8Array(size); } else { throw new Error("Unknown data type ".concat(dtype)); } } /** * Make nested `TypedArray` filled with zeros. * @param shape The shape information for the nested array. * @param dtype dtype of the array element. */ function makeZerosNestedTypedArray(shape, dtype) { var size = shape.reduce(function (prev, curr) { return prev * curr; }, 1); if (dtype == null || dtype === 'float32') { return toNestedArray(shape, new Float32Array(size)); } else if (dtype === 'int32') { return toNestedArray(shape, new Int32Array(size)); } else if (dtype === 'bool') { return toNestedArray(shape, new Uint8Array(size)); } else { throw new Error("Unknown data type ".concat(dtype)); } } function assertNonNegativeIntegerDimensions(shape) { shape.forEach(function (dimSize) { assert(Number.isInteger(dimSize) && dimSize >= 0, function () { return "Tensor must have a shape comprised of positive integers but got " + "shape [".concat(shape, "]."); }); }); } /** * Computes flat index for a given location (multidimentionsal index) in a * Tensor/multidimensional array. * * @param locs Location in the tensor. * @param rank Rank of the tensor. * @param strides Tensor strides. */ function locToIndex(locs, rank, strides) { if (rank === 0) { return 0; } else if (rank === 1) { return locs[0]; } var index = locs[locs.length - 1]; for (var i = 0; i < locs.length - 1; ++i) { index += strides[i] * locs[i]; } return index; } /** * Computes the location (multidimensional index) in a * tensor/multidimentional array for a given flat index. * * @param index Index in flat array. * @param rank Rank of tensor. * @param strides Strides of tensor. */ function indexToLoc(index, rank, strides) { if (rank === 0) { return []; } else if (rank === 1) { return [index]; } var locs = new Array(rank); for (var i = 0; i < locs.length - 1; ++i) { locs[i] = Math.floor(index / strides[i]); index -= locs[i] * strides[i]; } locs[locs.length - 1] = index; return locs; } /** * This method asserts whether an object is a Promise instance. * @param object */ // tslint:disable-next-line: no-any function isPromise(object) { // We chose to not use 'obj instanceOf Promise' for two reasons: // 1. It only reliably works for es6 Promise, not other Promise // implementations. // 2. It doesn't work with framework that uses zone.js. zone.js monkey // patch the async calls, so it is possible the obj (patched) is // comparing to a pre-patched Promise. return object && object.then && typeof object.then === 'function'; } // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; /** * The environment contains evaluated flags as well as the registered platform. * This is always used as a global singleton and can be retrieved with * `tf.env()`. * * @doc {heading: 'Environment'} */ var Environment = /** @class */ (function () { // tslint:disable-next-line: no-any function Environment(global) { this.global = global; this.flags = {}; this.flagRegistry = {}; this.urlFlags = {}; // Jasmine spies on this in 'environment_test.ts' this.getQueryParams = getQueryParams; this.populateURLFlags(); } Environment.prototype.setPlatform = function (platformName, platform) { if (this.platform != null) { if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn("Platform ".concat(this.platformName, " has already been set. ") + "Overwriting the platform with ".concat(platformName, ".")); } } this.platformName = platformName; this.platform = platform; }; Environment.prototype.registerFlag = function (flagName, evaluationFn, setHook) { this.flagRegistry[flagName] = { evaluationFn: evaluationFn, setHook: setHook }; // Override the flag value from the URL. This has to happen here because // the environment is initialized before flags get registered. if (this.urlFlags[flagName] != null) { var flagValue = this.urlFlags[flagName]; if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn("Setting feature override from URL ".concat(flagName, ": ").concat(flagValue, ".")); } this.set(flagName, flagValue); } }; Environment.prototype.getAsync = function (flagName) { return __awaiter(this, void 0, void 0, function () { var _a, _b; return __generator(this, function (_c) { switch (_c.label) { case 0: if (flagName in this.flags) { return [2 /*return*/, this.flags[flagName]]; } _a = this.flags; _b = flagName; return [4 /*yield*/, this.evaluateFlag(flagName)]; case 1: _a[_b] = _c.sent(); return [2 /*return*/, this.flags[flagName]]; } }); }); }; Environment.prototype.get = function (flagName) { if (flagName in this.flags) { return this.flags[flagName]; } var flagValue = this.evaluateFlag(flagName); if (isPromise(flagValue)) { throw new Error("Flag ".concat(flagName, " cannot be synchronously evaluated. ") + "Please use getAsync() instead."); } this.flags[flagName] = flagValue; return this.flags[flagName]; }; Environment.prototype.getNumber = function (flagName) { return this.get(flagName); }; Environment.prototype.getBool = function (flagName) { return this.get(flagName); }; Environment.prototype.getString = function (flagName) { return this.get(flagName); }; Environment.prototype.getFlags = function () { return this.flags; }; Object.defineProperty(Environment.prototype, "features", { // For backwards compatibility. get: function () { return this.flags; }, enumerable: false, configurable: true }); Environment.prototype.set = function (flagName, value) { if (this.flagRegistry[flagName] == null) { throw new Error("Cannot set flag ".concat(flagName, " as it has not been registered.")); } this.flags[flagName] = value; if (this.flagRegistry[flagName].setHook != null) { this.flagRegistry[flagName].setHook(value); } }; Environment.prototype.evaluateFlag = function (flagName) { if (this.flagRegistry[flagName] == null) { throw new Error("Cannot evaluate flag '".concat(flagName, "': no evaluation function found.")); } return this.flagRegistry[flagName].evaluationFn(); }; Environment.prototype.setFlags = function (flags) { this.flags = Object.assign({}, flags); }; Environment.prototype.reset = function () { this.flags = {}; this.urlFlags = {}; this.populateURLFlags(); }; Environment.prototype.populateURLFlags = function () { var _this = this; if (typeof this.global === 'undefined' || typeof this.global.location === 'undefined' || typeof this.global.location.search === 'undefined') { return; } var urlParams = this.getQueryParams(this.global.location.search); if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(','); keyValues.forEach(function (keyValue) { var _a = __read(keyValue.split(':'), 2), key = _a[0], value = _a[1]; _this.urlFlags[key] = parseValue(key, value); }); } }; return Environment; }()); function getQueryParams(queryString) { var params = {}; queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) { var t = []; for (var _i = 1; _i < arguments.length; _i++) { t[_i - 1] = arguments[_i]; } decodeParam(params, t[0], t[1]); return t.join('='); }); return params; } function decodeParam(params, name, value) { params[decodeURIComponent(name)] = decodeURIComponent(value || ''); } function parseValue(flagName, value) { var lowerCaseValue = value.toLowerCase(); if (lowerCaseValue === 'true' || lowerCaseValue === 'false') { return lowerCaseValue === 'true'; } else if ("".concat(+lowerCaseValue) === lowerCaseValue) { return +lowerCaseValue; } else { return value; } } /** * Returns the current environment (a global singleton). * * The environment object contains the evaluated feature values as well as the * active platform. * * @doc {heading: 'Environment'} */ function env() { return exports.ENV; } exports.ENV = null; function setEnvironmentGlobal(environment) { exports.ENV = environment; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Note that the identifier globalNameSpace is scoped to this module, but will // always resolve to the same global object regardless of how the module is // resolved. // tslint:disable-next-line:no-any var globalNameSpace; // tslint:disable-next-line:no-any function getGlobalNamespace() { if (globalNameSpace == null) { // tslint:disable-next-line:no-any var ns = void 0; if (typeof (window) !== 'undefined') { ns = window; } else if (typeof (global) !== 'undefined') { ns = global; } else if (typeof (process) !== 'undefined') { ns = process; } else if (typeof (self) !== 'undefined') { ns = self; } else { throw new Error('Could not find a global object'); } globalNameSpace = ns; } return globalNameSpace; } // tslint:disable-next-line:no-any function getGlobalMap() { var ns = getGlobalNamespace(); if (ns._tfGlobals == null) { ns._tfGlobals = new Map(); } return ns._tfGlobals; } /** * Returns a globally accessible 'singleton' object. * * @param key the name of the object * @param init a function to initialize to initialize this object * the first time it is fetched. */ function getGlobal(key, init) { var globalMap = getGlobalMap(); if (globalMap.has(key)) { return globalMap.get(key); } else { var singleton = init(); globalMap.set(key, singleton); return globalMap.get(key); } } var Abs = 'Abs'; var Acos = 'Acos'; var Acosh = 'Acosh'; var Add = 'Add'; var AddN = 'AddN'; var All = 'All'; var Any = 'Any'; var ArgMax = 'ArgMax'; var ArgMin = 'ArgMin'; var Asin = 'Asin'; var Asinh = 'Asinh'; var Atan = 'Atan'; var Atanh = 'Atanh'; var Atan2 = 'Atan2'; var AvgPool = 'AvgPool'; var AvgPoolGrad = 'AvgPoolGrad'; var AvgPool3D = 'AvgPool3D'; var AvgPool3DGrad = 'AvgPool3DGrad'; var BatchMatMul = 'BatchMatMul'; var BatchToSpaceND = 'BatchToSpaceND'; var Bincount = 'Bincount'; var BitwiseAnd = 'BitwiseAnd'; var BroadcastTo = 'BroadcastTo'; var BroadcastArgs = 'BroadcastArgs'; var Cast = 'Cast'; var Ceil = 'Ceil'; var ClipByValue = 'ClipByValue'; var Complex = 'Complex'; var ComplexAbs = 'ComplexAbs'; var Concat = 'Concat'; var Conv2D = 'Conv2D'; var Conv2DBackpropFilter = 'Conv2DBackpropFilter'; var Conv2DBackpropInput = 'Conv2DBackpropInput'; var Conv3D = 'Conv3D'; var Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2'; var Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2'; var Cos = 'Cos'; var Cosh = 'Cosh'; var Cumprod = 'Cumprod'; var Cumsum = 'Cumsum'; var CropAndResize = 'CropAndResize'; var DenseBincount = 'DenseBincount'; var DepthToSpace = 'DepthToSpace'; var DepthwiseConv2dNative = 'DepthwiseConv2dNative'; var DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter'; var DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput'; var Diag = 'Diag'; var Dilation2D = 'Dilation2D'; var Dilation2DBackpropInput = 'Dilation2DBackpropInput'; var Dilation2DBackpropFilter = 'Dilation2DBackpropFilter'; var Draw = 'Draw'; var RealDiv = 'RealDiv'; var Einsum = 'Einsum'; var Elu = 'Elu'; var EluGrad = 'EluGrad'; var Erf = 'Erf'; var Equal = 'Equal'; var Exp = 'Exp'; var ExpandDims = 'ExpandDims'; var Expm1 = 'Expm1'; var FFT = 'FFT'; var Fill = 'Fill'; var FlipLeftRight = 'FlipLeftRight'; var Floor = 'Floor'; var FloorDiv = 'FloorDiv'; var FusedBatchNorm = 'FusedBatchNorm'; var GatherV2 = 'GatherV2'; var GatherNd = 'GatherNd'; var Greater = 'Greater'; var GreaterEqual = 'GreaterEqual'; var Identity = 'Identity'; var IFFT = 'IFFT'; var Imag = 'Imag'; var IsFinite = 'IsFinite'; var IsInf = 'IsInf'; var IsNan = 'IsNan'; var LeakyRelu = 'LeakyRelu'; var Less = 'Less'; var LessEqual = 'LessEqual'; var LinSpace = 'LinSpace'; var Log = 'Log'; var Log1p = 'Log1p'; var LogicalAnd = 'LogicalAnd'; var LogicalNot = 'LogicalNot'; var LogicalOr = 'LogicalOr'; var LogicalXor = 'LogicalXor'; var LogSoftmax = 'LogSoftmax'; var LowerBound = 'LowerBound'; var LRN = 'LRN'; var LRNGrad = 'LRNGrad'; var MatrixBandPart = 'MatrixBandPart'; var Max = 'Max'; var Maximum = 'Maximum'; var MaxPool = 'MaxPool'; var MaxPoolGrad = 'MaxPoolGrad'; var MaxPool3D = 'MaxPool3D'; var MaxPool3DGrad = 'MaxPool3DGrad'; var MaxPoolWithArgmax = 'MaxPoolWithArgmax'; var Mean = 'Mean'; var Min = 'Min'; var Minimum = 'Minimum'; var MirrorPad = 'MirrorPad'; var Mod = 'Mod'; var Multinomial = 'Multinomial'; var Multiply = 'Multiply'; var Neg = 'Neg'; var NotEqual = 'NotEqual'; var NonMaxSuppressionV3 = 'NonMaxSuppressionV3'; var NonMaxSuppressionV4 = 'NonMaxSuppressionV4'; var NonMaxSuppressionV5 = 'NonMaxSuppressionV5'; var OnesLike = 'OnesLike'; var OneHot = 'OneHot'; var Pack = 'Pack'; var PadV2 = 'PadV2'; var Pool = 'Pool'; var Pow = 'Pow'; var Prelu = 'Prelu'; var Prod = 'Prod'; var RaggedGather = 'RaggedGather'; var RaggedRange = 'RaggedRange'; var RaggedTensorToTensor = 'RaggedTensorToTensor'; var Range = 'Range'; var Real = 'Real'; var Reciprocal = 'Reciprocal'; var Relu = 'Relu'; var Reshape = 'Reshape'; var ResizeNearestNeighbor = 'ResizeNearestNeighbor'; var ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad'; var ResizeBilinear = 'ResizeBilinear'; var ResizeBilinearGrad = 'ResizeBilinearGrad'; var Relu6 = 'Relu6'; var Reverse = 'Reverse'; var Round = 'Round'; var Rsqrt = 'Rsqrt'; var ScatterNd = 'ScatterNd'; var TensorScatterUpdate = 'TensorScatterUpdate'; var SearchSorted = 'SearchSorted'; var Select = 'Select'; var Selu = 'Selu'; var Slice = 'Slice'; var Sin = 'Sin'; var Sinh = 'Sinh'; var Sign = 'Sign'; var Sigmoid = 'Sigmoid'; var Softplus = 'Softplus'; var Sqrt = 'Sqrt'; var Sum = 'Sum'; var SpaceToBatchND = 'SpaceToBatchND'; var SplitV = 'SplitV'; var Softmax = 'Softmax'; var SparseFillEmptyRows = 'SparseFillEmptyRows'; var SparseReshape = 'SparseReshape'; var SparseSegmentMean = 'SparseSegmentMean'; var SparseSegmentSum = 'SparseSegmentSum'; var SparseToDense = 'SparseToDense'; var SquaredDifference = 'SquaredDifference'; var Square = 'Square'; var StaticRegexReplace = 'StaticRegexReplace'; var StridedSlice = 'StridedSlice'; var StringNGrams = 'StringNGrams'; var StringSplit = 'StringSplit'; var StringToHashBucketFast = 'StringToHashBucketFast'; var Sub = 'Sub'; var Tan = 'Tan'; var Tanh = 'Tanh'; var Tile = 'Tile'; var TopK = 'TopK'; var Transform = 'Transform'; var Transpose = 'Transpose'; var Unique = 'Unique'; var Unpack = 'Unpack'; var UnsortedSegmentSum = 'UnsortedSegmentSum'; var UpperBound = 'UpperBound'; var ZerosLike = 'ZerosLike'; /** * TensorFlow.js-only kernels */ var Step = 'Step'; var FromPixels = 'FromPixels'; var RotateWithOffset = 'RotateWithOffset'; var _FusedMatMul = '_FusedMatMul'; var FusedConv2D = 'FusedConv2D'; var FusedDepthwiseConv2D = 'FusedDepthwiseConv2D'; function warn() { var msg = []; for (var _i = 0; _i < arguments.length; _i++) { msg[_i] = arguments[_i]; } if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn.apply(console, __spreadArray([], __read(msg), false)); } } function log$1() { var msg = []; for (var _i = 0; _i < arguments.length; _i++) { msg[_i] = arguments[_i]; } if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.log.apply(console, __spreadArray([], __read(msg), false)); } } var kernelRegistry = getGlobal('kernelRegistry', function () { return new Map(); }); var gradRegistry = getGlobal('gradRegistry', function () { return new Map(); }); /** * Returns the kernel function (code) associated with the provided names. * * @param kernelName The official name of the kernel. * @param backendName The official name of the backend. */ function getKernel(kernelName, backendName) { var key = makeKey(kernelName, backendName); return kernelRegistry.get(key); } /** * Returns the registered gradient info associated with the provided kernel. * @param kernelName The official TF kernel name. */ function getGradient(kernelName) { return gradRegistry.get(kernelName); } function getKernelsForBackend(backendName) { var it = kernelRegistry.entries(); var result = []; while (true) { var _a = it.next(), done = _a.done, value = _a.value; if (done) { break; } var _b = __read(value, 2), key = _b[0], config = _b[1]; var _c = __read(key.split('_'), 1), backend = _c[0]; if (backend === backendName) { result.push(config); } } return result; } /** * Registers the function (forward pass) for the kernel in a global registry. * * @param config A config object with the following properties: * - `kernelName` The official name of the kernel. * - `backendName` The official name of the backend. * - `kernelFunc` The function to run during the forward pass of the kernel. * - `setupFunc` Optional. Gets called once, after the backend initializes. * - `disposeFunc` Optional. Gets called once, right before the backend is * disposed. */ function registerKernel(config) { var kernelName = config.kernelName, backendName = config.backendName; var key = makeKey(kernelName, backendName); if (kernelRegistry.has(key)) { warn("The kernel '".concat(kernelName, "' for backend ") + "'".concat(backendName, "' is already registered")); } kernelRegistry.set(key, config); } /** * Registers a gradient function for a given kernel in the global registry, * to be used during the back-propagation of that kernel. * * @param config An object with the following properties: * - `kernelName` The name of the kernel that the gradient function is for. * - `gradFunc` The function to run during back-propagation. */ function registerGradient(config) { var kernelName = config.kernelName; if (gradRegistry.has(kernelName)) { // TODO (yassogba) after 3.0 assess whether we need to keep this gated // to debug mode. if (env().getBool('DEBUG')) { warn("Overriding the gradient for '".concat(kernelName, "'")); } } gradRegistry.set(kernelName, config); } /** * Removes the kernel function from the registry. * * @param kernelName The official name of the kernel. * @param backendName The official name of the backend. * */ function unregisterKernel(kernelName, backendName) {