UNPKG

@tensorflow/tfjs

Version:

An open-source machine learning framework.

1,553 lines (1,526 loc) 4.02 MB
/** * @license * Copyright 2024 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function _mergeNamespaces(n, m) { m.forEach(function (e) { e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) { if (k !== 'default' && !(k in n)) { var d = Object.getOwnPropertyDescriptor(e, k); Object.defineProperty(n, k, d.get ? d : { enumerable: true, get: function () { return e[k]; } }); } }); }); return Object.freeze(n); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const EPSILON_FLOAT32$1 = 1e-7; const EPSILON_FLOAT16$1 = 1e-4; /** Convenient class for storing tensor-related data. */ class DataStorage { constructor(backend, dataMover) { this.backend = backend; this.dataMover = dataMover; this.data = new WeakMap(); this.dataIdsCount = 0; } get(dataId) { if (!this.data.has(dataId)) { this.dataMover.moveData(this.backend, dataId); } return this.data.get(dataId); } set(dataId, value) { this.dataIdsCount++; this.data.set(dataId, value); } has(dataId) { return this.data.has(dataId); } delete(dataId) { this.dataIdsCount--; return this.data.delete(dataId); } numDataIds() { return this.dataIdsCount; } } /** * The interface that defines the kernels that should be implemented when * adding a new backend. New backends don't need to implement every one of the * methods, this can be done gradually (throw an error for unimplemented * methods). */ class KernelBackend { refCount(dataId) { return notYetImplemented('refCount'); } incRef(dataId) { return notYetImplemented('incRef'); } timerAvailable() { return true; } time(f) { return notYetImplemented('time'); } read(dataId) { return notYetImplemented('read'); } readSync(dataId) { return notYetImplemented('readSync'); } readToGPU(dataId, options) { return notYetImplemented('readToGPU'); } numDataIds() { return notYetImplemented('numDataIds'); } disposeData(dataId, force) { return notYetImplemented('disposeData'); } write(values, shape, dtype) { return notYetImplemented('write'); } move(dataId, values, shape, dtype, refCount) { return notYetImplemented('move'); } createTensorFromGPUData(values, shape, dtype) { return notYetImplemented('createTensorFromGPUData'); } memory() { return notYetImplemented('memory'); } /** Returns the highest precision for floats in bits (e.g. 16 or 32) */ floatPrecision() { return notYetImplemented('floatPrecision'); } /** Returns the smallest representable number. */ epsilon() { return this.floatPrecision() === 32 ? EPSILON_FLOAT32$1 : EPSILON_FLOAT16$1; } dispose() { return notYetImplemented('dispose'); } } function notYetImplemented(kernelName) { throw new Error(`'${kernelName}' not yet implemented or not found in the registry. ` + `This kernel may not be supported by the tfjs backend you have chosen`); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Shuffles the array in-place using Fisher-Yates algorithm. * * ```js * const a = [1, 2, 3, 4, 5]; * tf.util.shuffle(a); * console.log(a); * ``` * * @param array The array to shuffle in-place. * * @doc {heading: 'Util', namespace: 'util'} */ // tslint:disable-next-line:no-any function shuffle(array) { let counter = array.length; let index = 0; // While there are elements in the array while (counter > 0) { // Pick a random index index = (Math.random() * counter) | 0; // Decrease counter by 1 counter--; // And swap the last element with it swap(array, counter, index); } } /** * Shuffles two arrays in-place the same way using Fisher-Yates algorithm. * * ```js * const a = [1,2,3,4,5]; * const b = [11,22,33,44,55]; * tf.util.shuffleCombo(a, b); * console.log(a, b); * ``` * * @param array The first array to shuffle in-place. * @param array2 The second array to shuffle in-place with the same permutation * as the first array. * * @doc {heading: 'Util', namespace: 'util'} */ function shuffleCombo( // tslint:disable-next-line:no-any array, // tslint:disable-next-line:no-any array2) { if (array.length !== array2.length) { throw new Error(`Array sizes must match to be shuffled together ` + `First array length was ${array.length}` + `Second array length was ${array2.length}`); } let counter = array.length; let index = 0; // While there are elements in the array while (counter > 0) { // Pick a random index index = (Math.random() * counter) | 0; // Decrease counter by 1 counter--; // And swap the last element of each array with it swap(array, counter, index); swap(array2, counter, index); } } /** Clamps a value to a specified range. */ function clamp(min, x, max) { return Math.max(min, Math.min(x, max)); } function nearestLargerEven(val) { return val % 2 === 0 ? val : val + 1; } function swap(object, left, right) { const temp = object[left]; object[left] = object[right]; object[right] = temp; } function sum$4(arr) { let sum = 0; for (let i = 0; i < arr.length; i++) { sum += arr[i]; } return sum; } /** * Returns a sample from a uniform [a, b) distribution. * * @param a The minimum support (inclusive). * @param b The maximum support (exclusive). * @return A pseudorandom number on the half-open interval [a,b). */ function randUniform(a, b) { const r = Math.random(); return (b * r) + (1 - r) * a; } /** Returns the squared Euclidean distance between two vectors. */ function distSquared(a, b) { let result = 0; for (let i = 0; i < a.length; i++) { const diff = Number(a[i]) - Number(b[i]); result += diff * diff; } return result; } /** * Asserts that the expression is true. Otherwise throws an error with the * provided message. * * ```js * const x = 2; * tf.util.assert(x === 2, 'x is not 2'); * ``` * * @param expr The expression to assert (as a boolean). * @param msg A function that returns the message to report when throwing an * error. We use a function for performance reasons. * * @doc {heading: 'Util', namespace: 'util'} */ function assert$1(expr, msg) { if (!expr) { throw new Error(typeof msg === 'string' ? msg : msg()); } } function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') { assert$1(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); } function assertNonNull(a) { assert$1(a != null, () => `The input to the tensor constructor must be a non-null value.`); } /** * Returns the size (number of elements) of the tensor given its shape. * * ```js * const shape = [3, 4, 2]; * const size = tf.util.sizeFromShape(shape); * console.log(size); * ``` * * @doc {heading: 'Util', namespace: 'util'} */ function sizeFromShape(shape) { if (shape.length === 0) { // Scalar. return 1; } let size = shape[0]; for (let i = 1; i < shape.length; i++) { size *= shape[i]; } return size; } function isScalarShape(shape) { return shape.length === 0; } function arraysEqualWithNull(n1, n2) { if (n1 === n2) { return true; } if (n1 == null || n2 == null) { return false; } if (n1.length !== n2.length) { return false; } for (let i = 0; i < n1.length; i++) { if (n1[i] !== null && n2[i] !== null && n1[i] !== n2[i]) { return false; } } return true; } function arraysEqual(n1, n2) { if (n1 === n2) { return true; } if (n1 == null || n2 == null) { return false; } if (n1.length !== n2.length) { return false; } for (let i = 0; i < n1.length; i++) { if (n1[i] !== n2[i]) { return false; } } return true; } function isInt(a) { return a % 1 === 0; } function tanh$3(x) { // tslint:disable-next-line:no-any if (Math.tanh != null) { // tslint:disable-next-line:no-any return Math.tanh(x); } if (x === Infinity) { return 1; } else if (x === -Infinity) { return -1; } else { const e2x = Math.exp(2 * x); return (e2x - 1) / (e2x + 1); } } function sizeToSquarishShape(size) { const width = Math.ceil(Math.sqrt(size)); return [width, Math.ceil(size / width)]; } /** * Creates a new array with randomized indices to a given quantity. * * ```js * const randomTen = tf.util.createShuffledIndices(10); * console.log(randomTen); * ``` * * @param number Quantity of how many shuffled indices to create. * * @doc {heading: 'Util', namespace: 'util'} */ function createShuffledIndices(n) { const shuffledIndices = new Uint32Array(n); for (let i = 0; i < n; ++i) { shuffledIndices[i] = i; } shuffle(shuffledIndices); return shuffledIndices; } function rightPad(a, size) { if (size <= a.length) { return a; } return a + ' '.repeat(size - a.length); } function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter, scheduleFn) { return new Promise((resolve, reject) => { let tryCount = 0; const tryFn = () => { if (checkFn()) { resolve(); return; } tryCount++; const nextBackoff = delayFn(tryCount); if (maxCounter != null && tryCount >= maxCounter) { reject(); return; } if (scheduleFn != null) { scheduleFn(tryFn, nextBackoff); } else { // google3 does not allow assigning another variable to setTimeout. // Don't refactor this so scheduleFn has a default value of setTimeout. setTimeout(tryFn, nextBackoff); } }; tryFn(); }); } /** * Given the full size of the array and a shape that may contain -1 as the * implicit dimension, returns the inferred shape where -1 is replaced. * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3]. * * @param shape The shape, which may contain -1 in some dimension. * @param size The full size (number of elements) of the array. * @return The inferred shape where -1 is replaced with the inferred size. */ function inferFromImplicitShape(shape, size) { let shapeProd = 1; let implicitIdx = -1; for (let i = 0; i < shape.length; ++i) { if (shape[i] >= 0) { shapeProd *= shape[i]; } else if (shape[i] === -1) { if (implicitIdx !== -1) { throw Error(`Shapes can only have 1 implicit size. ` + `Found -1 at dim ${implicitIdx} and dim ${i}`); } implicitIdx = i; } else if (shape[i] < 0) { throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`); } } if (implicitIdx === -1) { if (size > 0 && size !== shapeProd) { throw Error(`Size(${size}) must match the product of shape ${shape}`); } return shape; } if (shapeProd === 0) { throw Error(`Cannot infer the missing size in [${shape}] when ` + `there are 0 elements`); } if (size % shapeProd !== 0) { throw Error(`The implicit shape can't be a fractional number. ` + `Got ${size} / ${shapeProd}`); } const newShape = shape.slice(); newShape[implicitIdx] = size / shapeProd; return newShape; } function parseAxisParam(axis, shape) { const rank = shape.length; // Normalize input axis = axis == null ? shape.map((s, i) => i) : [].concat(axis); // Check for valid range assert$1(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` + `got axis ${axis}`); // Check for only integers assert$1(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` + `got axis ${axis}`); // Handle negative axis. return axis.map(a => a < 0 ? rank + a : a); } /** Reduces the shape by removing all dimensions of shape 1. */ function squeezeShape(shape, axis) { const newShape = []; const keptDims = []; const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; const axes = (axis == null || isEmptyArray) ? null : parseAxisParam(axis, shape).sort(); let j = 0; for (let i = 0; i < shape.length; ++i) { if (axes != null) { if (axes[j] === i && shape[i] !== 1) { throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`); } if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { newShape.push(shape[i]); keptDims.push(i); } if (axes[j] <= i) { j++; } } if (shape[i] !== 1) { newShape.push(shape[i]); keptDims.push(i); } } return { newShape, keptDims }; } function getTypedArrayFromDType(dtype, size) { return getArrayFromDType(dtype, size); } function getArrayFromDType(dtype, size) { let values = null; if (dtype == null || dtype === 'float32') { values = new Float32Array(size); } else if (dtype === 'int32') { values = new Int32Array(size); } else if (dtype === 'bool') { values = new Uint8Array(size); } else if (dtype === 'string') { values = new Array(size); } else { throw new Error(`Unknown data type ${dtype}`); } return values; } function checkConversionForErrors(vals, dtype) { for (let i = 0; i < vals.length; i++) { const num = vals[i]; if (isNaN(num) || !isFinite(num)) { throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`); } } } /** Returns true if the dtype is valid. */ function isValidDtype(dtype) { return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || dtype === 'int32' || dtype === 'string'; } /** * Returns true if the new type can't encode the old type without loss of * precision. */ function hasEncodingLoss(oldType, newType) { if (newType === 'complex64') { return false; } if (newType === 'float32' && oldType !== 'complex64') { return false; } if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') { return false; } if (newType === 'bool' && oldType === 'bool') { return false; } return true; } function bytesPerElement(dtype) { if (dtype === 'float32' || dtype === 'int32') { return 4; } else if (dtype === 'complex64') { return 8; } else if (dtype === 'bool') { return 1; } else { throw new Error(`Unknown dtype ${dtype}`); } } /** * Returns the approximate number of bytes allocated in the string array - 2 * bytes per character. Computing the exact bytes for a native string in JS * is not possible since it depends on the encoding of the html page that * serves the website. */ function bytesFromStringArray(arr) { if (arr == null) { return 0; } let bytes = 0; arr.forEach(x => bytes += x.length); return bytes; } /** Returns true if the value is a string. */ function isString(value) { return typeof value === 'string' || value instanceof String; } function isBoolean(value) { return typeof value === 'boolean'; } function isNumber(value) { return typeof value === 'number'; } function inferDtype(values) { if (Array.isArray(values)) { return inferDtype(values[0]); } if (values instanceof Float32Array) { return 'float32'; } else if (values instanceof Int32Array || values instanceof Uint8Array || values instanceof Uint8ClampedArray) { return 'int32'; } else if (isNumber(values)) { return 'float32'; } else if (isString(values)) { return 'string'; } else if (isBoolean(values)) { return 'bool'; } return 'float32'; } function isFunction(f) { return !!(f && f.constructor && f.call && f.apply); } function nearestDivisor(size, start) { for (let i = start; i < size; ++i) { if (size % i === 0) { return i; } } return size; } function computeStrides(shape) { const rank = shape.length; if (rank < 2) { return []; } // Last dimension has implicit stride of 1, thus having D-1 (instead of D) // strides. const strides = new Array(rank - 1); strides[rank - 2] = shape[rank - 1]; for (let i = rank - 3; i >= 0; --i) { strides[i] = strides[i + 1] * shape[i + 1]; } return strides; } function createNestedArray(offset, shape, a, isComplex = false) { const ret = new Array(); if (shape.length === 1) { const d = shape[0] * (isComplex ? 2 : 1); for (let i = 0; i < d; i++) { ret[i] = a[offset + i]; } } else { const d = shape[0]; const rest = shape.slice(1); const len = rest.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1); for (let i = 0; i < d; i++) { ret[i] = createNestedArray(offset + i * len, rest, a, isComplex); } } return ret; } // Provide a nested array of TypedArray in given shape. function toNestedArray(shape, a, isComplex = false) { if (shape.length === 0) { // Scalar type should return a single number. return a[0]; } const size = shape.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1); if (size === 0) { // A tensor with shape zero should be turned into empty list. return []; } if (size !== a.length) { throw new Error(`[${shape}] does not match the input size ${a.length}${isComplex ? ' for a complex tensor' : ''}.`); } return createNestedArray(0, shape, a, isComplex); } function convertBackendValuesAndArrayBuffer(data, dtype) { // If is type Uint8Array[], return it directly. if (Array.isArray(data)) { return data; } if (dtype === 'float32') { return data instanceof Float32Array ? data : new Float32Array(data); } else if (dtype === 'int32') { return data instanceof Int32Array ? data : new Int32Array(data); } else if (dtype === 'bool' || dtype === 'string') { return Uint8Array.from(new Int32Array(data)); } else { throw new Error(`Unknown dtype ${dtype}`); } } function makeOnesTypedArray(size, dtype) { const array = makeZerosTypedArray(size, dtype); for (let i = 0; i < array.length; i++) { array[i] = 1; } return array; } function makeZerosTypedArray(size, dtype) { if (dtype == null || dtype === 'float32' || dtype === 'complex64') { return new Float32Array(size); } else if (dtype === 'int32') { return new Int32Array(size); } else if (dtype === 'bool') { return new Uint8Array(size); } else { throw new Error(`Unknown data type ${dtype}`); } } /** * Make nested `TypedArray` filled with zeros. * @param shape The shape information for the nested array. * @param dtype dtype of the array element. */ function makeZerosNestedTypedArray(shape, dtype) { const size = shape.reduce((prev, curr) => prev * curr, 1); if (dtype == null || dtype === 'float32') { return toNestedArray(shape, new Float32Array(size)); } else if (dtype === 'int32') { return toNestedArray(shape, new Int32Array(size)); } else if (dtype === 'bool') { return toNestedArray(shape, new Uint8Array(size)); } else { throw new Error(`Unknown data type ${dtype}`); } } function assertNonNegativeIntegerDimensions(shape) { shape.forEach(dimSize => { assert$1(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` + `shape [${shape}].`); }); } /** * Computes flat index for a given location (multidimentionsal index) in a * Tensor/multidimensional array. * * @param locs Location in the tensor. * @param rank Rank of the tensor. * @param strides Tensor strides. */ function locToIndex(locs, rank, strides) { if (rank === 0) { return 0; } else if (rank === 1) { return locs[0]; } let index = locs[locs.length - 1]; for (let i = 0; i < locs.length - 1; ++i) { index += strides[i] * locs[i]; } return index; } /** * Computes the location (multidimensional index) in a * tensor/multidimentional array for a given flat index. * * @param index Index in flat array. * @param rank Rank of tensor. * @param strides Strides of tensor. */ function indexToLoc(index, rank, strides) { if (rank === 0) { return []; } else if (rank === 1) { return [index]; } const locs = new Array(rank); for (let i = 0; i < locs.length - 1; ++i) { locs[i] = Math.floor(index / strides[i]); index -= locs[i] * strides[i]; } locs[locs.length - 1] = index; return locs; } /** * This method asserts whether an object is a Promise instance. * @param object */ // tslint:disable-next-line: no-any function isPromise(object) { // We chose to not use 'obj instanceOf Promise' for two reasons: // 1. It only reliably works for es6 Promise, not other Promise // implementations. // 2. It doesn't work with framework that uses zone.js. zone.js monkey // patch the async calls, so it is possible the obj (patched) is // comparing to a pre-patched Promise. return object && object.then && typeof object.then === 'function'; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; /** * The environment contains evaluated flags as well as the registered platform. * This is always used as a global singleton and can be retrieved with * `tf.env()`. * * @doc {heading: 'Environment'} */ class Environment { // tslint:disable-next-line: no-any constructor(global) { this.global = global; this.flags = {}; this.flagRegistry = {}; this.urlFlags = {}; // Jasmine spies on this in 'environment_test.ts' this.getQueryParams = getQueryParams; this.populateURLFlags(); } setPlatform(platformName, platform) { if (this.platform != null) { if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn(`Platform ${this.platformName} has already been set. ` + `Overwriting the platform with ${platformName}.`); } } this.platformName = platformName; this.platform = platform; } registerFlag(flagName, evaluationFn, setHook) { this.flagRegistry[flagName] = { evaluationFn, setHook }; // Override the flag value from the URL. This has to happen here because // the environment is initialized before flags get registered. if (this.urlFlags[flagName] != null) { const flagValue = this.urlFlags[flagName]; if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`); } this.set(flagName, flagValue); } } async getAsync(flagName) { if (flagName in this.flags) { return this.flags[flagName]; } this.flags[flagName] = await this.evaluateFlag(flagName); return this.flags[flagName]; } get(flagName) { if (flagName in this.flags) { return this.flags[flagName]; } const flagValue = this.evaluateFlag(flagName); if (isPromise(flagValue)) { throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` + `Please use getAsync() instead.`); } this.flags[flagName] = flagValue; return this.flags[flagName]; } getNumber(flagName) { return this.get(flagName); } getBool(flagName) { return this.get(flagName); } getString(flagName) { return this.get(flagName); } getFlags() { return this.flags; } // For backwards compatibility. get features() { return this.flags; } set(flagName, value) { if (this.flagRegistry[flagName] == null) { throw new Error(`Cannot set flag ${flagName} as it has not been registered.`); } this.flags[flagName] = value; if (this.flagRegistry[flagName].setHook != null) { this.flagRegistry[flagName].setHook(value); } } evaluateFlag(flagName) { if (this.flagRegistry[flagName] == null) { throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`); } return this.flagRegistry[flagName].evaluationFn(); } setFlags(flags) { this.flags = Object.assign({}, flags); } reset() { this.flags = {}; this.urlFlags = {}; this.populateURLFlags(); } populateURLFlags() { if (typeof this.global === 'undefined' || typeof this.global.location === 'undefined' || typeof this.global.location.search === 'undefined') { return; } const urlParams = this.getQueryParams(this.global.location.search); if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(','); keyValues.forEach(keyValue => { const [key, value] = keyValue.split(':'); this.urlFlags[key] = parseValue(key, value); }); } } } function getQueryParams(queryString) { const params = {}; queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => { decodeParam(params, t[0], t[1]); return t.join('='); }); return params; } function decodeParam(params, name, value) { params[decodeURIComponent(name)] = decodeURIComponent(value || ''); } function parseValue(flagName, value) { const lowerCaseValue = value.toLowerCase(); if (lowerCaseValue === 'true' || lowerCaseValue === 'false') { return lowerCaseValue === 'true'; } else if (`${+lowerCaseValue}` === lowerCaseValue) { return +lowerCaseValue; } else { return value; } } /** * Returns the current environment (a global singleton). * * The environment object contains the evaluated feature values as well as the * active platform. * * @doc {heading: 'Environment'} */ function env() { return ENV$4; } let ENV$4 = null; function setEnvironmentGlobal(environment) { ENV$4 = environment; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Note that the identifier globalNameSpace is scoped to this module, but will // always resolve to the same global object regardless of how the module is // resolved. // tslint:disable-next-line:no-any let globalNameSpace; // tslint:disable-next-line:no-any function getGlobalNamespace() { if (globalNameSpace == null) { // tslint:disable-next-line:no-any let ns; if (typeof (window) !== 'undefined') { ns = window; } else if (typeof (global) !== 'undefined') { ns = global; } else if (typeof (process) !== 'undefined') { ns = process; } else if (typeof (self) !== 'undefined') { ns = self; } else { throw new Error('Could not find a global object'); } globalNameSpace = ns; } return globalNameSpace; } // tslint:disable-next-line:no-any function getGlobalMap() { const ns = getGlobalNamespace(); if (ns._tfGlobals == null) { ns._tfGlobals = new Map(); } return ns._tfGlobals; } /** * Returns a globally accessible 'singleton' object. * * @param key the name of the object * @param init a function to initialize to initialize this object * the first time it is fetched. */ function getGlobal(key, init) { const globalMap = getGlobalMap(); if (globalMap.has(key)) { return globalMap.get(key); } else { const singleton = init(); globalMap.set(key, singleton); return globalMap.get(key); } } const Abs = 'Abs'; const Acos = 'Acos'; const Acosh = 'Acosh'; const Add$1 = 'Add'; const AddN = 'AddN'; const All = 'All'; const Any = 'Any'; const ArgMax = 'ArgMax'; const ArgMin = 'ArgMin'; const Asin = 'Asin'; const Asinh = 'Asinh'; const Atan = 'Atan'; const Atanh = 'Atanh'; const Atan2 = 'Atan2'; const AvgPool = 'AvgPool'; const AvgPoolGrad = 'AvgPoolGrad'; const AvgPool3D = 'AvgPool3D'; const AvgPool3DGrad = 'AvgPool3DGrad'; const BatchMatMul = 'BatchMatMul'; const BatchToSpaceND = 'BatchToSpaceND'; const Bincount = 'Bincount'; const BitwiseAnd = 'BitwiseAnd'; const BroadcastTo = 'BroadcastTo'; const BroadcastArgs = 'BroadcastArgs'; const Cast = 'Cast'; const Ceil = 'Ceil'; const ClipByValue = 'ClipByValue'; const Complex = 'Complex'; const ComplexAbs = 'ComplexAbs'; const Concat = 'Concat'; const Conv2D$1 = 'Conv2D'; const Conv2DBackpropFilter = 'Conv2DBackpropFilter'; const Conv2DBackpropInput = 'Conv2DBackpropInput'; const Conv3D$1 = 'Conv3D'; const Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2'; const Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2'; const Cos = 'Cos'; const Cosh = 'Cosh'; const Cumprod = 'Cumprod'; const Cumsum = 'Cumsum'; const CropAndResize = 'CropAndResize'; const DenseBincount = 'DenseBincount'; const DepthToSpace = 'DepthToSpace'; const DepthwiseConv2dNative = 'DepthwiseConv2dNative'; const DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter'; const DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput'; const Diag = 'Diag'; const Dilation2D = 'Dilation2D'; const Dilation2DBackpropInput = 'Dilation2DBackpropInput'; const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter'; const Draw = 'Draw'; const RealDiv = 'RealDiv'; const Einsum = 'Einsum'; const Elu$1 = 'Elu'; const EluGrad = 'EluGrad'; const Erf = 'Erf'; const Equal = 'Equal'; const Exp = 'Exp'; const ExpandDims = 'ExpandDims'; const Expm1 = 'Expm1'; const FFT = 'FFT'; const Fill = 'Fill'; const FlipLeftRight = 'FlipLeftRight'; const Floor = 'Floor'; const FloorDiv = 'FloorDiv'; const FusedBatchNorm = 'FusedBatchNorm'; const GatherV2 = 'GatherV2'; const GatherNd = 'GatherNd'; const Greater = 'Greater'; const GreaterEqual = 'GreaterEqual'; const Identity$1 = 'Identity'; const IFFT = 'IFFT'; const Imag = 'Imag'; const IsFinite = 'IsFinite'; const IsInf = 'IsInf'; const IsNan = 'IsNan'; const LeakyRelu = 'LeakyRelu'; const Less = 'Less'; const LessEqual = 'LessEqual'; const LinSpace = 'LinSpace'; const Log = 'Log'; const Log1p = 'Log1p'; const LogicalAnd = 'LogicalAnd'; const LogicalNot = 'LogicalNot'; const LogicalOr = 'LogicalOr'; const LogicalXor = 'LogicalXor'; const LogSoftmax$1 = 'LogSoftmax'; const LowerBound = 'LowerBound'; const LRN = 'LRN'; const LRNGrad = 'LRNGrad'; const MatrixBandPart = 'MatrixBandPart'; const Max = 'Max'; const Maximum$1 = 'Maximum'; const MaxPool = 'MaxPool'; const MaxPoolGrad = 'MaxPoolGrad'; const MaxPool3D = 'MaxPool3D'; const MaxPool3DGrad = 'MaxPool3DGrad'; const MaxPoolWithArgmax = 'MaxPoolWithArgmax'; const Mean = 'Mean'; const Min = 'Min'; const Minimum$1 = 'Minimum'; const MirrorPad = 'MirrorPad'; const Mod = 'Mod'; const Multinomial = 'Multinomial'; const Multiply$1 = 'Multiply'; const Neg = 'Neg'; const NotEqual = 'NotEqual'; const NonMaxSuppressionV3 = 'NonMaxSuppressionV3'; const NonMaxSuppressionV4 = 'NonMaxSuppressionV4'; const NonMaxSuppressionV5 = 'NonMaxSuppressionV5'; const OnesLike = 'OnesLike'; const OneHot = 'OneHot'; const Pack = 'Pack'; const PadV2 = 'PadV2'; const Pool = 'Pool'; const Pow = 'Pow'; const Prelu = 'Prelu'; const Prod = 'Prod'; const RaggedGather = 'RaggedGather'; const RaggedRange = 'RaggedRange'; const RaggedTensorToTensor = 'RaggedTensorToTensor'; const Range = 'Range'; const Real = 'Real'; const Reciprocal = 'Reciprocal'; const Relu$1 = 'Relu'; const Reshape$1 = 'Reshape'; const ResizeNearestNeighbor = 'ResizeNearestNeighbor'; const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad'; const ResizeBilinear = 'ResizeBilinear'; const ResizeBilinearGrad = 'ResizeBilinearGrad'; const Relu6$1 = 'Relu6'; const Reverse = 'Reverse'; const Round = 'Round'; const Rsqrt = 'Rsqrt'; const ScatterNd = 'ScatterNd'; const TensorScatterUpdate = 'TensorScatterUpdate'; const SearchSorted = 'SearchSorted'; const Select = 'Select'; const Selu$1 = 'Selu'; const Slice = 'Slice'; const Sin = 'Sin'; const Sinh = 'Sinh'; const Sign = 'Sign'; const Sigmoid$1 = 'Sigmoid'; const Softplus$1 = 'Softplus'; const Sqrt = 'Sqrt'; const Sum = 'Sum'; const SpaceToBatchND = 'SpaceToBatchND'; const SplitV = 'SplitV'; const Softmax$2 = 'Softmax'; const SparseFillEmptyRows = 'SparseFillEmptyRows'; const SparseReshape = 'SparseReshape'; const SparseSegmentMean = 'SparseSegmentMean'; const SparseSegmentSum = 'SparseSegmentSum'; const SparseToDense = 'SparseToDense'; const SquaredDifference = 'SquaredDifference'; const Square = 'Square'; const StaticRegexReplace = 'StaticRegexReplace'; const StridedSlice = 'StridedSlice'; const StringNGrams = 'StringNGrams'; const StringSplit = 'StringSplit'; const StringToHashBucketFast = 'StringToHashBucketFast'; const Sub = 'Sub'; const Tan = 'Tan'; const Tanh$1 = 'Tanh'; const Tile = 'Tile'; const TopK = 'TopK'; const Transform = 'Transform'; const Transpose = 'Transpose'; const Unique = 'Unique'; const Unpack = 'Unpack'; const UnsortedSegmentSum = 'UnsortedSegmentSum'; const UpperBound = 'UpperBound'; const ZerosLike = 'ZerosLike'; /** * TensorFlow.js-only kernels */ const Step = 'Step'; const FromPixels = 'FromPixels'; const RotateWithOffset = 'RotateWithOffset'; const _FusedMatMul = '_FusedMatMul'; const FusedConv2D = 'FusedConv2D'; const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D'; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function warn(...msg) { if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn(...msg); } } function log$3(...msg) { if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.log(...msg); } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const kernelRegistry = getGlobal('kernelRegistry', () => new Map()); const gradRegistry = getGlobal('gradRegistry', () => new Map()); /** * Returns the kernel function (code) associated with the provided names. * * @param kernelName The official name of the kernel. * @param backendName The official name of the backend. */ function getKernel(kernelName, backendName) { const key = makeKey(kernelName, backendName); return kernelRegistry.get(key); } /** * Returns the registered gradient info associated with the provided kernel. * @param kernelName The official TF kernel name. */ function getGradient(kernelName) { return gradRegistry.get(kernelName); } function getKernelsForBackend(backendName) { const it = kernelRegistry.entries(); const result = []; while (true) { const { done, value } = it.next(); if (done) { break; } const [key, config] = value; const [backend,] = key.split('_'); if (backend === backendName) { result.push(config); } } return result; } /** * Registers the function (forward pass) for the kernel in a global registry. * * @param config A config object with the following properties: * - `kernelName` The official name of the kernel. * - `backendName` The official name of the backend. * - `kernelFunc` The function to run during the forward pass of the kernel. * - `setupFunc` Optional. Gets called once, after the backend initializes. * - `disposeFunc` Optional. Gets called once, right before the backend is * disposed. */ function registerKernel(config) { const { kernelName, backendName } = config; const key = makeKey(kernelName, backendName); if (kernelRegistry.has(key)) { warn(`The kernel '${kernelName}' for backend ` + `'${backendName}' is already registered`); } kernelRegistry.set(key, config); } /** * Registers a gradient function for a given kernel in the global registry, * to be used during the back-propagation of that kernel. * * @param config An object with the following properties: * - `kernelName` The name of the kernel that the gradient function is for. * - `gradFunc` The function to run during back-propagation. */ function registerGradient(config) { const { kernelName } = config; if (gradRegistry.has(kernelName)) { // TODO (yassogba) after 3.0 assess whether we need to keep this gated // to debug mode. if (env().getBool('DEBUG')) { warn(`Overriding the gradient for '${kernelName}'`); } } gradRegistry.set(kernelName, config); } /** * Removes the kernel function from the registry. * * @param kernelName The official name of the kernel. * @param backendName The official name of the backend. * */ function unregisterKernel(kernelName, backendName) { const key = makeKey(kernelName, backendName); if (!kernelRegistry.has(key)) { throw new Error(`The kernel '${kernelName}' for backend ` + `'${backendName}' is not registered`); } kernelRegistry.delete(key); } /** Removes the registered gradient from the global registry. */ function unregisterGradient(kernelName) { if (!gradRegistry.has(kernelName)) { throw new Error(`The gradient '${kernelName}' for backend is not registered`); } gradRegistry.delete(kernelName); } /** * Finds kernels that have already been registered to a backend and re-registers * them for a new backend. Useful for registering custom backends. * @param registeredBackendName Already registered backend. * @param newBackendName New backend. */ function copyRegisteredKernels(registeredBackendName, newBackendName) { const kernels = getKernelsForBackend(registeredBackendName); kernels.forEach(kernelConfig => { const newKernelConfig = Object.assign({}, kernelConfig, { backendName: newBackendName }); registerKernel(newKernelConfig); }); } function makeKey(kernelName, backendName) { return `${backendName}_${kernelName}`; } /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function isTypedArrayBrowser(a) { return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array || a instanceof Uint8ClampedArray; } var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {}; function getDefaultExportFromCjs (x) { return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x; } function getDefaultExportFromNamespaceIfPresent (n) { return n && Object.prototype.hasOwnProperty.call(n, 'default') ? n['default'] : n; } function getDefaultExportFromNamespaceIfNotNamed (n) { return n && Object.prototype.hasOwnProperty.call(n, 'default') && Object.keys(n).length === 1 ? n['default'] : n; } function getAugmentedNamespace(n) { if (n.__esModule) return n; var f = n.default; if (typeof f == "function") { var a = function a () { if (this instanceof a) { var args = [null]; args.push.apply(args, arguments); var Ctor = Function.bind.apply(f, args); return new Ctor(); } return f.apply(this, arguments); }; a.prototype = f.prototype; } else a = {}; Object.defineProperty(a, '__esModule', {value: true}); Object.keys(n).forEach(function (k) { var d = Object.getOwnPropertyDescriptor(n, k); Object.defineProperty(a, k, d.get ? d : { enumerable: true, get: function () { return n[k]; } }); }); return a; } var long = Long$1; /** * wasm optimizations, to do native i64 multiplication and divide */ var wasm = null; try { wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([ 0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11 ])), {}).exports; } catch (e) { // no wasm support :( } /** * Constructs a 64 bit two's-complement integer, given its low and high 32 bit values as *signed* integers. * See the from* functions below for more convenient ways of constructing Longs. * @exports Long * @class A Long class for representing a 64 bit two's-complement integer value. * @param {number} low The low (signed) 32 bits of the long * @param {number} high The high (signed) 32 bits of the long * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @constructor */ function Long$1(low, high, unsigned) { /** * The low 32 bits as a signed value. * @type {number} */ this.low = low | 0; /** * The high 32 bits as a signed value. * @type {number} */ this.high = high | 0; /** * Whether unsigned or not. * @type {boolean} */ this.unsigned = !!unsigned; } // The internal representation of a long is the two given signed, 32-bit values. // We use 32-bit pieces because these are the size of integers on which // Javascript performs bit-operations. For operations like addition and // multiplication, we split each number into 16 bit pieces, which can easily be // multiplied within Javascript's floating-point representation without overflow // or change in sign. // // In the algorithms below, we frequently reduce the negative case to the // positive case by negating the input(s) and then post-processing the result. // Note that we must ALWAYS check specially whether those values are MIN_VALUE // (-2^63) because -MIN_VALUE == MIN_VALUE (since 2^63 cannot be represented as // a positive number, it overflows back into a negative). Not handling this // case would often result in infinite recursion. // // Common constant values ZERO, ONE, NEG_ONE, etc. are defined below the from* // methods on which they depend. /** * An indicator used to reliably determine if an object is a Long or not. * @type {boolean} * @const * @private */ Long$1.prototype.__isLong__; Object.defineProperty(Long$1.prototype, "__isLong__", { value: true }); /** * @function * @param {*} obj Object * @returns {boolean} * @inner */ function isLong(obj) { return (obj && obj["__isLong__"]) === true; } /** * Tests if the specified object is a Long. * @function * @param {*} obj Object * @returns {boolean} */ Long$1.isLong = isLong; /** * A cache of the Long representations of small integer values. * @type {!Object} * @inner */ var INT_CACHE = {}; /** * A cache of the Long representations o