@tensorflow/tfjs
Version:
An open-source machine learning framework.
1,564 lines (1,552 loc) • 1.34 MB
JavaScript
/**
* @license
* Copyright 2024 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
'use strict';
var tfjsCore = require('@tensorflow/tfjs-core');
var LongExports = require('long');
var tfjsLayers = require('@tensorflow/tfjs-layers');
var tfjsConverter = require('@tensorflow/tfjs-converter');
var tfjsData = require('@tensorflow/tfjs-data');
var tfjsBackendCpu = require('@tensorflow/tfjs-backend-cpu');
var tfjsBackendWebgl = require('@tensorflow/tfjs-backend-webgl');
function _interopNamespaceDefault(e) {
var n = Object.create(null);
if (e) {
Object.keys(e).forEach(function (k) {
if (k !== 'default') {
var d = Object.getOwnPropertyDescriptor(e, k);
Object.defineProperty(n, k, d.get ? d : {
enumerable: true,
get: function () { return e[k]; }
});
}
});
}
n.default = e;
return n;
}
var LongExports__namespace = /*#__PURE__*/_interopNamespaceDefault(LongExports);
var tfjsData__namespace = /*#__PURE__*/_interopNamespaceDefault(tfjsData);
const Abs = 'Abs';
const Acos = 'Acos';
const Acosh = 'Acosh';
const Add = 'Add';
const AddN = 'AddN';
const All = 'All';
const Any = 'Any';
const ArgMax = 'ArgMax';
const ArgMin = 'ArgMin';
const Asin = 'Asin';
const Asinh = 'Asinh';
const Atan = 'Atan';
const Atanh = 'Atanh';
const Atan2 = 'Atan2';
const AvgPool = 'AvgPool';
const AvgPoolGrad = 'AvgPoolGrad';
const AvgPool3D = 'AvgPool3D';
const AvgPool3DGrad = 'AvgPool3DGrad';
const BatchMatMul = 'BatchMatMul';
const BatchToSpaceND = 'BatchToSpaceND';
const Bincount = 'Bincount';
const BitwiseAnd = 'BitwiseAnd';
const BroadcastTo = 'BroadcastTo';
const BroadcastArgs = 'BroadcastArgs';
const Cast = 'Cast';
const Ceil = 'Ceil';
const ClipByValue = 'ClipByValue';
const Complex = 'Complex';
const ComplexAbs = 'ComplexAbs';
const Concat = 'Concat';
const Conv2D = 'Conv2D';
const Conv2DBackpropFilter = 'Conv2DBackpropFilter';
const Conv2DBackpropInput = 'Conv2DBackpropInput';
const Conv3D = 'Conv3D';
const Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2';
const Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2';
const Cos = 'Cos';
const Cosh = 'Cosh';
const Cumprod = 'Cumprod';
const Cumsum = 'Cumsum';
const CropAndResize = 'CropAndResize';
const DenseBincount = 'DenseBincount';
const DepthToSpace = 'DepthToSpace';
const DepthwiseConv2dNative = 'DepthwiseConv2dNative';
const DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter';
const DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput';
const Diag = 'Diag';
const Dilation2D = 'Dilation2D';
const Dilation2DBackpropInput = 'Dilation2DBackpropInput';
const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
const Draw = 'Draw';
const RealDiv = 'RealDiv';
const Einsum = 'Einsum';
const Elu = 'Elu';
const EluGrad = 'EluGrad';
const Erf = 'Erf';
const Equal = 'Equal';
const Exp = 'Exp';
const ExpandDims = 'ExpandDims';
const Expm1 = 'Expm1';
const FFT = 'FFT';
const Fill = 'Fill';
const FlipLeftRight = 'FlipLeftRight';
const Floor = 'Floor';
const FloorDiv = 'FloorDiv';
const FusedBatchNorm = 'FusedBatchNorm';
const GatherV2 = 'GatherV2';
const GatherNd = 'GatherNd';
const Greater = 'Greater';
const GreaterEqual = 'GreaterEqual';
const Identity = 'Identity';
const IFFT = 'IFFT';
const Imag = 'Imag';
const IsFinite = 'IsFinite';
const IsInf = 'IsInf';
const IsNan = 'IsNan';
const LeakyRelu = 'LeakyRelu';
const Less = 'Less';
const LessEqual = 'LessEqual';
const LinSpace = 'LinSpace';
const Log = 'Log';
const Log1p = 'Log1p';
const LogicalAnd = 'LogicalAnd';
const LogicalNot = 'LogicalNot';
const LogicalOr = 'LogicalOr';
const LogicalXor = 'LogicalXor';
const LogSoftmax = 'LogSoftmax';
const LowerBound = 'LowerBound';
const LRN = 'LRN';
const LRNGrad = 'LRNGrad';
const MatrixBandPart = 'MatrixBandPart';
const Max = 'Max';
const Maximum = 'Maximum';
const MaxPool = 'MaxPool';
const MaxPoolGrad = 'MaxPoolGrad';
const MaxPool3D = 'MaxPool3D';
const MaxPool3DGrad = 'MaxPool3DGrad';
const MaxPoolWithArgmax = 'MaxPoolWithArgmax';
const Mean = 'Mean';
const Min = 'Min';
const Minimum = 'Minimum';
const MirrorPad = 'MirrorPad';
const Mod = 'Mod';
const Multinomial = 'Multinomial';
const Multiply = 'Multiply';
const Neg = 'Neg';
const NotEqual = 'NotEqual';
const NonMaxSuppressionV3 = 'NonMaxSuppressionV3';
const NonMaxSuppressionV4 = 'NonMaxSuppressionV4';
const NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
const OnesLike = 'OnesLike';
const OneHot = 'OneHot';
const Pack = 'Pack';
const PadV2 = 'PadV2';
const Pool = 'Pool';
const Pow = 'Pow';
const Prelu = 'Prelu';
const Prod = 'Prod';
const RaggedGather = 'RaggedGather';
const RaggedRange = 'RaggedRange';
const RaggedTensorToTensor = 'RaggedTensorToTensor';
const Range = 'Range';
const Real = 'Real';
const Reciprocal = 'Reciprocal';
const Relu = 'Relu';
const Reshape = 'Reshape';
const ResizeNearestNeighbor = 'ResizeNearestNeighbor';
const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad';
const ResizeBilinear = 'ResizeBilinear';
const ResizeBilinearGrad = 'ResizeBilinearGrad';
const Relu6 = 'Relu6';
const Reverse = 'Reverse';
const Round = 'Round';
const Rsqrt = 'Rsqrt';
const ScatterNd = 'ScatterNd';
const TensorScatterUpdate = 'TensorScatterUpdate';
const SearchSorted = 'SearchSorted';
const Select = 'Select';
const Selu = 'Selu';
const Slice = 'Slice';
const Sin = 'Sin';
const Sinh = 'Sinh';
const Sign = 'Sign';
const Sigmoid = 'Sigmoid';
const Softplus = 'Softplus';
const Sqrt = 'Sqrt';
const Sum = 'Sum';
const SpaceToBatchND = 'SpaceToBatchND';
const SplitV = 'SplitV';
const Softmax = 'Softmax';
const SparseFillEmptyRows = 'SparseFillEmptyRows';
const SparseReshape = 'SparseReshape';
const SparseSegmentMean = 'SparseSegmentMean';
const SparseSegmentSum = 'SparseSegmentSum';
const SparseToDense = 'SparseToDense';
const SquaredDifference = 'SquaredDifference';
const Square = 'Square';
const StaticRegexReplace = 'StaticRegexReplace';
const StridedSlice = 'StridedSlice';
const StringNGrams = 'StringNGrams';
const StringSplit = 'StringSplit';
const StringToHashBucketFast = 'StringToHashBucketFast';
const Sub = 'Sub';
const Tan = 'Tan';
const Tanh = 'Tanh';
const Tile = 'Tile';
const TopK = 'TopK';
const Transform = 'Transform';
const Transpose = 'Transpose';
const Unique = 'Unique';
const Unpack = 'Unpack';
const UnsortedSegmentSum = 'UnsortedSegmentSum';
const UpperBound = 'UpperBound';
const ZerosLike = 'ZerosLike';
/**
* TensorFlow.js-only kernels
*/
const Step = 'Step';
const FromPixels = 'FromPixels';
const RotateWithOffset = 'RotateWithOffset';
const _FusedMatMul = '_FusedMatMul';
const FusedConv2D = 'FusedConv2D';
const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const EPSILON_FLOAT32 = 1e-7;
const EPSILON_FLOAT16 = 1e-4;
/** Convenient class for storing tensor-related data. */
class DataStorage {
constructor(backend, dataMover) {
this.backend = backend;
this.dataMover = dataMover;
this.data = new WeakMap();
this.dataIdsCount = 0;
}
get(dataId) {
if (!this.data.has(dataId)) {
this.dataMover.moveData(this.backend, dataId);
}
return this.data.get(dataId);
}
set(dataId, value) {
this.dataIdsCount++;
this.data.set(dataId, value);
}
has(dataId) {
return this.data.has(dataId);
}
delete(dataId) {
this.dataIdsCount--;
return this.data.delete(dataId);
}
numDataIds() {
return this.dataIdsCount;
}
}
/**
* The interface that defines the kernels that should be implemented when
* adding a new backend. New backends don't need to implement every one of the
* methods, this can be done gradually (throw an error for unimplemented
* methods).
*/
class KernelBackend {
refCount(dataId) {
return notYetImplemented('refCount');
}
incRef(dataId) {
return notYetImplemented('incRef');
}
timerAvailable() {
return true;
}
time(f) {
return notYetImplemented('time');
}
read(dataId) {
return notYetImplemented('read');
}
readSync(dataId) {
return notYetImplemented('readSync');
}
readToGPU(dataId, options) {
return notYetImplemented('readToGPU');
}
numDataIds() {
return notYetImplemented('numDataIds');
}
disposeData(dataId, force) {
return notYetImplemented('disposeData');
}
write(values, shape, dtype) {
return notYetImplemented('write');
}
move(dataId, values, shape, dtype, refCount) {
return notYetImplemented('move');
}
createTensorFromGPUData(values, shape, dtype) {
return notYetImplemented('createTensorFromGPUData');
}
memory() {
return notYetImplemented('memory');
}
/** Returns the highest precision for floats in bits (e.g. 16 or 32) */
floatPrecision() {
return notYetImplemented('floatPrecision');
}
/** Returns the smallest representable number. */
epsilon() {
return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
}
dispose() {
return notYetImplemented('dispose');
}
}
function notYetImplemented(kernelName) {
throw new Error(`'${kernelName}' not yet implemented or not found in the registry. ` +
`This kernel may not be supported by the tfjs backend you have chosen`);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Shuffles the array in-place using Fisher-Yates algorithm.
*
* ```js
* const a = [1, 2, 3, 4, 5];
* tf.util.shuffle(a);
* console.log(a);
* ```
*
* @param array The array to shuffle in-place.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
// tslint:disable-next-line:no-any
function shuffle(array) {
let counter = array.length;
let index = 0;
// While there are elements in the array
while (counter > 0) {
// Pick a random index
index = (Math.random() * counter) | 0;
// Decrease counter by 1
counter--;
// And swap the last element with it
swap(array, counter, index);
}
}
/**
* Shuffles two arrays in-place the same way using Fisher-Yates algorithm.
*
* ```js
* const a = [1,2,3,4,5];
* const b = [11,22,33,44,55];
* tf.util.shuffleCombo(a, b);
* console.log(a, b);
* ```
*
* @param array The first array to shuffle in-place.
* @param array2 The second array to shuffle in-place with the same permutation
* as the first array.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function shuffleCombo(
// tslint:disable-next-line:no-any
array,
// tslint:disable-next-line:no-any
array2) {
if (array.length !== array2.length) {
throw new Error(`Array sizes must match to be shuffled together ` +
`First array length was ${array.length}` +
`Second array length was ${array2.length}`);
}
let counter = array.length;
let index = 0;
// While there are elements in the array
while (counter > 0) {
// Pick a random index
index = (Math.random() * counter) | 0;
// Decrease counter by 1
counter--;
// And swap the last element of each array with it
swap(array, counter, index);
swap(array2, counter, index);
}
}
/** Clamps a value to a specified range. */
function clamp(min, x, max) {
return Math.max(min, Math.min(x, max));
}
function nearestLargerEven(val) {
return val % 2 === 0 ? val : val + 1;
}
function swap(object, left, right) {
const temp = object[left];
object[left] = object[right];
object[right] = temp;
}
function sum$1(arr) {
let sum = 0;
for (let i = 0; i < arr.length; i++) {
sum += arr[i];
}
return sum;
}
/**
* Returns a sample from a uniform [a, b) distribution.
*
* @param a The minimum support (inclusive).
* @param b The maximum support (exclusive).
* @return A pseudorandom number on the half-open interval [a,b).
*/
function randUniform(a, b) {
const r = Math.random();
return (b * r) + (1 - r) * a;
}
/** Returns the squared Euclidean distance between two vectors. */
function distSquared(a, b) {
let result = 0;
for (let i = 0; i < a.length; i++) {
const diff = Number(a[i]) - Number(b[i]);
result += diff * diff;
}
return result;
}
/**
* Asserts that the expression is true. Otherwise throws an error with the
* provided message.
*
* ```js
* const x = 2;
* tf.util.assert(x === 2, 'x is not 2');
* ```
*
* @param expr The expression to assert (as a boolean).
* @param msg A function that returns the message to report when throwing an
* error. We use a function for performance reasons.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function assert(expr, msg) {
if (!expr) {
throw new Error(typeof msg === 'string' ? msg : msg());
}
}
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') {
assert(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
}
function assertNonNull(a) {
assert(a != null, () => `The input to the tensor constructor must be a non-null value.`);
}
/**
* Returns the size (number of elements) of the tensor given its shape.
*
* ```js
* const shape = [3, 4, 2];
* const size = tf.util.sizeFromShape(shape);
* console.log(size);
* ```
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function sizeFromShape(shape) {
if (shape.length === 0) {
// Scalar.
return 1;
}
let size = shape[0];
for (let i = 1; i < shape.length; i++) {
size *= shape[i];
}
return size;
}
function isScalarShape(shape) {
return shape.length === 0;
}
function arraysEqualWithNull(n1, n2) {
if (n1 === n2) {
return true;
}
if (n1 == null || n2 == null) {
return false;
}
if (n1.length !== n2.length) {
return false;
}
for (let i = 0; i < n1.length; i++) {
if (n1[i] !== null && n2[i] !== null && n1[i] !== n2[i]) {
return false;
}
}
return true;
}
function arraysEqual(n1, n2) {
if (n1 === n2) {
return true;
}
if (n1 == null || n2 == null) {
return false;
}
if (n1.length !== n2.length) {
return false;
}
for (let i = 0; i < n1.length; i++) {
if (n1[i] !== n2[i]) {
return false;
}
}
return true;
}
function isInt(a) {
return a % 1 === 0;
}
function tanh$1(x) {
// tslint:disable-next-line:no-any
if (Math.tanh != null) {
// tslint:disable-next-line:no-any
return Math.tanh(x);
}
if (x === Infinity) {
return 1;
}
else if (x === -Infinity) {
return -1;
}
else {
const e2x = Math.exp(2 * x);
return (e2x - 1) / (e2x + 1);
}
}
function sizeToSquarishShape(size) {
const width = Math.ceil(Math.sqrt(size));
return [width, Math.ceil(size / width)];
}
/**
* Creates a new array with randomized indices to a given quantity.
*
* ```js
* const randomTen = tf.util.createShuffledIndices(10);
* console.log(randomTen);
* ```
*
* @param number Quantity of how many shuffled indices to create.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function createShuffledIndices(n) {
const shuffledIndices = new Uint32Array(n);
for (let i = 0; i < n; ++i) {
shuffledIndices[i] = i;
}
shuffle(shuffledIndices);
return shuffledIndices;
}
function rightPad(a, size) {
if (size <= a.length) {
return a;
}
return a + ' '.repeat(size - a.length);
}
function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter, scheduleFn) {
return new Promise((resolve, reject) => {
let tryCount = 0;
const tryFn = () => {
if (checkFn()) {
resolve();
return;
}
tryCount++;
const nextBackoff = delayFn(tryCount);
if (maxCounter != null && tryCount >= maxCounter) {
reject();
return;
}
if (scheduleFn != null) {
scheduleFn(tryFn, nextBackoff);
}
else {
// google3 does not allow assigning another variable to setTimeout.
// Don't refactor this so scheduleFn has a default value of setTimeout.
setTimeout(tryFn, nextBackoff);
}
};
tryFn();
});
}
/**
* Given the full size of the array and a shape that may contain -1 as the
* implicit dimension, returns the inferred shape where -1 is replaced.
* E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
*
* @param shape The shape, which may contain -1 in some dimension.
* @param size The full size (number of elements) of the array.
* @return The inferred shape where -1 is replaced with the inferred size.
*/
function inferFromImplicitShape(shape, size) {
let shapeProd = 1;
let implicitIdx = -1;
for (let i = 0; i < shape.length; ++i) {
if (shape[i] >= 0) {
shapeProd *= shape[i];
}
else if (shape[i] === -1) {
if (implicitIdx !== -1) {
throw Error(`Shapes can only have 1 implicit size. ` +
`Found -1 at dim ${implicitIdx} and dim ${i}`);
}
implicitIdx = i;
}
else if (shape[i] < 0) {
throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);
}
}
if (implicitIdx === -1) {
if (size > 0 && size !== shapeProd) {
throw Error(`Size(${size}) must match the product of shape ${shape}`);
}
return shape;
}
if (shapeProd === 0) {
throw Error(`Cannot infer the missing size in [${shape}] when ` +
`there are 0 elements`);
}
if (size % shapeProd !== 0) {
throw Error(`The implicit shape can't be a fractional number. ` +
`Got ${size} / ${shapeProd}`);
}
const newShape = shape.slice();
newShape[implicitIdx] = size / shapeProd;
return newShape;
}
function parseAxisParam(axis, shape) {
const rank = shape.length;
// Normalize input
axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);
// Check for valid range
assert(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
`got axis ${axis}`);
// Check for only integers
assert(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` +
`got axis ${axis}`);
// Handle negative axis.
return axis.map(a => a < 0 ? rank + a : a);
}
/** Reduces the shape by removing all dimensions of shape 1. */
function squeezeShape(shape, axis) {
const newShape = [];
const keptDims = [];
const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
const axes = (axis == null || isEmptyArray) ?
null :
parseAxisParam(axis, shape).sort();
let j = 0;
for (let i = 0; i < shape.length; ++i) {
if (axes != null) {
if (axes[j] === i && shape[i] !== 1) {
throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
}
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
if (axes[j] <= i) {
j++;
}
}
if (shape[i] !== 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
}
return { newShape, keptDims };
}
function getTypedArrayFromDType(dtype, size) {
return getArrayFromDType(dtype, size);
}
function getArrayFromDType(dtype, size) {
let values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
}
else if (dtype === 'int32') {
values = new Int32Array(size);
}
else if (dtype === 'bool') {
values = new Uint8Array(size);
}
else if (dtype === 'string') {
values = new Array(size);
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
return values;
}
function checkConversionForErrors(vals, dtype) {
for (let i = 0; i < vals.length; i++) {
const num = vals[i];
if (isNaN(num) || !isFinite(num)) {
throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);
}
}
}
/** Returns true if the dtype is valid. */
function isValidDtype(dtype) {
return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
dtype === 'int32' || dtype === 'string';
}
/**
* Returns true if the new type can't encode the old type without loss of
* precision.
*/
function hasEncodingLoss(oldType, newType) {
if (newType === 'complex64') {
return false;
}
if (newType === 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'bool' && oldType === 'bool') {
return false;
}
return true;
}
function bytesPerElement(dtype) {
if (dtype === 'float32' || dtype === 'int32') {
return 4;
}
else if (dtype === 'complex64') {
return 8;
}
else if (dtype === 'bool') {
return 1;
}
else {
throw new Error(`Unknown dtype ${dtype}`);
}
}
/**
* Returns the approximate number of bytes allocated in the string array - 2
* bytes per character. Computing the exact bytes for a native string in JS
* is not possible since it depends on the encoding of the html page that
* serves the website.
*/
function bytesFromStringArray(arr) {
if (arr == null) {
return 0;
}
let bytes = 0;
arr.forEach(x => bytes += x.length);
return bytes;
}
/** Returns true if the value is a string. */
function isString(value) {
return typeof value === 'string' || value instanceof String;
}
function isBoolean(value) {
return typeof value === 'boolean';
}
function isNumber(value) {
return typeof value === 'number';
}
function inferDtype(values) {
if (Array.isArray(values)) {
return inferDtype(values[0]);
}
if (values instanceof Float32Array) {
return 'float32';
}
else if (values instanceof Int32Array || values instanceof Uint8Array ||
values instanceof Uint8ClampedArray) {
return 'int32';
}
else if (isNumber(values)) {
return 'float32';
}
else if (isString(values)) {
return 'string';
}
else if (isBoolean(values)) {
return 'bool';
}
return 'float32';
}
function isFunction(f) {
return !!(f && f.constructor && f.call && f.apply);
}
function nearestDivisor(size, start) {
for (let i = start; i < size; ++i) {
if (size % i === 0) {
return i;
}
}
return size;
}
function computeStrides(shape) {
const rank = shape.length;
if (rank < 2) {
return [];
}
// Last dimension has implicit stride of 1, thus having D-1 (instead of D)
// strides.
const strides = new Array(rank - 1);
strides[rank - 2] = shape[rank - 1];
for (let i = rank - 3; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
function createNestedArray(offset, shape, a, isComplex = false) {
const ret = new Array();
if (shape.length === 1) {
const d = shape[0] * (isComplex ? 2 : 1);
for (let i = 0; i < d; i++) {
ret[i] = a[offset + i];
}
}
else {
const d = shape[0];
const rest = shape.slice(1);
const len = rest.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
for (let i = 0; i < d; i++) {
ret[i] = createNestedArray(offset + i * len, rest, a, isComplex);
}
}
return ret;
}
// Provide a nested array of TypedArray in given shape.
function toNestedArray(shape, a, isComplex = false) {
if (shape.length === 0) {
// Scalar type should return a single number.
return a[0];
}
const size = shape.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
if (size === 0) {
// A tensor with shape zero should be turned into empty list.
return [];
}
if (size !== a.length) {
throw new Error(`[${shape}] does not match the input size ${a.length}${isComplex ? ' for a complex tensor' : ''}.`);
}
return createNestedArray(0, shape, a, isComplex);
}
function convertBackendValuesAndArrayBuffer(data, dtype) {
// If is type Uint8Array[], return it directly.
if (Array.isArray(data)) {
return data;
}
if (dtype === 'float32') {
return data instanceof Float32Array ? data : new Float32Array(data);
}
else if (dtype === 'int32') {
return data instanceof Int32Array ? data : new Int32Array(data);
}
else if (dtype === 'bool' || dtype === 'string') {
return Uint8Array.from(new Int32Array(data));
}
else {
throw new Error(`Unknown dtype ${dtype}`);
}
}
function makeOnesTypedArray(size, dtype) {
const array = makeZerosTypedArray(size, dtype);
for (let i = 0; i < array.length; i++) {
array[i] = 1;
}
return array;
}
function makeZerosTypedArray(size, dtype) {
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(size);
}
else if (dtype === 'int32') {
return new Int32Array(size);
}
else if (dtype === 'bool') {
return new Uint8Array(size);
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
}
/**
* Make nested `TypedArray` filled with zeros.
* @param shape The shape information for the nested array.
* @param dtype dtype of the array element.
*/
function makeZerosNestedTypedArray(shape, dtype) {
const size = shape.reduce((prev, curr) => prev * curr, 1);
if (dtype == null || dtype === 'float32') {
return toNestedArray(shape, new Float32Array(size));
}
else if (dtype === 'int32') {
return toNestedArray(shape, new Int32Array(size));
}
else if (dtype === 'bool') {
return toNestedArray(shape, new Uint8Array(size));
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
}
function assertNonNegativeIntegerDimensions(shape) {
shape.forEach(dimSize => {
assert(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` +
`shape [${shape}].`);
});
}
/**
* Computes flat index for a given location (multidimentionsal index) in a
* Tensor/multidimensional array.
*
* @param locs Location in the tensor.
* @param rank Rank of the tensor.
* @param strides Tensor strides.
*/
function locToIndex(locs, rank, strides) {
if (rank === 0) {
return 0;
}
else if (rank === 1) {
return locs[0];
}
let index = locs[locs.length - 1];
for (let i = 0; i < locs.length - 1; ++i) {
index += strides[i] * locs[i];
}
return index;
}
/**
* Computes the location (multidimensional index) in a
* tensor/multidimentional array for a given flat index.
*
* @param index Index in flat array.
* @param rank Rank of tensor.
* @param strides Strides of tensor.
*/
function indexToLoc(index, rank, strides) {
if (rank === 0) {
return [];
}
else if (rank === 1) {
return [index];
}
const locs = new Array(rank);
for (let i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / strides[i]);
index -= locs[i] * strides[i];
}
locs[locs.length - 1] = index;
return locs;
}
/**
* This method asserts whether an object is a Promise instance.
* @param object
*/
// tslint:disable-next-line: no-any
function isPromise(object) {
// We chose to not use 'obj instanceOf Promise' for two reasons:
// 1. It only reliably works for es6 Promise, not other Promise
// implementations.
// 2. It doesn't work with framework that uses zone.js. zone.js monkey
// patch the async calls, so it is possible the obj (patched) is
// comparing to a pre-patched Promise.
return object && object.then && typeof object.then === 'function';
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
/**
* The environment contains evaluated flags as well as the registered platform.
* This is always used as a global singleton and can be retrieved with
* `tf.env()`.
*
* @doc {heading: 'Environment'}
*/
class Environment {
// tslint:disable-next-line: no-any
constructor(global) {
this.global = global;
this.flags = {};
this.flagRegistry = {};
this.urlFlags = {};
// Jasmine spies on this in 'environment_test.ts'
this.getQueryParams = getQueryParams;
this.populateURLFlags();
}
setPlatform(platformName, platform) {
if (this.platform != null) {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.warn(`Platform ${this.platformName} has already been set. ` +
`Overwriting the platform with ${platformName}.`);
}
}
this.platformName = platformName;
this.platform = platform;
}
registerFlag(flagName, evaluationFn, setHook) {
this.flagRegistry[flagName] = { evaluationFn, setHook };
// Override the flag value from the URL. This has to happen here because
// the environment is initialized before flags get registered.
if (this.urlFlags[flagName] != null) {
const flagValue = this.urlFlags[flagName];
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`);
}
this.set(flagName, flagValue);
}
}
async getAsync(flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
this.flags[flagName] = await this.evaluateFlag(flagName);
return this.flags[flagName];
}
get(flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
const flagValue = this.evaluateFlag(flagName);
if (isPromise(flagValue)) {
throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` +
`Please use getAsync() instead.`);
}
this.flags[flagName] = flagValue;
return this.flags[flagName];
}
getNumber(flagName) {
return this.get(flagName);
}
getBool(flagName) {
return this.get(flagName);
}
getString(flagName) {
return this.get(flagName);
}
getFlags() {
return this.flags;
}
// For backwards compatibility.
get features() {
return this.flags;
}
set(flagName, value) {
if (this.flagRegistry[flagName] == null) {
throw new Error(`Cannot set flag ${flagName} as it has not been registered.`);
}
this.flags[flagName] = value;
if (this.flagRegistry[flagName].setHook != null) {
this.flagRegistry[flagName].setHook(value);
}
}
evaluateFlag(flagName) {
if (this.flagRegistry[flagName] == null) {
throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`);
}
return this.flagRegistry[flagName].evaluationFn();
}
setFlags(flags) {
this.flags = Object.assign({}, flags);
}
reset() {
this.flags = {};
this.urlFlags = {};
this.populateURLFlags();
}
populateURLFlags() {
if (typeof this.global === 'undefined' ||
typeof this.global.location === 'undefined' ||
typeof this.global.location.search === 'undefined') {
return;
}
const urlParams = this.getQueryParams(this.global.location.search);
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
keyValues.forEach(keyValue => {
const [key, value] = keyValue.split(':');
this.urlFlags[key] = parseValue(key, value);
});
}
}
}
function getQueryParams(queryString) {
const params = {};
queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => {
decodeParam(params, t[0], t[1]);
return t.join('=');
});
return params;
}
function decodeParam(params, name, value) {
params[decodeURIComponent(name)] = decodeURIComponent(value || '');
}
function parseValue(flagName, value) {
const lowerCaseValue = value.toLowerCase();
if (lowerCaseValue === 'true' || lowerCaseValue === 'false') {
return lowerCaseValue === 'true';
}
else if (`${+lowerCaseValue}` === lowerCaseValue) {
return +lowerCaseValue;
}
else {
return value;
}
}
/**
* Returns the current environment (a global singleton).
*
* The environment object contains the evaluated feature values as well as the
* active platform.
*
* @doc {heading: 'Environment'}
*/
function env() {
return ENV$1;
}
let ENV$1 = null;
function setEnvironmentGlobal(environment) {
ENV$1 = environment;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Note that the identifier globalNameSpace is scoped to this module, but will
// always resolve to the same global object regardless of how the module is
// resolved.
// tslint:disable-next-line:no-any
let globalNameSpace;
// tslint:disable-next-line:no-any
function getGlobalNamespace() {
if (globalNameSpace == null) {
// tslint:disable-next-line:no-any
let ns;
if (typeof (window) !== 'undefined') {
ns = window;
}
else if (typeof (global) !== 'undefined') {
ns = global;
}
else if (typeof (process) !== 'undefined') {
ns = process;
}
else if (typeof (self) !== 'undefined') {
ns = self;
}
else {
throw new Error('Could not find a global object');
}
globalNameSpace = ns;
}
return globalNameSpace;
}
// tslint:disable-next-line:no-any
function getGlobalMap() {
const ns = getGlobalNamespace();
if (ns._tfGlobals == null) {
ns._tfGlobals = new Map();
}
return ns._tfGlobals;
}
/**
* Returns a globally accessible 'singleton' object.
*
* @param key the name of the object
* @param init a function to initialize to initialize this object
* the first time it is fetched.
*/
function getGlobal(key, init) {
const globalMap = getGlobalMap();
if (globalMap.has(key)) {
return globalMap.get(key);
}
else {
const singleton = init();
globalMap.set(key, singleton);
return globalMap.get(key);
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function warn(...msg) {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.warn(...msg);
}
}
function log$1(...msg) {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.log(...msg);
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const kernelRegistry = getGlobal('kernelRegistry', () => new Map());
const gradRegistry = getGlobal('gradRegistry', () => new Map());
/**
* Returns the kernel function (code) associated with the provided names.
*
* @param kernelName The official name of the kernel.
* @param backendName The official name of the backend.
*/
function getKernel(kernelName, backendName) {
const key = makeKey(kernelName, backendName);
return kernelRegistry.get(key);
}
/**
* Returns the registered gradient info associated with the provided kernel.
* @param kernelName The official TF kernel name.
*/
function getGradient(kernelName) {
return gradRegistry.get(kernelName);
}
function getKernelsForBackend(backendName) {
const it = kernelRegistry.entries();
const result = [];
while (true) {
const { done, value } = it.next();
if (done) {
break;
}
const [key, config] = value;
const [backend,] = key.split('_');
if (backend === backendName) {
result.push(config);
}
}
return result;
}
/**
* Registers the function (forward pass) for the kernel in a global registry.
*
* @param config A config object with the following properties:
* - `kernelName` The official name of the kernel.
* - `backendName` The official name of the backend.
* - `kernelFunc` The function to run during the forward pass of the kernel.
* - `setupFunc` Optional. Gets called once, after the backend initializes.
* - `disposeFunc` Optional. Gets called once, right before the backend is
* disposed.
*/
function registerKernel(config) {
const { kernelName, backendName } = config;
const key = makeKey(kernelName, backendName);
if (kernelRegistry.has(key)) {
warn(`The kernel '${kernelName}' for backend ` +
`'${backendName}' is already registered`);
}
kernelRegistry.set(key, config);
}
/**
* Registers a gradient function for a given kernel in the global registry,
* to be used during the back-propagation of that kernel.
*
* @param config An object with the following properties:
* - `kernelName` The name of the kernel that the gradient function is for.
* - `gradFunc` The function to run during back-propagation.
*/
function registerGradient(config) {
const { kernelName } = config;
if (gradRegistry.has(kernelName)) {
// TODO (yassogba) after 3.0 assess whether we need to keep this gated
// to debug mode.
if (env().getBool('DEBUG')) {
warn(`Overriding the gradient for '${kernelName}'`);
}
}
gradRegistry.set(kernelName, config);
}
/**
* Removes the kernel function from the registry.
*
* @param kernelName The official name of the kernel.
* @param backendName The official name of the backend.
*
*/
function unregisterKernel(kernelName, backendName) {
const key = makeKey(kernelName, backendName);
if (!kernelRegistry.has(key)) {
throw new Error(`The kernel '${kernelName}' for backend ` +
`'${backendName}' is not registered`);
}
kernelRegistry.delete(key);
}
/** Removes the registered gradient from the global registry. */
function unregisterGradient(kernelName) {
if (!gradRegistry.has(kernelName)) {
throw new Error(`The gradient '${kernelName}' for backend is not registered`);
}
gradRegistry.delete(kernelName);
}
/**
* Finds kernels that have already been registered to a backend and re-registers
* them for a new backend. Useful for registering custom backends.
* @param registeredBackendName Already registered backend.
* @param newBackendName New backend.
*/
function copyRegisteredKernels(registeredBackendName, newBackendName) {
const kernels = getKernelsForBackend(registeredBackendName);
kernels.forEach(kernelConfig => {
const newKernelConfig = Object.assign({}, kernelConfig, { backendName: newBackendName });
registerKernel(newKernelConfig);
});
}
function makeKey(kernelName, backendName) {
return `${backendName}_${kernelName}`;
}
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function isTypedArrayBrowser(a) {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
}
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line
const Long =
// tslint:disable-next-line
LongExports__namespace.default || LongExports__namespace;
function hexToLong(hex) {
return Long.fromString(hex, true, 16);
}
// Some primes between 2^63 and 2^64 for various uses.
// Hex 0xc3a5c85c97cb3127
const k0 = hexToLong('c3a5c85c97cb3127');
// Hex 0xb492b66fbe98f273
const k1 = hexToLong('b492b66fbe98f273');
// Hex 0x9ae16a3b2f90404f
const k2 = hexToLong('9ae16a3b2f90404f');
function shiftMix(val) {
return val.xor(val.shru(47));
}
function fetch$2(s, offset, numBytes) {
const bytes = s.slice(offset, offset + numBytes);
return Long.fromBytes(Array.from(bytes), true, true);
}
function fetch64(s, offset) {
return fetch$2(s, offset, 8);
}
function fetch32(s, offset) {
return fetch$2(s, offset, 4);
}
function rotate64(val, shift) {
// Avoid shifting by 64: doing so yields an undefined result.
return shift === 0 ? val : val.shru(shift).or(val.shl(64 - shift));
}
function hashLen16(u, v, mul = hexToLong('9ddfea08eb382d69')) {
// Murmur-inspired hashing.
let a = u.xor(v).mul(mul);
a = a.xor(a.shru(47));
let b = v.xor(a).mul(mul);
b = b.xor(b.shru(47));
b = b.mul(mul);
return b;
}
// Return a 16-byte hash for 48 bytes. Quick and dirty.
// Callers do best to use "random-looking" values for a and b.
function weakHashLen32WithSeeds(w, x, y, z, a, b) {
a = a.add(w);
b = rotate64(b.add(a).add(z), 21);
const c = a;
a = a.add(x);
a = a.add(y);
b = b.add(rotate64(a, 44));
return [a.add(z), b.add(c)];
}
function weakHashLen32WithSeedsStr(s, offset, a, b) {
return weakHashLen32WithSeeds(fetch64(s, offset), fetch64(s, offset + 8), fetch64(s, offset + 16), fetch64(s, offset + 24), a, b);
}
function hashLen0to16(s, len = s.length) {
if (len >= 8) {
const mul = k2.add(len * 2);
const a = fetch64(s, 0).add(k2);
const b = fetch64(s, len - 8);
const c = rotate64(b, 37).mul(mul).add(a);
const d = rotate64(a, 25).add(b).mul(mul);
return hashLen16(c, d, mul);
}
if (len >= 4) {
const mul = k2.add(len * 2);
const a = fetch32(s, 0);
return hashLen16(a.shl(3).add(len), fetch32(s, len - 4), mul);
}
if (len > 0) {
const a = s[0];
const b = s[len >> 1];
const c = s[len - 1];
const y = a + (b << 8);
const z = len + (c << 2);
return shiftMix(k2.mul(y).xor(k0.mul(z))).mul(k2);
}
return k2;
}
function hashLen17to32(s, len = s.length) {
const mul = k2.add(len * 2);
const a = fetch64(s, 0).mul(k1);
const b = fetch64(s, 8);
const c = fetch64(s, len - 8).mul(mul);
const d = fetch64(s, len - 16).mul(k2);
return hashLen16(rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d), a.add(rotate64(b.add(k2), 18)).add(c), mul);
}
function hashLen33to64(s, len = s.length) {
const mul = k2.add(len * 2);
const a = fetch64(s, 0).mul(k2);
const b = fetch64(s, 8);
const c = fetch64(s, len - 8).mul(mul);
const d = fetch64(s, len - 16).mul(k2);
const y = rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d);
const z = hashLen16(y, a.add(rotate64(b.add(k2), 18)).add(c), mul);
const e = fetch64(s, 16).mul(mul);
const f = fetch64(s, 24);
const g = y.add(fetch64(s, len - 32)).mul(mul);
const h = z.add(fetch64(s, len - 24)).mul(mul);
return hashLen16(rotate64(e.add(f), 43).add(rotate64(g, 30)).add(h), e.add(rotate64(f.add(a), 18)).add(g), mul);
}
function fingerPrint64(s, len = s.length) {
const seed = Long.fromNumber(81, true);
if (len <= 32) {
if (len <= 16) {
return hashLen0to16(s, len);
}
else {
return hashLen17to32(s, len);
}
}
else if (len <= 64) {
return hashLen33to64(s, len);
}
// For strings over 64 bytes we loop. Internal state consists of
// 56 bytes: v, w, x, y, and z.
let x = seed;
let y = seed.mul(k1).add(113);
let z = shiftMix(y.mul(k2).add(113)).mul(k2);
let v = [Long.UZERO, Long.UZERO];
let w = [Long.UZERO, Long.UZERO];
x = x.mul(k2).add(fetch64(s, 0));
let offset