denoiser
Version:
OIDN Denoiser with tensorflow.js
1,553 lines (1,513 loc) • 3.82 MB
JavaScript
function _mergeNamespaces(n, m) {
m.forEach(function (e) {
e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) {
if (k !== 'default' && !(k in n)) {
var d = Object.getOwnPropertyDescriptor(e, k);
Object.defineProperty(n, k, d.get ? d : {
enumerable: true,
get: function () { return e[k]; }
});
}
});
});
return Object.freeze(n);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const EPSILON_FLOAT32$1 = 1e-7;
const EPSILON_FLOAT16$1 = 1e-4;
/** Convenient class for storing tensor-related data. */
class DataStorage {
constructor(backend, dataMover) {
this.backend = backend;
this.dataMover = dataMover;
this.data = new WeakMap();
this.dataIdsCount = 0;
}
get(dataId) {
if (!this.data.has(dataId)) {
this.dataMover.moveData(this.backend, dataId);
}
return this.data.get(dataId);
}
set(dataId, value) {
this.dataIdsCount++;
this.data.set(dataId, value);
}
has(dataId) {
return this.data.has(dataId);
}
delete(dataId) {
this.dataIdsCount--;
return this.data.delete(dataId);
}
numDataIds() {
return this.dataIdsCount;
}
}
/**
* The interface that defines the kernels that should be implemented when
* adding a new backend. New backends don't need to implement every one of the
* methods, this can be done gradually (throw an error for unimplemented
* methods).
*/
class KernelBackend {
refCount(dataId) {
return notYetImplemented('refCount');
}
incRef(dataId) {
return notYetImplemented('incRef');
}
timerAvailable() {
return true;
}
time(f) {
return notYetImplemented('time');
}
read(dataId) {
return notYetImplemented('read');
}
readSync(dataId) {
return notYetImplemented('readSync');
}
readToGPU(dataId, options) {
return notYetImplemented('readToGPU');
}
numDataIds() {
return notYetImplemented('numDataIds');
}
disposeData(dataId, force) {
return notYetImplemented('disposeData');
}
write(values, shape, dtype) {
return notYetImplemented('write');
}
move(dataId, values, shape, dtype, refCount) {
return notYetImplemented('move');
}
createTensorFromGPUData(values, shape, dtype) {
return notYetImplemented('createTensorFromGPUData');
}
memory() {
return notYetImplemented('memory');
}
/** Returns the highest precision for floats in bits (e.g. 16 or 32) */
floatPrecision() {
return notYetImplemented('floatPrecision');
}
/** Returns the smallest representable number. */
epsilon() {
return this.floatPrecision() === 32 ? EPSILON_FLOAT32$1 : EPSILON_FLOAT16$1;
}
dispose() {
return notYetImplemented('dispose');
}
}
function notYetImplemented(kernelName) {
throw new Error(`'${kernelName}' not yet implemented or not found in the registry. ` +
`This kernel may not be supported by the tfjs backend you have chosen`);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Shuffles the array in-place using Fisher-Yates algorithm.
*
* ```js
* const a = [1, 2, 3, 4, 5];
* tf.util.shuffle(a);
* console.log(a);
* ```
*
* @param array The array to shuffle in-place.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
// tslint:disable-next-line:no-any
function shuffle(array) {
let counter = array.length;
let index = 0;
// While there are elements in the array
while (counter > 0) {
// Pick a random index
index = (Math.random() * counter) | 0;
// Decrease counter by 1
counter--;
// And swap the last element with it
swap(array, counter, index);
}
}
/** Clamps a value to a specified range. */
function clamp(min, x, max) {
return Math.max(min, Math.min(x, max));
}
function nearestLargerEven(val) {
return val % 2 === 0 ? val : val + 1;
}
function swap(object, left, right) {
const temp = object[left];
object[left] = object[right];
object[right] = temp;
}
function sum$4(arr) {
let sum = 0;
for (let i = 0; i < arr.length; i++) {
sum += arr[i];
}
return sum;
}
/**
* Asserts that the expression is true. Otherwise throws an error with the
* provided message.
*
* ```js
* const x = 2;
* tf.util.assert(x === 2, 'x is not 2');
* ```
*
* @param expr The expression to assert (as a boolean).
* @param msg A function that returns the message to report when throwing an
* error. We use a function for performance reasons.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function assert$1(expr, msg) {
if (!expr) {
throw new Error(typeof msg === 'string' ? msg : msg());
}
}
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') {
assert$1(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
}
function assertNonNull(a) {
assert$1(a != null, () => `The input to the tensor constructor must be a non-null value.`);
}
/**
* Returns the size (number of elements) of the tensor given its shape.
*
* ```js
* const shape = [3, 4, 2];
* const size = tf.util.sizeFromShape(shape);
* console.log(size);
* ```
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function sizeFromShape(shape) {
if (shape.length === 0) {
// Scalar.
return 1;
}
let size = shape[0];
for (let i = 1; i < shape.length; i++) {
size *= shape[i];
}
return size;
}
function isScalarShape(shape) {
return shape.length === 0;
}
function arraysEqual(n1, n2) {
if (n1 === n2) {
return true;
}
if (n1 == null || n2 == null) {
return false;
}
if (n1.length !== n2.length) {
return false;
}
for (let i = 0; i < n1.length; i++) {
if (n1[i] !== n2[i]) {
return false;
}
}
return true;
}
function isInt(a) {
return a % 1 === 0;
}
function sizeToSquarishShape(size) {
const width = Math.ceil(Math.sqrt(size));
return [width, Math.ceil(size / width)];
}
function rightPad(a, size) {
if (size <= a.length) {
return a;
}
return a + ' '.repeat(size - a.length);
}
function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter, scheduleFn) {
return new Promise((resolve, reject) => {
let tryCount = 0;
const tryFn = () => {
if (checkFn()) {
resolve();
return;
}
tryCount++;
const nextBackoff = delayFn(tryCount);
if (maxCounter != null && tryCount >= maxCounter) {
reject();
return;
}
if (scheduleFn != null) {
scheduleFn(tryFn, nextBackoff);
}
else {
// google3 does not allow assigning another variable to setTimeout.
// Don't refactor this so scheduleFn has a default value of setTimeout.
setTimeout(tryFn, nextBackoff);
}
};
tryFn();
});
}
/**
* Given the full size of the array and a shape that may contain -1 as the
* implicit dimension, returns the inferred shape where -1 is replaced.
* E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
*
* @param shape The shape, which may contain -1 in some dimension.
* @param size The full size (number of elements) of the array.
* @return The inferred shape where -1 is replaced with the inferred size.
*/
function inferFromImplicitShape(shape, size) {
let shapeProd = 1;
let implicitIdx = -1;
for (let i = 0; i < shape.length; ++i) {
if (shape[i] >= 0) {
shapeProd *= shape[i];
}
else if (shape[i] === -1) {
if (implicitIdx !== -1) {
throw Error(`Shapes can only have 1 implicit size. ` +
`Found -1 at dim ${implicitIdx} and dim ${i}`);
}
implicitIdx = i;
}
else if (shape[i] < 0) {
throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);
}
}
if (implicitIdx === -1) {
if (size > 0 && size !== shapeProd) {
throw Error(`Size(${size}) must match the product of shape ${shape}`);
}
return shape;
}
if (shapeProd === 0) {
throw Error(`Cannot infer the missing size in [${shape}] when ` +
`there are 0 elements`);
}
if (size % shapeProd !== 0) {
throw Error(`The implicit shape can't be a fractional number. ` +
`Got ${size} / ${shapeProd}`);
}
const newShape = shape.slice();
newShape[implicitIdx] = size / shapeProd;
return newShape;
}
function parseAxisParam(axis, shape) {
const rank = shape.length;
// Normalize input
axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);
// Check for valid range
assert$1(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
`got axis ${axis}`);
// Check for only integers
assert$1(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` +
`got axis ${axis}`);
// Handle negative axis.
return axis.map(a => a < 0 ? rank + a : a);
}
/** Reduces the shape by removing all dimensions of shape 1. */
function squeezeShape(shape, axis) {
const newShape = [];
const keptDims = [];
const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
const axes = (axis == null || isEmptyArray) ?
null :
parseAxisParam(axis, shape).sort();
let j = 0;
for (let i = 0; i < shape.length; ++i) {
if (axes != null) {
if (axes[j] === i && shape[i] !== 1) {
throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
}
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
if (axes[j] <= i) {
j++;
}
}
if (shape[i] !== 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
}
return { newShape, keptDims };
}
function getTypedArrayFromDType(dtype, size) {
return getArrayFromDType(dtype, size);
}
function getArrayFromDType(dtype, size) {
let values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
}
else if (dtype === 'int32') {
values = new Int32Array(size);
}
else if (dtype === 'bool') {
values = new Uint8Array(size);
}
else if (dtype === 'string') {
values = new Array(size);
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
return values;
}
function checkConversionForErrors(vals, dtype) {
for (let i = 0; i < vals.length; i++) {
const num = vals[i];
if (isNaN(num) || !isFinite(num)) {
throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);
}
}
}
/** Returns true if the dtype is valid. */
function isValidDtype(dtype) {
return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
dtype === 'int32' || dtype === 'string';
}
/**
* Returns true if the new type can't encode the old type without loss of
* precision.
*/
function hasEncodingLoss(oldType, newType) {
if (newType === 'complex64') {
return false;
}
if (newType === 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'bool' && oldType === 'bool') {
return false;
}
return true;
}
function bytesPerElement(dtype) {
if (dtype === 'float32' || dtype === 'int32') {
return 4;
}
else if (dtype === 'complex64') {
return 8;
}
else if (dtype === 'bool') {
return 1;
}
else {
throw new Error(`Unknown dtype ${dtype}`);
}
}
/**
* Returns the approximate number of bytes allocated in the string array - 2
* bytes per character. Computing the exact bytes for a native string in JS
* is not possible since it depends on the encoding of the html page that
* serves the website.
*/
function bytesFromStringArray(arr) {
if (arr == null) {
return 0;
}
let bytes = 0;
arr.forEach(x => bytes += x.length);
return bytes;
}
/** Returns true if the value is a string. */
function isString(value) {
return typeof value === 'string' || value instanceof String;
}
function isBoolean(value) {
return typeof value === 'boolean';
}
function isNumber(value) {
return typeof value === 'number';
}
function inferDtype(values) {
if (Array.isArray(values)) {
return inferDtype(values[0]);
}
if (values instanceof Float32Array) {
return 'float32';
}
else if (values instanceof Int32Array || values instanceof Uint8Array ||
values instanceof Uint8ClampedArray) {
return 'int32';
}
else if (isNumber(values)) {
return 'float32';
}
else if (isString(values)) {
return 'string';
}
else if (isBoolean(values)) {
return 'bool';
}
return 'float32';
}
function isFunction(f) {
return !!(f && f.constructor && f.call && f.apply);
}
function nearestDivisor(size, start) {
for (let i = start; i < size; ++i) {
if (size % i === 0) {
return i;
}
}
return size;
}
function computeStrides(shape) {
const rank = shape.length;
if (rank < 2) {
return [];
}
// Last dimension has implicit stride of 1, thus having D-1 (instead of D)
// strides.
const strides = new Array(rank - 1);
strides[rank - 2] = shape[rank - 1];
for (let i = rank - 3; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
function createNestedArray(offset, shape, a, isComplex = false) {
const ret = new Array();
if (shape.length === 1) {
const d = shape[0] * (isComplex ? 2 : 1);
for (let i = 0; i < d; i++) {
ret[i] = a[offset + i];
}
}
else {
const d = shape[0];
const rest = shape.slice(1);
const len = rest.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
for (let i = 0; i < d; i++) {
ret[i] = createNestedArray(offset + i * len, rest, a, isComplex);
}
}
return ret;
}
// Provide a nested array of TypedArray in given shape.
function toNestedArray(shape, a, isComplex = false) {
if (shape.length === 0) {
// Scalar type should return a single number.
return a[0];
}
const size = shape.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
if (size === 0) {
// A tensor with shape zero should be turned into empty list.
return [];
}
if (size !== a.length) {
throw new Error(`[${shape}] does not match the input size ${a.length}${isComplex ? ' for a complex tensor' : ''}.`);
}
return createNestedArray(0, shape, a, isComplex);
}
function convertBackendValuesAndArrayBuffer(data, dtype) {
// If is type Uint8Array[], return it directly.
if (Array.isArray(data)) {
return data;
}
if (dtype === 'float32') {
return data instanceof Float32Array ? data : new Float32Array(data);
}
else if (dtype === 'int32') {
return data instanceof Int32Array ? data : new Int32Array(data);
}
else if (dtype === 'bool' || dtype === 'string') {
return Uint8Array.from(new Int32Array(data));
}
else {
throw new Error(`Unknown dtype ${dtype}`);
}
}
function makeOnesTypedArray(size, dtype) {
const array = makeZerosTypedArray(size, dtype);
for (let i = 0; i < array.length; i++) {
array[i] = 1;
}
return array;
}
function makeZerosTypedArray(size, dtype) {
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(size);
}
else if (dtype === 'int32') {
return new Int32Array(size);
}
else if (dtype === 'bool') {
return new Uint8Array(size);
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
}
/**
* Make nested `TypedArray` filled with zeros.
* @param shape The shape information for the nested array.
* @param dtype dtype of the array element.
*/
function makeZerosNestedTypedArray(shape, dtype) {
const size = shape.reduce((prev, curr) => prev * curr, 1);
if (dtype == null || dtype === 'float32') {
return toNestedArray(shape, new Float32Array(size));
}
else if (dtype === 'int32') {
return toNestedArray(shape, new Int32Array(size));
}
else if (dtype === 'bool') {
return toNestedArray(shape, new Uint8Array(size));
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
}
function assertNonNegativeIntegerDimensions(shape) {
shape.forEach(dimSize => {
assert$1(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` +
`shape [${shape}].`);
});
}
/**
* Computes flat index for a given location (multidimentionsal index) in a
* Tensor/multidimensional array.
*
* @param locs Location in the tensor.
* @param rank Rank of the tensor.
* @param strides Tensor strides.
*/
function locToIndex(locs, rank, strides) {
if (rank === 0) {
return 0;
}
else if (rank === 1) {
return locs[0];
}
let index = locs[locs.length - 1];
for (let i = 0; i < locs.length - 1; ++i) {
index += strides[i] * locs[i];
}
return index;
}
/**
* Computes the location (multidimensional index) in a
* tensor/multidimentional array for a given flat index.
*
* @param index Index in flat array.
* @param rank Rank of tensor.
* @param strides Strides of tensor.
*/
function indexToLoc(index, rank, strides) {
if (rank === 0) {
return [];
}
else if (rank === 1) {
return [index];
}
const locs = new Array(rank);
for (let i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / strides[i]);
index -= locs[i] * strides[i];
}
locs[locs.length - 1] = index;
return locs;
}
/**
* This method asserts whether an object is a Promise instance.
* @param object
*/
// tslint:disable-next-line: no-any
function isPromise(object) {
// We chose to not use 'obj instanceOf Promise' for two reasons:
// 1. It only reliably works for es6 Promise, not other Promise
// implementations.
// 2. It doesn't work with framework that uses zone.js. zone.js monkey
// patch the async calls, so it is possible the obj (patched) is
// comparing to a pre-patched Promise.
return object && object.then && typeof object.then === 'function';
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
/**
* The environment contains evaluated flags as well as the registered platform.
* This is always used as a global singleton and can be retrieved with
* `tf.env()`.
*
* @doc {heading: 'Environment'}
*/
class Environment {
// tslint:disable-next-line: no-any
constructor(global) {
this.global = global;
this.flags = {};
this.flagRegistry = {};
this.urlFlags = {};
// Jasmine spies on this in 'environment_test.ts'
this.getQueryParams = getQueryParams;
this.populateURLFlags();
}
setPlatform(platformName, platform) {
if (this.platform != null) {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.warn(`Platform ${this.platformName} has already been set. ` +
`Overwriting the platform with ${platformName}.`);
}
}
this.platformName = platformName;
this.platform = platform;
}
registerFlag(flagName, evaluationFn, setHook) {
this.flagRegistry[flagName] = { evaluationFn, setHook };
// Override the flag value from the URL. This has to happen here because
// the environment is initialized before flags get registered.
if (this.urlFlags[flagName] != null) {
const flagValue = this.urlFlags[flagName];
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`);
}
this.set(flagName, flagValue);
}
}
async getAsync(flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
this.flags[flagName] = await this.evaluateFlag(flagName);
return this.flags[flagName];
}
get(flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
const flagValue = this.evaluateFlag(flagName);
if (isPromise(flagValue)) {
throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` +
`Please use getAsync() instead.`);
}
this.flags[flagName] = flagValue;
return this.flags[flagName];
}
getNumber(flagName) {
return this.get(flagName);
}
getBool(flagName) {
return this.get(flagName);
}
getString(flagName) {
return this.get(flagName);
}
getFlags() {
return this.flags;
}
// For backwards compatibility.
get features() {
return this.flags;
}
set(flagName, value) {
if (this.flagRegistry[flagName] == null) {
throw new Error(`Cannot set flag ${flagName} as it has not been registered.`);
}
this.flags[flagName] = value;
if (this.flagRegistry[flagName].setHook != null) {
this.flagRegistry[flagName].setHook(value);
}
}
evaluateFlag(flagName) {
if (this.flagRegistry[flagName] == null) {
throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`);
}
return this.flagRegistry[flagName].evaluationFn();
}
setFlags(flags) {
this.flags = Object.assign({}, flags);
}
reset() {
this.flags = {};
this.urlFlags = {};
this.populateURLFlags();
}
populateURLFlags() {
if (typeof this.global === 'undefined' ||
typeof this.global.location === 'undefined' ||
typeof this.global.location.search === 'undefined') {
return;
}
const urlParams = this.getQueryParams(this.global.location.search);
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
keyValues.forEach(keyValue => {
const [key, value] = keyValue.split(':');
this.urlFlags[key] = parseValue(key, value);
});
}
}
}
function getQueryParams(queryString) {
const params = {};
queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => {
decodeParam(params, t[0], t[1]);
return t.join('=');
});
return params;
}
function decodeParam(params, name, value) {
params[decodeURIComponent(name)] = decodeURIComponent(value || '');
}
function parseValue(flagName, value) {
const lowerCaseValue = value.toLowerCase();
if (lowerCaseValue === 'true' || lowerCaseValue === 'false') {
return lowerCaseValue === 'true';
}
else if (`${+lowerCaseValue}` === lowerCaseValue) {
return +lowerCaseValue;
}
else {
return value;
}
}
/**
* Returns the current environment (a global singleton).
*
* The environment object contains the evaluated feature values as well as the
* active platform.
*
* @doc {heading: 'Environment'}
*/
function env() {
return ENV$5;
}
let ENV$5 = null;
function setEnvironmentGlobal(environment) {
ENV$5 = environment;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Note that the identifier globalNameSpace is scoped to this module, but will
// always resolve to the same global object regardless of how the module is
// resolved.
// tslint:disable-next-line:no-any
let globalNameSpace;
// tslint:disable-next-line:no-any
function getGlobalNamespace() {
if (globalNameSpace == null) {
// tslint:disable-next-line:no-any
let ns;
if (typeof (window) !== 'undefined') {
ns = window;
}
else if (typeof (global) !== 'undefined') {
ns = global;
}
else if (typeof (process) !== 'undefined') {
ns = process;
}
else if (typeof (self) !== 'undefined') {
ns = self;
}
else {
throw new Error('Could not find a global object');
}
globalNameSpace = ns;
}
return globalNameSpace;
}
// tslint:disable-next-line:no-any
function getGlobalMap() {
const ns = getGlobalNamespace();
if (ns._tfGlobals == null) {
ns._tfGlobals = new Map();
}
return ns._tfGlobals;
}
/**
* Returns a globally accessible 'singleton' object.
*
* @param key the name of the object
* @param init a function to initialize to initialize this object
* the first time it is fetched.
*/
function getGlobal(key, init) {
const globalMap = getGlobalMap();
if (globalMap.has(key)) {
return globalMap.get(key);
}
else {
const singleton = init();
globalMap.set(key, singleton);
return globalMap.get(key);
}
}
const Abs = 'Abs';
const Acos = 'Acos';
const Acosh = 'Acosh';
const Add$1 = 'Add';
const AddN = 'AddN';
const All = 'All';
const Any = 'Any';
const ArgMax = 'ArgMax';
const ArgMin = 'ArgMin';
const Asin = 'Asin';
const Asinh = 'Asinh';
const Atan = 'Atan';
const Atanh = 'Atanh';
const Atan2 = 'Atan2';
const AvgPool = 'AvgPool';
const AvgPoolGrad = 'AvgPoolGrad';
const AvgPool3D = 'AvgPool3D';
const AvgPool3DGrad = 'AvgPool3DGrad';
const BatchMatMul = 'BatchMatMul';
const BatchToSpaceND = 'BatchToSpaceND';
const Bincount = 'Bincount';
const BitwiseAnd = 'BitwiseAnd';
const BroadcastTo = 'BroadcastTo';
const BroadcastArgs = 'BroadcastArgs';
const Cast = 'Cast';
const Ceil = 'Ceil';
const ClipByValue = 'ClipByValue';
const Complex = 'Complex';
const ComplexAbs = 'ComplexAbs';
const Concat = 'Concat';
const Conv2D$1 = 'Conv2D';
const Conv2DBackpropFilter = 'Conv2DBackpropFilter';
const Conv2DBackpropInput = 'Conv2DBackpropInput';
const Conv3D$1 = 'Conv3D';
const Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2';
const Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2';
const Cos = 'Cos';
const Cosh = 'Cosh';
const Cumprod = 'Cumprod';
const Cumsum = 'Cumsum';
const CropAndResize = 'CropAndResize';
const DenseBincount = 'DenseBincount';
const DepthToSpace = 'DepthToSpace';
const DepthwiseConv2dNative = 'DepthwiseConv2dNative';
const DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter';
const DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput';
const Diag = 'Diag';
const Dilation2D = 'Dilation2D';
const Dilation2DBackpropInput = 'Dilation2DBackpropInput';
const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
const Draw = 'Draw';
const RealDiv = 'RealDiv';
const Einsum = 'Einsum';
const Elu$1 = 'Elu';
const EluGrad = 'EluGrad';
const Erf = 'Erf';
const Equal = 'Equal';
const Exp = 'Exp';
const ExpandDims = 'ExpandDims';
const Expm1 = 'Expm1';
const FFT = 'FFT';
const Fill = 'Fill';
const FlipLeftRight = 'FlipLeftRight';
const Floor = 'Floor';
const FloorDiv = 'FloorDiv';
const FusedBatchNorm = 'FusedBatchNorm';
const GatherV2 = 'GatherV2';
const GatherNd = 'GatherNd';
const Greater = 'Greater';
const GreaterEqual = 'GreaterEqual';
const Identity$1 = 'Identity';
const IFFT = 'IFFT';
const Imag = 'Imag';
const IsFinite = 'IsFinite';
const IsInf = 'IsInf';
const IsNan = 'IsNan';
const LeakyRelu = 'LeakyRelu';
const Less = 'Less';
const LessEqual = 'LessEqual';
const LinSpace = 'LinSpace';
const Log = 'Log';
const Log1p = 'Log1p';
const LogicalAnd = 'LogicalAnd';
const LogicalNot = 'LogicalNot';
const LogicalOr = 'LogicalOr';
const LogSoftmax$1 = 'LogSoftmax';
const LRN = 'LRN';
const LRNGrad = 'LRNGrad';
const Max = 'Max';
const Maximum$1 = 'Maximum';
const MaxPool = 'MaxPool';
const MaxPoolGrad = 'MaxPoolGrad';
const MaxPool3D = 'MaxPool3D';
const MaxPool3DGrad = 'MaxPool3DGrad';
const MaxPoolWithArgmax = 'MaxPoolWithArgmax';
const Mean = 'Mean';
const Min = 'Min';
const Minimum$1 = 'Minimum';
const MirrorPad = 'MirrorPad';
const Mod = 'Mod';
const Multinomial = 'Multinomial';
const Multiply$1 = 'Multiply';
const Neg = 'Neg';
const NotEqual = 'NotEqual';
const NonMaxSuppressionV3 = 'NonMaxSuppressionV3';
const NonMaxSuppressionV4 = 'NonMaxSuppressionV4';
const NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
const OnesLike = 'OnesLike';
const OneHot = 'OneHot';
const Pack = 'Pack';
const PadV2 = 'PadV2';
const Pow = 'Pow';
const Prelu = 'Prelu';
const Prod = 'Prod';
const RaggedGather = 'RaggedGather';
const RaggedRange = 'RaggedRange';
const RaggedTensorToTensor = 'RaggedTensorToTensor';
const Range = 'Range';
const Real = 'Real';
const Reciprocal = 'Reciprocal';
const Relu$1 = 'Relu';
const Reshape$1 = 'Reshape';
const ResizeNearestNeighbor = 'ResizeNearestNeighbor';
const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad';
const ResizeBilinear = 'ResizeBilinear';
const ResizeBilinearGrad = 'ResizeBilinearGrad';
const Relu6$1 = 'Relu6';
const Reverse = 'Reverse';
const Round = 'Round';
const Rsqrt = 'Rsqrt';
const ScatterNd = 'ScatterNd';
const TensorScatterUpdate = 'TensorScatterUpdate';
const SearchSorted = 'SearchSorted';
const Select = 'Select';
const Selu$1 = 'Selu';
const Slice = 'Slice';
const Sin = 'Sin';
const Sinh = 'Sinh';
const Sign = 'Sign';
const Sigmoid$1 = 'Sigmoid';
const Softplus$1 = 'Softplus';
const Sqrt = 'Sqrt';
const Sum = 'Sum';
const SpaceToBatchND = 'SpaceToBatchND';
const SplitV = 'SplitV';
const Softmax$2 = 'Softmax';
const SparseFillEmptyRows = 'SparseFillEmptyRows';
const SparseReshape = 'SparseReshape';
const SparseSegmentMean = 'SparseSegmentMean';
const SparseSegmentSum = 'SparseSegmentSum';
const SparseToDense = 'SparseToDense';
const SquaredDifference = 'SquaredDifference';
const Square = 'Square';
const StaticRegexReplace = 'StaticRegexReplace';
const StridedSlice = 'StridedSlice';
const StringNGrams = 'StringNGrams';
const StringSplit = 'StringSplit';
const StringToHashBucketFast = 'StringToHashBucketFast';
const Sub = 'Sub';
const Tan = 'Tan';
const Tanh$1 = 'Tanh';
const Tile = 'Tile';
const TopK = 'TopK';
const Transform = 'Transform';
const Transpose = 'Transpose';
const Unique = 'Unique';
const Unpack = 'Unpack';
const UnsortedSegmentSum = 'UnsortedSegmentSum';
const ZerosLike = 'ZerosLike';
/**
* TensorFlow.js-only kernels
*/
const Step = 'Step';
const FromPixels = 'FromPixels';
const RotateWithOffset = 'RotateWithOffset';
const _FusedMatMul = '_FusedMatMul';
const FusedConv2D = 'FusedConv2D';
const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function warn(...msg) {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.warn(...msg);
}
}
function log$4(...msg) {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.log(...msg);
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const kernelRegistry = getGlobal('kernelRegistry', () => new Map());
const gradRegistry = getGlobal('gradRegistry', () => new Map());
/**
* Returns the kernel function (code) associated with the provided names.
*
* @param kernelName The official name of the kernel.
* @param backendName The official name of the backend.
*/
function getKernel(kernelName, backendName) {
const key = makeKey(kernelName, backendName);
return kernelRegistry.get(key);
}
/**
* Returns the registered gradient info associated with the provided kernel.
* @param kernelName The official TF kernel name.
*/
function getGradient(kernelName) {
return gradRegistry.get(kernelName);
}
function getKernelsForBackend(backendName) {
const it = kernelRegistry.entries();
const result = [];
while (true) {
const { done, value } = it.next();
if (done) {
break;
}
const [key, config] = value;
const [backend,] = key.split('_');
if (backend === backendName) {
result.push(config);
}
}
return result;
}
/**
* Registers the function (forward pass) for the kernel in a global registry.
*
* @param config A config object with the following properties:
* - `kernelName` The official name of the kernel.
* - `backendName` The official name of the backend.
* - `kernelFunc` The function to run during the forward pass of the kernel.
* - `setupFunc` Optional. Gets called once, after the backend initializes.
* - `disposeFunc` Optional. Gets called once, right before the backend is
* disposed.
*/
function registerKernel(config) {
const { kernelName, backendName } = config;
const key = makeKey(kernelName, backendName);
if (kernelRegistry.has(key)) {
warn(`The kernel '${kernelName}' for backend ` +
`'${backendName}' is already registered`);
}
kernelRegistry.set(key, config);
}
/**
* Registers a gradient function for a given kernel in the global registry,
* to be used during the back-propagation of that kernel.
*
* @param config An object with the following properties:
* - `kernelName` The name of the kernel that the gradient function is for.
* - `gradFunc` The function to run during back-propagation.
*/
function registerGradient(config) {
const { kernelName } = config;
if (gradRegistry.has(kernelName)) {
// TODO (yassogba) after 3.0 assess whether we need to keep this gated
// to debug mode.
if (env().getBool('DEBUG')) {
warn(`Overriding the gradient for '${kernelName}'`);
}
}
gradRegistry.set(kernelName, config);
}
function makeKey(kernelName, backendName) {
return `${backendName}_${kernelName}`;
}
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function isTypedArrayBrowser(a) {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
}
var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {};
function getDefaultExportFromCjs (x) {
return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x;
}
var long = Long$1;
/**
* wasm optimizations, to do native i64 multiplication and divide
*/
var wasm = null;
try {
wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([
0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11
])), {}).exports;
} catch (e) {
// no wasm support :(
}
/**
* Constructs a 64 bit two's-complement integer, given its low and high 32 bit values as *signed* integers.
* See the from* functions below for more convenient ways of constructing Longs.
* @exports Long
* @class A Long class for representing a 64 bit two's-complement integer value.
* @param {number} low The low (signed) 32 bits of the long
* @param {number} high The high (signed) 32 bits of the long
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @constructor
*/
function Long$1(low, high, unsigned) {
/**
* The low 32 bits as a signed value.
* @type {number}
*/
this.low = low | 0;
/**
* The high 32 bits as a signed value.
* @type {number}
*/
this.high = high | 0;
/**
* Whether unsigned or not.
* @type {boolean}
*/
this.unsigned = !!unsigned;
}
// The internal representation of a long is the two given signed, 32-bit values.
// We use 32-bit pieces because these are the size of integers on which
// Javascript performs bit-operations. For operations like addition and
// multiplication, we split each number into 16 bit pieces, which can easily be
// multiplied within Javascript's floating-point representation without overflow
// or change in sign.
//
// In the algorithms below, we frequently reduce the negative case to the
// positive case by negating the input(s) and then post-processing the result.
// Note that we must ALWAYS check specially whether those values are MIN_VALUE
// (-2^63) because -MIN_VALUE == MIN_VALUE (since 2^63 cannot be represented as
// a positive number, it overflows back into a negative). Not handling this
// case would often result in infinite recursion.
//
// Common constant values ZERO, ONE, NEG_ONE, etc. are defined below the from*
// methods on which they depend.
/**
* An indicator used to reliably determine if an object is a Long or not.
* @type {boolean}
* @const
* @private
*/
Long$1.prototype.__isLong__;
Object.defineProperty(Long$1.prototype, "__isLong__", { value: true });
/**
* @function
* @param {*} obj Object
* @returns {boolean}
* @inner
*/
function isLong(obj) {
return (obj && obj["__isLong__"]) === true;
}
/**
* Tests if the specified object is a Long.
* @function
* @param {*} obj Object
* @returns {boolean}
*/
Long$1.isLong = isLong;
/**
* A cache of the Long representations of small integer values.
* @type {!Object}
* @inner
*/
var INT_CACHE = {};
/**
* A cache of the Long representations of small unsigned integer values.
* @type {!Object}
* @inner
*/
var UINT_CACHE = {};
/**
* @param {number} value
* @param {boolean=} unsigned
* @returns {!Long}
* @inner
*/
function fromInt(value, unsigned) {
var obj, cachedObj, cache;
if (unsigned) {
value >>>= 0;
if (cache = (0 <= value && value < 256)) {
cachedObj = UINT_CACHE[value];
if (cachedObj)
return cachedObj;
}
obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true);
if (cache)
UINT_CACHE[value] = obj;
return obj;
} else {
value |= 0;
if (cache = (-128 <= value && value < 128)) {
cachedObj = INT_CACHE[value];
if (cachedObj)
return cachedObj;
}
obj = fromBits(value, value < 0 ? -1 : 0, false);
if (cache)
INT_CACHE[value] = obj;
return obj;
}
}
/**
* Returns a Long representing the given 32 bit integer value.
* @function
* @param {number} value The 32 bit integer in question
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @returns {!Long} The corresponding Long value
*/
Long$1.fromInt = fromInt;
/**
* @param {number} value
* @param {boolean=} unsigned
* @returns {!Long}
* @inner
*/
function fromNumber(value, unsigned) {
if (isNaN(value))
return unsigned ? UZERO : ZERO;
if (unsigned) {
if (value < 0)
return UZERO;
if (value >= TWO_PWR_64_DBL)
return MAX_UNSIGNED_VALUE;
} else {
if (value <= -TWO_PWR_63_DBL)
return MIN_VALUE;
if (value + 1 >= TWO_PWR_63_DBL)
return MAX_VALUE;
}
if (value < 0)
return fromNumber(-value, unsigned).neg();
return fromBits((value % TWO_PWR_32_DBL) | 0, (value / TWO_PWR_32_DBL) | 0, unsigned);
}
/**
* Returns a Long representing the given value, provided that it is a finite number. Otherwise, zero is returned.
* @function
* @param {number} value The number in question
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @returns {!Long} The corresponding Long value
*/
Long$1.fromNumber = fromNumber;
/**
* @param {number} lowBits
* @param {number} highBits
* @param {boolean=} unsigned
* @returns {!Long}
* @inner
*/
function fromBits(lowBits, highBits, unsigned) {
return new Long$1(lowBits, highBits, unsigned);
}
/**
* Returns a Long representing the 64 bit integer that comes by concatenating the given low and high bits. Each is
* assumed to use 32 bits.
* @function
* @param {number} lowBits The low 32 bits
* @param {number} highBits The high 32 bits
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @returns {!Long} The corresponding Long value
*/
Long$1.fromBits = fromBits;
/**
* @function
* @param {number} base
* @param {number} exponent
* @returns {number}
* @inner
*/
var pow_dbl = Math.pow; // Used 4 times (4*8 to 15+4)
/**
* @param {string} str
* @param {(boolean|number)=} unsigned
* @param {number=} radix
* @returns {!Long}
* @inner
*/
function fromString(str, unsigned, radix) {
if (str.length === 0)
throw Error('empty string');
if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity")
return ZERO;
if (typeof unsigned === 'number') {
// For goog.math.long compatibility
radix = unsigned,
unsigned = false;
} else {
unsigned = !! unsigned;
}
radix = radix || 10;
if (radix < 2 || 36 < radix)
throw RangeError('radix');
var p;
if ((p = str.indexOf('-')) > 0)
throw Error('interior hyphen');
else if (p === 0) {
return fromString(str.substring(1), unsigned, radix).neg();
}
// Do several (8) digits each time through the loop, so as to
// minimize the calls to the very expensive emulated div.
var radixToPower = fromNumber(pow_dbl(radix, 8));
var result = ZERO;
for (var i = 0; i < str.length; i += 8) {
var size = Math.min(8, str.length - i),
value = parseInt(str.substring(i, i + size), radix);
if (size < 8) {
var power = fromNumber(pow_dbl(radix, size));
result = result.mul(power).add(fromNumber(value));
} else {
result = result.mul(radixToPower);
result = result.add(fromNumber(value));
}
}
result.unsigned = unsigned;
return result;
}
/**
* Returns a Long representation of the given string, written using the specified radix.
* @function
* @param {string} str The textual representation of the Long
* @param {(boolean|number)=} unsigned Whether unsigned or not, defaults to signed
* @param {number=} radix The radix in which the text is written (2-36), defaults to 10
* @returns {!Long} The corresponding Long value
*/
Long$1.fromString = fromString;
/**
* @function
* @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val
* @param {boolean=} unsigned
* @returns {!Long}
* @inner
*/
function fromValue(val, unsigned) {
if (typeof val === 'number')
return fromNumber(val, unsigned);
if (typeof val === 'string')
return fromString(val, unsigned);
// Throws for non-objects, converts non-instanceof Long:
return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned);
}
/**
* Converts the specified value to a Long using the appropriate from* function for its type.
* @function
* @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val Value
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @returns {!Long}
*/
Long$1.fromValue = fromValue;
// NOTE: the compiler should inline these constant values below and then remove these variables, so there should be
// no runtime penalty for these.
/**
* @type {number}
* @const
* @inner
*