@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
1,152 lines (1,122 loc) • 484 kB
JavaScript
/**
* @license
* Copyright 2024 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
(function (global, factory) {
typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports, require('@tensorflow/tfjs-core')) :
typeof define === 'function' && define.amd ? define(['exports', '@tensorflow/tfjs-core'], factory) :
(global = typeof globalThis !== 'undefined' ? globalThis : global || self, factory(global.tf = global.tf || {}, global.tf));
})(this, (function (exports, tfjsCore) { 'use strict';
function _mergeNamespaces(n, m) {
m.forEach(function (e) {
e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) {
if (k !== 'default' && !(k in n)) {
var d = Object.getOwnPropertyDescriptor(e, k);
Object.defineProperty(n, k, d.get ? d : {
enumerable: true,
get: function () { return e[k]; }
});
}
});
});
return n;
}
/******************************************************************************
Copyright (c) Microsoft Corporation.
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.
***************************************************************************** */
/* global Reflect, Promise */
var extendStatics = function (d, b) {
extendStatics = Object.setPrototypeOf ||
({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
function (d, b) { for (var p in b)
if (Object.prototype.hasOwnProperty.call(b, p))
d[p] = b[p]; };
return extendStatics(d, b);
};
function __extends(d, b) {
if (typeof b !== "function" && b !== null)
throw new TypeError("Class extends value " + String(b) + " is not a constructor or null");
extendStatics(d, b);
function __() { this.constructor = d; }
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
}
function __awaiter(thisArg, _arguments, P, generator) {
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try {
step(generator.next(value));
}
catch (e) {
reject(e);
} }
function rejected(value) { try {
step(generator["throw"](value));
}
catch (e) {
reject(e);
} }
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
}
function __generator(thisArg, body) {
var _ = { label: 0, sent: function () { if (t[0] & 1)
throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function () { return this; }), g;
function verb(n) { return function (v) { return step([n, v]); }; }
function step(op) {
if (f)
throw new TypeError("Generator is already executing.");
while (_)
try {
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done)
return t;
if (y = 0, t)
op = [op[0] & 2, t.value];
switch (op[0]) {
case 0:
case 1:
t = op;
break;
case 4:
_.label++;
return { value: op[1], done: false };
case 5:
_.label++;
y = op[1];
op = [0];
continue;
case 7:
op = _.ops.pop();
_.trys.pop();
continue;
default:
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) {
_ = 0;
continue;
}
if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) {
_.label = op[1];
break;
}
if (op[0] === 6 && _.label < t[1]) {
_.label = t[1];
t = op;
break;
}
if (t && _.label < t[2]) {
_.label = t[2];
_.ops.push(op);
break;
}
if (t[2])
_.ops.pop();
_.trys.pop();
continue;
}
op = body.call(thisArg, _);
}
catch (e) {
op = [6, e];
y = 0;
}
finally {
f = t = 0;
}
if (op[0] & 5)
throw op[1];
return { value: op[0] ? op[1] : void 0, done: true };
}
}
function __values(o) {
var s = typeof Symbol === "function" && Symbol.iterator, m = s && o[s], i = 0;
if (m)
return m.call(o);
if (o && typeof o.length === "number")
return {
next: function () {
if (o && i >= o.length)
o = void 0;
return { value: o && o[i++], done: !o };
}
};
throw new TypeError(s ? "Object is not iterable." : "Symbol.iterator is not defined.");
}
function __read(o, n) {
var m = typeof Symbol === "function" && o[Symbol.iterator];
if (!m)
return o;
var i = m.call(o), r, ar = [], e;
try {
while ((n === void 0 || n-- > 0) && !(r = i.next()).done)
ar.push(r.value);
}
catch (error) {
e = { error: error };
}
finally {
try {
if (r && !r.done && (m = i["return"]))
m.call(i);
}
finally {
if (e)
throw e.error;
}
}
return ar;
}
function __spreadArray(to, from, pack) {
if (pack || arguments.length === 2)
for (var i = 0, l = from.length, ar; i < l; i++) {
if (ar || !(i in from)) {
if (!ar)
ar = Array.prototype.slice.call(from, 0, i);
ar[i] = from[i];
}
}
return to.concat(ar || Array.prototype.slice.call(from));
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// This enum must align with the enum defined in cc/backend.h.
var CppDType;
(function (CppDType) {
CppDType[CppDType["float32"] = 0] = "float32";
CppDType[CppDType["int32"] = 1] = "int32";
CppDType[CppDType["bool"] = 2] = "bool";
CppDType[CppDType["string"] = 3] = "string";
CppDType[CppDType["complex64"] = 4] = "complex64";
})(CppDType || (CppDType = {}));
// Must match enum in cc/fusable_activations.h.
var FusableActivation;
(function (FusableActivation) {
FusableActivation[FusableActivation["linear"] = 0] = "linear";
FusableActivation[FusableActivation["relu"] = 1] = "relu";
FusableActivation[FusableActivation["relu6"] = 2] = "relu6";
FusableActivation[FusableActivation["prelu"] = 3] = "prelu";
FusableActivation[FusableActivation["leakyrelu"] = 4] = "leakyrelu";
FusableActivation[FusableActivation["sigmoid"] = 5] = "sigmoid";
FusableActivation[FusableActivation["elu"] = 6] = "elu";
})(FusableActivation || (FusableActivation = {}));
var wasmFusedMatMul;
function setup$1a(backend) {
wasmFusedMatMul = backend.wasm.cwrap(tfjsCore._FusedMatMul, null /* void */, [
'number',
'array',
'number',
'number',
'array',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number' // out_id
]);
}
function fusedBatchMatMul(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var a = inputs.a, b = inputs.b, bias = inputs.bias, preluActivationWeights = inputs.preluActivationWeights;
if (a.dtype !== 'float32' || b.dtype !== 'float32') {
throw new Error("_FusedMatMul for non non-float32 tensors not yet supported.");
}
var transposeA = attrs.transposeA, transposeB = attrs.transposeB, activation = attrs.activation, leakyreluAlpha = attrs.leakyreluAlpha;
var aId = backend.dataIdMap.get(a.dataId).id;
var bId = backend.dataIdMap.get(b.dataId).id;
var biasId = 0;
if (bias != null) {
var biasData = backend.dataIdMap.get(bias.dataId);
if (biasData.shape.length !== 1) {
throw new Error("_FusedMatMul only supports rank-1 bias but got " +
"rank ".concat(biasData.shape.length, "."));
}
biasId = biasData.id;
}
var preluActivationWeightsId = preluActivationWeights == null ?
0 :
backend.dataIdMap.get(preluActivationWeights.dataId).id;
var fusedActivation = FusableActivation[activation];
if (fusedActivation == null) {
throw new Error("".concat(activation, " activation not yet supported for FusedConv2D ") +
"in the wasm backend.");
}
var leftDim = transposeA ? a.shape[2] : a.shape[1];
var rightDim = transposeB ? b.shape[1] : b.shape[2];
var batchDims = tfjsCore.broadcast_util.assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
var out = backend.makeOutput(__spreadArray(__spreadArray([], __read(batchDims), false), [leftDim, rightDim], false), a.dtype);
var outId = backend.dataIdMap.get(out.dataId).id;
var aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer);
var bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer);
wasmFusedMatMul(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, transposeA, transposeB, fusedActivation, biasId, preluActivationWeightsId, leakyreluAlpha || 0, outId);
return out;
}
var _fusedMatMulConfig = {
kernelName: tfjsCore._FusedMatMul,
backendName: 'wasm',
setupFunc: setup$1a,
kernelFunc: fusedBatchMatMul
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createUnaryKernelConfig(kernelName, outType) {
var wasmFunc;
function setupFunc(backend) {
wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [
'number',
'number',
'number', // out_id
]);
}
function kernelFunc(args) {
var backend = args.backend, x = args.inputs.x;
var xId = backend.dataIdMap.get(x.dataId).id;
var out = backend.makeOutput(x.shape, outType || x.dtype);
var outId = backend.dataIdMap.get(out.dataId).id;
// Short-circuit zero-sized tensors.
if (tfjsCore.util.sizeFromShape(out.shape) === 0) {
return out;
}
wasmFunc(xId, CppDType[x.dtype], outId);
return out;
}
return { kernelName: kernelName, backendName: 'wasm', setupFunc: setupFunc, kernelFunc: kernelFunc };
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var absConfig = createUnaryKernelConfig(tfjsCore.Abs);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var acosConfig = createUnaryKernelConfig(tfjsCore.Acos);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var acoshConfig = createUnaryKernelConfig(tfjsCore.Acosh);
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createBinaryKernelConfig(kernelName, supportsFullBroadcast, dtype) {
var wasmFunc;
function setupFunc(backend) {
wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [
'number',
'array',
'number',
'number',
'array',
'number',
'number',
'number' // out_id
]);
}
function kernelFunc(args) {
var backend = args.backend, inputs = args.inputs;
var a = inputs.a, b = inputs.b;
var aId = backend.dataIdMap.get(a.dataId).id;
var bId = backend.dataIdMap.get(b.dataId).id;
var outputType = dtype != null ? dtype : a.dtype;
var newShape = tfjsCore.backend_util.assertAndGetBroadcastShape(a.shape, b.shape);
var out = backend.makeOutput(newShape, outputType);
// Short-circuit zero-sized tensors.
if (tfjsCore.util.sizeFromShape(newShape) === 0) {
return out;
}
var aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer);
var bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer);
var outId = backend.dataIdMap.get(out.dataId).id;
var kernelFunc = function () { return wasmFunc(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, CppDType[a.dtype], outId); };
kernelFunc();
return out;
}
return { kernelName: kernelName, backendName: 'wasm', setupFunc: setupFunc, kernelFunc: kernelFunc };
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var addConfig = createBinaryKernelConfig(tfjsCore.Add);
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var wasmFunc$6;
function setupFunc$1(backend) {
wasmFunc$6 = backend.wasm.cwrap(tfjsCore.AddN, null /* void */, [
'array',
'number',
'number',
'number', // out_id
]);
}
function addn(args) {
var inputs = args.inputs, backend = args.backend;
var out = backend.makeOutput(inputs[0].shape, inputs[0].dtype);
// Short-circuit zero-sized tensors.
if (tfjsCore.util.sizeFromShape(out.shape) === 0) {
return out;
}
var inputIds = inputs.map(function (x) { return backend.dataIdMap.get(x.dataId).id; });
var inputIdsBytes = new Uint8Array(new Int32Array(inputIds).buffer);
var outId = backend.dataIdMap.get(out.dataId).id;
wasmFunc$6(inputIdsBytes, inputIds.length, CppDType[out.dtype], outId);
return out;
}
var addNConfig = {
kernelName: tfjsCore.AddN,
backendName: 'wasm',
setupFunc: setupFunc$1,
kernelFunc: addn,
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function identity(args) {
var x = args.inputs.x, backend = args.backend;
if (x.dtype === 'string') {
return tfjsCore.tensor(backend.readSync(x.dataId), x.shape, x.dtype);
}
var out = backend.makeOutput(x.shape, x.dtype);
var inVals = backend.typedArrayFromHeap(x);
var outVals = backend.typedArrayFromHeap(out);
outVals.set(inVals);
return out;
}
var identityConfig = {
kernelName: tfjsCore.Identity,
backendName: 'wasm',
kernelFunc: identity,
};
var wasmTranspose;
function setup$19(backend) {
wasmTranspose = backend.wasm.cwrap(tfjsCore.Transpose, null /* void */, [
'number',
'array',
'number',
'number',
'number',
'array',
'number', // perm.length
]);
}
function transpose(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
// Reduce any dimensions with size one. Lower-rank transpose kernel performs
// better due to simpler memory access pattern.
var _a = __read(removeOneSizeDims(inputs.x.shape, attrs.perm), 2), reducedShape = _a[0], perm = _a[1];
var permIsNoOp = true;
for (var i = 0; i < perm.length; i++) {
if (perm[i] !== i) {
permIsNoOp = false;
}
}
var outShape = computeOutShape(inputs.x.shape, attrs.perm);
var x = {
dataId: inputs.x.dataId,
shape: reducedShape,
dtype: inputs.x.dtype
};
if (permIsNoOp) {
var cloned = identity({ inputs: inputs, backend: backend });
cloned.shape = outShape;
return cloned;
}
var out = backend.makeOutput(outShape, x.dtype);
var xId = backend.dataIdMap.get(x.dataId).id;
var outId = backend.dataIdMap.get(out.dataId).id;
var permBytes = new Uint8Array(new Int32Array(perm).buffer);
var xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer);
wasmTranspose(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], outId, permBytes, perm.length);
return out;
}
function computeOutShape(inShape, perm) {
var outShape = new Array(inShape.length);
for (var i = 0; i < outShape.length; i++) {
outShape[i] = inShape[perm[i]];
}
return outShape;
}
function removeOneSizeDims(shape, perm) {
var newShape = [];
var newPerm = [];
for (var i = 0; i < shape.length; ++i) {
if (shape[i] !== 1) {
newShape.push(shape[i]);
}
if (shape[perm[i]] !== 1) {
newPerm.push(perm[i]);
}
}
for (var i = 0; i < newPerm.length; ++i) {
var minValIdx = -1;
for (var j = 0; j < newPerm.length; ++j) {
if (newPerm[j] >= i &&
(minValIdx === -1 || newPerm[minValIdx] > newPerm[j])) {
minValIdx = j;
}
}
newPerm[minValIdx] = i;
}
return [newShape, newPerm];
}
var transposeConfig = {
kernelName: tfjsCore.Transpose,
backendName: 'wasm',
kernelFunc: transpose,
setupFunc: setup$19,
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Compute permutation axes and do a transpose if necessary.
*
* Used by reduction ops.
* @param x input TensorInfo
* @param axis reduction axes
* @param backend wasm backend instance
*/
function permuteAxesAndTranspose(x, axis, backend) {
var xShape = x.shape;
var xRank = x.shape.length;
var originalAxes = tfjsCore.util.parseAxisParam(axis, xShape);
var axes = originalAxes;
var permutedAxes = tfjsCore.backend_util.getAxesPermutation(axes, xRank);
var xTransposed = null;
var inputWasTransposed = false;
if (permutedAxes != null) {
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = xShape[permutedAxes[i]];
}
axes = tfjsCore.backend_util.getInnerMostAxes(axes.length, xRank);
xTransposed =
transpose({ inputs: { x: x }, attrs: { perm: permutedAxes }, backend: backend });
var xId = backend.dataIdMap.get(x.dataId).id;
var transposedId = backend.dataIdMap.get(xTransposed.dataId).id;
if (transposedId !== xId) {
inputWasTransposed = true;
}
}
return { transposed: xTransposed, originalAxes: originalAxes, axes: axes, inputWasTransposed: inputWasTransposed };
}
var wasmAll;
function setup$18(backend) {
wasmAll = backend.wasm.cwrap(tfjsCore.All, null /*void*/, ['number, number, number']);
}
function all(args) {
var backend = args.backend, inputs = args.inputs, attrs = args.attrs;
var axis = attrs.axis, keepDims = attrs.keepDims;
var x = inputs.x;
var xId = backend.dataIdMap.get(x.dataId).id;
var inputId = xId;
var input = x;
var _a = permuteAxesAndTranspose(x, axis, backend), transposed = _a.transposed, axes = _a.axes, originalAxes = _a.originalAxes, inputWasTransposed = _a.inputWasTransposed;
if (inputWasTransposed) {
var transposedId = backend.dataIdMap.get(transposed.dataId).id;
input = transposed;
inputId = transposedId;
}
var inputRank = input.shape.length;
tfjsCore.backend_util.assertAxesAreInnerMostDims('all', axes, inputRank);
var _b = __read(tfjsCore.backend_util.computeOutAndReduceShapes(input.shape, axes), 2), outShape = _b[0], reduceShape = _b[1];
var reduceSize = tfjsCore.util.sizeFromShape(reduceShape);
var out = backend.makeOutput(outShape, x.dtype);
if (tfjsCore.util.sizeFromShape(input.shape) !== 0) {
var outId = backend.dataIdMap.get(out.dataId).id;
wasmAll(inputId, reduceSize, outId);
}
if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}
if (keepDims) {
// reshape
var newShape = tfjsCore.backend_util.expandShapeToKeepDim(out.shape, originalAxes);
out.shape = newShape;
}
return out;
}
var allConfig = {
kernelName: tfjsCore.All,
backendName: 'wasm',
setupFunc: setup$18,
kernelFunc: all
};
var wasmAny;
function setup$17(backend) {
wasmAny = backend.wasm.cwrap(tfjsCore.Any, null /*void*/, ['number, number, number']);
}
function any(args) {
var backend = args.backend, inputs = args.inputs, attrs = args.attrs;
var axis = attrs.axis, keepDims = attrs.keepDims;
var x = inputs.x;
var xId = backend.dataIdMap.get(x.dataId).id;
var inputId = xId;
var input = x;
var _a = permuteAxesAndTranspose(x, axis, backend), transposed = _a.transposed, axes = _a.axes, originalAxes = _a.originalAxes, inputWasTransposed = _a.inputWasTransposed;
if (inputWasTransposed) {
var transposedId = backend.dataIdMap.get(transposed.dataId).id;
input = transposed;
inputId = transposedId;
}
var inputRank = input.shape.length;
tfjsCore.backend_util.assertAxesAreInnerMostDims('any', axes, inputRank);
var _b = __read(tfjsCore.backend_util.computeOutAndReduceShapes(input.shape, axes), 2), outShape = _b[0], reduceShape = _b[1];
var reduceSize = tfjsCore.util.sizeFromShape(reduceShape);
var out = backend.makeOutput(outShape, x.dtype);
if (tfjsCore.util.sizeFromShape(input.shape) !== 0) {
var outId = backend.dataIdMap.get(out.dataId).id;
wasmAny(inputId, reduceSize, outId);
}
if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}
if (keepDims) {
// reshape
var newShape = tfjsCore.backend_util.expandShapeToKeepDim(out.shape, originalAxes);
out.shape = newShape;
}
return out;
}
var anyConfig = {
kernelName: tfjsCore.Any,
backendName: 'wasm',
setupFunc: setup$17,
kernelFunc: any
};
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createArgMinMaxKernelConfig(kernelName) {
var wasmFunc;
function setupFunc(backend) {
wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [
'number',
'number',
'number',
'number',
'number' // out_id
]);
}
function kernelFunc(args) {
var backend = args.backend, inputs = args.inputs, attrs = args.attrs;
var axis = attrs.axis;
var x = inputs.x;
var xId = backend.dataIdMap.get(x.dataId).id;
var inputId = xId;
var input = x;
var _a = permuteAxesAndTranspose(x, axis, backend), transposed = _a.transposed, axes = _a.axes, inputWasTransposed = _a.inputWasTransposed;
if (inputWasTransposed) {
var transposedId = backend.dataIdMap.get(transposed.dataId).id;
if (transposedId !== xId) {
// transpose was not a no-op. We will need to dispose of this
// once we are done.
input = transposed;
inputId = transposedId;
}
}
var outShape = input.shape.slice(0, -1);
var out = backend.makeOutput(outShape, 'int32');
var outId = backend.dataIdMap.get(out.dataId).id;
var outerSize = tfjsCore.util.sizeFromShape(out.shape);
var innerSize = input.shape[axes[0]];
wasmFunc(inputId, CppDType[input.dtype], outerSize, innerSize, outId);
if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}
return out;
}
return {
kernelName: kernelName,
backendName: 'wasm',
setupFunc: setupFunc,
kernelFunc: kernelFunc,
};
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var argMaxConfig = createArgMinMaxKernelConfig(tfjsCore.ArgMax);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var argMinConfig = createArgMinMaxKernelConfig(tfjsCore.ArgMin);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var asinConfig = createUnaryKernelConfig(tfjsCore.Asin);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var asinhConfig = createUnaryKernelConfig(tfjsCore.Asinh);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atanConfig = createUnaryKernelConfig(tfjsCore.Atan);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atan2Config = createBinaryKernelConfig(tfjsCore.Atan2);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atanhConfig = createUnaryKernelConfig(tfjsCore.Atanh);
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var wasmAvgPool;
function setup$16(backend) {
wasmAvgPool = backend.wasm.cwrap(tfjsCore.AvgPool, null /* void */, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // outId
]);
}
function avgPool(args) {
var inputs = args.inputs, attrs = args.attrs, backend = args.backend;
var x = inputs.x;
var xId = backend.dataIdMap.get(x.dataId).id;
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
var convInfo = tfjsCore.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var padTop = convInfo.padInfo.top;
var padRight = convInfo.padInfo.right;
var padBottom = convInfo.padInfo.bottom;
var padLeft = convInfo.padInfo.left;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var channels = convInfo.inChannels;
if (convInfo.dataFormat !== 'channelsLast') {
throw new Error("wasm backend does not support dataFormat:'" +
"".concat(convInfo.dataFormat, "'. Please use 'channelsLast'."));
}
if (convInfo.dilationWidth !== 1 || convInfo.dilationHeight !== 1) {
throw new Error("was backend only supports average pooling with dilation = [1, 1], " +
"got [".concat(convInfo.dilationHeight, ", ").concat(convInfo.dilationWidth, "]."));
}
var out = backend.makeOutput(convInfo.outShape, 'float32');
var outId = backend.dataIdMap.get(out.dataId).id;
wasmAvgPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, strideHeight, strideWidth, channels, outId);
return out;
}
var avgPoolConfig = {
kernelName: tfjsCore.AvgPool,
backendName: 'wasm',
setupFunc: setup$16,
kernelFunc: avgPool
};
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var wasmAvgPool3D;
function setup$15(backend) {
wasmAvgPool3D = backend.wasm.cwrap('AvgPool3D', null, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // padLeft
]);
}
function avgPool3D(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var x = inputs.x;
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode, dataFormat = attrs.dataFormat;
var convInfo = tfjsCore.backend_util.computePool3DInfo(x.shape, filterSize, strides,
/*dilations=*/ 1, pad, dimRoundingMode, dataFormat);
var out = backend.makeOutput(convInfo.outShape, x.dtype);
wasmAvgPool3D(backend.dataIdMap.get(x.dataId).id, backend.dataIdMap.get(out.dataId).id, convInfo.batchSize,
// Since Pool3D ops (AvgPool3D and MaxPool3D) support 3D filter only, in
// channels should always equal to out channels.
/*channelSize=*/ convInfo.inChannels, convInfo.inDepth, convInfo.inHeight, convInfo.inWidth, convInfo.outDepth, convInfo.outHeight, convInfo.outWidth, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationDepth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterDepth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left);
return out;
}
var avgPool3DConfig = {
kernelName: tfjsCore.AvgPool3D,
backendName: 'wasm',
setupFunc: setup$15,
kernelFunc: avgPool3D
};
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var wasmAvgPool3DGrad;
function setup$14(backend) {
wasmAvgPool3DGrad = backend.wasm.cwrap('AvgPool3DGrad', null, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // filterWidth
]);
}
function avgPool3DGrad(args) {
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
var dy = inputs.dy, input = inputs.input;
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
var convInfo = tfjsCore.backend_util.computePool3DInfo(input.shape, filterSize, strides, /*dilations=*/ 1, pad, dimRoundingMode);
var dx = backend.makeOutput(input.shape, input.dtype);
wasmAvgPool3DGrad(backend.dataIdMap.get(dy.dataId).id, backend.dataIdMap.get(dx.dataId).id, convInfo.batchSize,
// Since Pool3D ops (AvgPool3D and MaxPool3D) support 3D filter only, in
// channels should always equal to out channels.
/*channelSize=*/ convInfo.inChannels, convInfo.inDepth, convInfo.inHeight, convInfo.inWidth, convInfo.outDepth, convInfo.outHeight, convInfo.outWidth, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationDepth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterDepth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left, convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth);
return dx;
}
var avgPool3DGradConfig = {
kernelName: tfjsCore.AvgPool3DGrad,
backendName: 'wasm',
setupFunc: setup$14,
kernelFunc: avgPool3DGrad
};
/**
* @license
* Copyright