@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
1,231 lines (1,201 loc) • 414 kB
JavaScript
/**
* @license
* Copyright 2024 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { _FusedMatMul, broadcast_util, util, Abs, Acos, Acosh, backend_util, Add, AddN, Identity, tensor, Transpose, All, Any, ArgMax, ArgMin, Asin, Asinh, Atan, Atan2, Atanh, AvgPool, AvgPool3D, AvgPool3DGrad, AvgPoolGrad, Reshape, BatchMatMul, slice_util, buffer, TensorBuffer, Slice, BatchToSpaceND, Bincount, BitwiseAnd, BroadcastArgs, Cast, Ceil, ClipByValue, Concat, Conv2D, Conv2DBackpropInput, Conv3D, Conv3DBackpropFilterV2, Conv3DBackpropInputV2, Cos, Cosh, CropAndResize, Cumprod, Cumsum, DenseBincount, DepthToSpace, DepthwiseConv2dNative, Diag, Dilation2D, Dilation2DBackpropFilter, Dilation2DBackpropInput, Elu, EluGrad, Equal, Erf, Exp, ExpandDims, Expm1, Fill, FlipLeftRight, Floor, FloorDiv, FusedBatchNorm, FusedConv2D, FusedDepthwiseConv2D, GatherNd, gather_util, GatherV2, Greater, GreaterEqual, IsFinite, IsInf, IsNan, LeakyRelu, Less, LessEqual, LinSpace, Log, Log1p, LogicalAnd, LogicalNot, LogicalOr, LogicalXor, LRN, LRNGrad, Max, Maximum, MaxPool, MaxPool3D, MaxPool3DGrad, MaxPoolGrad, MaxPoolWithArgmax, Mean, Min, Minimum, MirrorPad, Softmax, Multinomial, Mod, Multiply, Neg, NonMaxSuppressionV3, NonMaxSuppressionV4, NonMaxSuppressionV5, NotEqual, OneHot, OnesLike, Pack, PadV2, Pow, Prelu, Prod, Range, RealDiv, Reciprocal, Relu, Relu6, ResizeBilinear, ResizeBilinearGrad, ResizeNearestNeighbor, ResizeNearestNeighborGrad, Reverse, RotateWithOffset, Round, Rsqrt, ScatterNd, scatter_util, SearchSorted, Select, Selu, Sigmoid, Sign, Sin, Sinh, Softplus, SpaceToBatchND, SparseFillEmptyRows, SparseReshape, SparseSegmentMean, SparseSegmentSum, SparseToDense, SplitV, Sqrt, Square, SquaredDifference, Step, StridedSlice, StringNGrams, StringSplit, StringToHashBucketFast, Sub, Sum, Tan, Tanh, TensorScatterUpdate, Tile, TopK, Transform, Unique, Unpack, ZerosLike, registerKernel, env, KernelBackend, DataStorage, engine, deprecationWarn, registerBackend } from '@tensorflow/tfjs-core';
import require$$0 from 'fs';
import require$$1 from 'path';
import require$$3 from 'perf_hooks';
import require$$4 from 'os';
function _mergeNamespaces(n, m) {
m.forEach(function (e) {
e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) {
if (k !== 'default' && !(k in n)) {
var d = Object.getOwnPropertyDescriptor(e, k);
Object.defineProperty(n, k, d.get ? d : {
enumerable: true,
get: function () { return e[k]; }
});
}
});
});
return n;
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// This enum must align with the enum defined in cc/backend.h.
var CppDType;
(function (CppDType) {
CppDType[CppDType["float32"] = 0] = "float32";
CppDType[CppDType["int32"] = 1] = "int32";
CppDType[CppDType["bool"] = 2] = "bool";
CppDType[CppDType["string"] = 3] = "string";
CppDType[CppDType["complex64"] = 4] = "complex64";
})(CppDType || (CppDType = {}));
// Must match enum in cc/fusable_activations.h.
var FusableActivation;
(function (FusableActivation) {
FusableActivation[FusableActivation["linear"] = 0] = "linear";
FusableActivation[FusableActivation["relu"] = 1] = "relu";
FusableActivation[FusableActivation["relu6"] = 2] = "relu6";
FusableActivation[FusableActivation["prelu"] = 3] = "prelu";
FusableActivation[FusableActivation["leakyrelu"] = 4] = "leakyrelu";
FusableActivation[FusableActivation["sigmoid"] = 5] = "sigmoid";
FusableActivation[FusableActivation["elu"] = 6] = "elu";
})(FusableActivation || (FusableActivation = {}));
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let wasmFusedMatMul;
function setup$1a(backend) {
wasmFusedMatMul = backend.wasm.cwrap(_FusedMatMul, null /* void */, [
'number',
'array',
'number',
'number',
'array',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number' // out_id
]);
}
function fusedBatchMatMul(args) {
const { inputs, backend, attrs } = args;
const { a, b, bias, preluActivationWeights } = inputs;
if (a.dtype !== 'float32' || b.dtype !== 'float32') {
throw new Error(`_FusedMatMul for non non-float32 tensors not yet supported.`);
}
const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
const aId = backend.dataIdMap.get(a.dataId).id;
const bId = backend.dataIdMap.get(b.dataId).id;
let biasId = 0;
if (bias != null) {
const biasData = backend.dataIdMap.get(bias.dataId);
if (biasData.shape.length !== 1) {
throw new Error(`_FusedMatMul only supports rank-1 bias but got ` +
`rank ${biasData.shape.length}.`);
}
biasId = biasData.id;
}
const preluActivationWeightsId = preluActivationWeights == null ?
0 :
backend.dataIdMap.get(preluActivationWeights.dataId).id;
const fusedActivation = FusableActivation[activation];
if (fusedActivation == null) {
throw new Error(`${activation} activation not yet supported for FusedConv2D ` +
`in the wasm backend.`);
}
const leftDim = transposeA ? a.shape[2] : a.shape[1];
const rightDim = transposeB ? b.shape[1] : b.shape[2];
const batchDims = broadcast_util.assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
const out = backend.makeOutput([...batchDims, leftDim, rightDim], a.dtype);
const outId = backend.dataIdMap.get(out.dataId).id;
const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer);
const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer);
wasmFusedMatMul(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, transposeA, transposeB, fusedActivation, biasId, preluActivationWeightsId, leakyreluAlpha || 0, outId);
return out;
}
const _fusedMatMulConfig = {
kernelName: _FusedMatMul,
backendName: 'wasm',
setupFunc: setup$1a,
kernelFunc: fusedBatchMatMul
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createUnaryKernelConfig(kernelName, outType) {
let wasmFunc;
function setupFunc(backend) {
wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [
'number',
'number',
'number', // out_id
]);
}
function kernelFunc(args) {
const { backend, inputs: { x } } = args;
const xId = backend.dataIdMap.get(x.dataId).id;
const out = backend.makeOutput(x.shape, outType || x.dtype);
const outId = backend.dataIdMap.get(out.dataId).id;
// Short-circuit zero-sized tensors.
if (util.sizeFromShape(out.shape) === 0) {
return out;
}
wasmFunc(xId, CppDType[x.dtype], outId);
return out;
}
return { kernelName, backendName: 'wasm', setupFunc, kernelFunc };
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const absConfig = createUnaryKernelConfig(Abs);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const acosConfig = createUnaryKernelConfig(Acos);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const acoshConfig = createUnaryKernelConfig(Acosh);
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createBinaryKernelConfig(kernelName, supportsFullBroadcast, dtype) {
let wasmFunc;
function setupFunc(backend) {
wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [
'number',
'array',
'number',
'number',
'array',
'number',
'number',
'number' // out_id
]);
}
function kernelFunc(args) {
const { backend, inputs } = args;
const { a, b } = inputs;
const aId = backend.dataIdMap.get(a.dataId).id;
const bId = backend.dataIdMap.get(b.dataId).id;
const outputType = dtype != null ? dtype : a.dtype;
const newShape = backend_util.assertAndGetBroadcastShape(a.shape, b.shape);
const out = backend.makeOutput(newShape, outputType);
// Short-circuit zero-sized tensors.
if (util.sizeFromShape(newShape) === 0) {
return out;
}
const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer);
const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer);
const outId = backend.dataIdMap.get(out.dataId).id;
const kernelFunc = () => wasmFunc(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, CppDType[a.dtype], outId);
kernelFunc();
return out;
}
return { kernelName, backendName: 'wasm', setupFunc, kernelFunc };
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const addConfig = createBinaryKernelConfig(Add);
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let wasmFunc$6;
function setupFunc$1(backend) {
wasmFunc$6 = backend.wasm.cwrap(AddN, null /* void */, [
'array',
'number',
'number',
'number', // out_id
]);
}
function addn(args) {
const { inputs, backend } = args;
const out = backend.makeOutput(inputs[0].shape, inputs[0].dtype);
// Short-circuit zero-sized tensors.
if (util.sizeFromShape(out.shape) === 0) {
return out;
}
const inputIds = inputs.map(x => backend.dataIdMap.get(x.dataId).id);
const inputIdsBytes = new Uint8Array(new Int32Array(inputIds).buffer);
const outId = backend.dataIdMap.get(out.dataId).id;
wasmFunc$6(inputIdsBytes, inputIds.length, CppDType[out.dtype], outId);
return out;
}
const addNConfig = {
kernelName: AddN,
backendName: 'wasm',
setupFunc: setupFunc$1,
kernelFunc: addn,
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function identity(args) {
const { inputs: { x }, backend } = args;
if (x.dtype === 'string') {
return tensor(backend.readSync(x.dataId), x.shape, x.dtype);
}
const out = backend.makeOutput(x.shape, x.dtype);
const inVals = backend.typedArrayFromHeap(x);
const outVals = backend.typedArrayFromHeap(out);
outVals.set(inVals);
return out;
}
const identityConfig = {
kernelName: Identity,
backendName: 'wasm',
kernelFunc: identity,
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let wasmTranspose;
function setup$19(backend) {
wasmTranspose = backend.wasm.cwrap(Transpose, null /* void */, [
'number',
'array',
'number',
'number',
'number',
'array',
'number', // perm.length
]);
}
function transpose(args) {
const { inputs, backend, attrs } = args;
// Reduce any dimensions with size one. Lower-rank transpose kernel performs
// better due to simpler memory access pattern.
const [reducedShape, perm] = removeOneSizeDims(inputs.x.shape, attrs.perm);
let permIsNoOp = true;
for (let i = 0; i < perm.length; i++) {
if (perm[i] !== i) {
permIsNoOp = false;
}
}
const outShape = computeOutShape(inputs.x.shape, attrs.perm);
const x = {
dataId: inputs.x.dataId,
shape: reducedShape,
dtype: inputs.x.dtype
};
if (permIsNoOp) {
const cloned = identity({ inputs, backend });
cloned.shape = outShape;
return cloned;
}
const out = backend.makeOutput(outShape, x.dtype);
const xId = backend.dataIdMap.get(x.dataId).id;
const outId = backend.dataIdMap.get(out.dataId).id;
const permBytes = new Uint8Array(new Int32Array(perm).buffer);
const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer);
wasmTranspose(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], outId, permBytes, perm.length);
return out;
}
function computeOutShape(inShape, perm) {
const outShape = new Array(inShape.length);
for (let i = 0; i < outShape.length; i++) {
outShape[i] = inShape[perm[i]];
}
return outShape;
}
function removeOneSizeDims(shape, perm) {
const newShape = [];
const newPerm = [];
for (let i = 0; i < shape.length; ++i) {
if (shape[i] !== 1) {
newShape.push(shape[i]);
}
if (shape[perm[i]] !== 1) {
newPerm.push(perm[i]);
}
}
for (let i = 0; i < newPerm.length; ++i) {
let minValIdx = -1;
for (let j = 0; j < newPerm.length; ++j) {
if (newPerm[j] >= i &&
(minValIdx === -1 || newPerm[minValIdx] > newPerm[j])) {
minValIdx = j;
}
}
newPerm[minValIdx] = i;
}
return [newShape, newPerm];
}
const transposeConfig = {
kernelName: Transpose,
backendName: 'wasm',
kernelFunc: transpose,
setupFunc: setup$19,
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Compute permutation axes and do a transpose if necessary.
*
* Used by reduction ops.
* @param x input TensorInfo
* @param axis reduction axes
* @param backend wasm backend instance
*/
function permuteAxesAndTranspose(x, axis, backend) {
const xShape = x.shape;
const xRank = x.shape.length;
const originalAxes = util.parseAxisParam(axis, xShape);
let axes = originalAxes;
const permutedAxes = backend_util.getAxesPermutation(axes, xRank);
let xTransposed = null;
let inputWasTransposed = false;
if (permutedAxes != null) {
const newShape = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = xShape[permutedAxes[i]];
}
axes = backend_util.getInnerMostAxes(axes.length, xRank);
xTransposed =
transpose({ inputs: { x }, attrs: { perm: permutedAxes }, backend });
const xId = backend.dataIdMap.get(x.dataId).id;
const transposedId = backend.dataIdMap.get(xTransposed.dataId).id;
if (transposedId !== xId) {
inputWasTransposed = true;
}
}
return { transposed: xTransposed, originalAxes, axes, inputWasTransposed };
}
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let wasmAll;
function setup$18(backend) {
wasmAll = backend.wasm.cwrap(All, null /*void*/, ['number, number, number']);
}
function all(args) {
const { backend, inputs, attrs } = args;
const { axis, keepDims } = attrs;
const { x } = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
let inputId = xId;
let input = x;
const { transposed, axes, originalAxes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend);
if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
input = transposed;
inputId = transposedId;
}
const inputRank = input.shape.length;
backend_util.assertAxesAreInnerMostDims('all', axes, inputRank);
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(input.shape, axes);
const reduceSize = util.sizeFromShape(reduceShape);
const out = backend.makeOutput(outShape, x.dtype);
if (util.sizeFromShape(input.shape) !== 0) {
const outId = backend.dataIdMap.get(out.dataId).id;
wasmAll(inputId, reduceSize, outId);
}
if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}
if (keepDims) {
// reshape
const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes);
out.shape = newShape;
}
return out;
}
const allConfig = {
kernelName: All,
backendName: 'wasm',
setupFunc: setup$18,
kernelFunc: all
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let wasmAny;
function setup$17(backend) {
wasmAny = backend.wasm.cwrap(Any, null /*void*/, ['number, number, number']);
}
function any(args) {
const { backend, inputs, attrs } = args;
const { axis, keepDims } = attrs;
const { x } = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
let inputId = xId;
let input = x;
const { transposed, axes, originalAxes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend);
if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
input = transposed;
inputId = transposedId;
}
const inputRank = input.shape.length;
backend_util.assertAxesAreInnerMostDims('any', axes, inputRank);
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(input.shape, axes);
const reduceSize = util.sizeFromShape(reduceShape);
const out = backend.makeOutput(outShape, x.dtype);
if (util.sizeFromShape(input.shape) !== 0) {
const outId = backend.dataIdMap.get(out.dataId).id;
wasmAny(inputId, reduceSize, outId);
}
if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}
if (keepDims) {
// reshape
const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes);
out.shape = newShape;
}
return out;
}
const anyConfig = {
kernelName: Any,
backendName: 'wasm',
setupFunc: setup$17,
kernelFunc: any
};
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createArgMinMaxKernelConfig(kernelName) {
let wasmFunc;
function setupFunc(backend) {
wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [
'number',
'number',
'number',
'number',
'number' // out_id
]);
}
function kernelFunc(args) {
const { backend, inputs, attrs } = args;
const { axis } = attrs;
const { x } = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
let inputId = xId;
let input = x;
const { transposed, axes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend);
if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
if (transposedId !== xId) {
// transpose was not a no-op. We will need to dispose of this
// once we are done.
input = transposed;
inputId = transposedId;
}
}
const outShape = input.shape.slice(0, -1);
const out = backend.makeOutput(outShape, 'int32');
const outId = backend.dataIdMap.get(out.dataId).id;
const outerSize = util.sizeFromShape(out.shape);
const innerSize = input.shape[axes[0]];
wasmFunc(inputId, CppDType[input.dtype], outerSize, innerSize, outId);
if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}
return out;
}
return {
kernelName,
backendName: 'wasm',
setupFunc,
kernelFunc: kernelFunc,
};
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const argMaxConfig = createArgMinMaxKernelConfig(ArgMax);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const argMinConfig = createArgMinMaxKernelConfig(ArgMin);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const asinConfig = createUnaryKernelConfig(Asin);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const asinhConfig = createUnaryKernelConfig(Asinh);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const atanConfig = createUnaryKernelConfig(Atan);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const atan2Config = createBinaryKernelConfig(Atan2);
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const atanhConfig = createUnaryKernelConfig(Atanh);
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let wasmAvgPool;
function setup$16(backend) {
wasmAvgPool = backend.wasm.cwrap(AvgPool, null /* void */, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // outId
]);
}
function avgPool(args) {
const { inputs, attrs, backend } = args;
const x = inputs.x;
const xId = backend.dataIdMap.get(x.dataId).id;
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const padTop = convInfo.padInfo.top;
const padRight = convInfo.padInfo.right;
const padBottom = convInfo.padInfo.bottom;
const padLeft = convInfo.padInfo.left;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const channels = convInfo.inChannels;
if (convInfo.dataFormat !== 'channelsLast') {
throw new Error(`wasm backend does not support dataFormat:'` +
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
}
if (convInfo.dilationWidth !== 1 || convInfo.dilationHeight !== 1) {
throw new Error(`was backend only supports average pooling with dilation = [1, 1], ` +
`got [${convInfo.dilationHeight}, ${convInfo.dilationWidth}].`);
}
const out = backend.makeOutput(convInfo.outShape, 'float32');
const outId = backend.dataIdMap.get(out.dataId).id;
wasmAvgPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, strideHeight, strideWidth, channels, outId);
return out;
}
const avgPoolConfig = {
kernelName: AvgPool,
backendName: 'wasm',
setupFunc: setup$16,
kernelFunc: avgPool
};
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let wasmAvgPool3D;
function setup$15(backend) {
wasmAvgPool3D = backend.wasm.cwrap('AvgPool3D', null, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // padLeft
]);
}
function avgPool3D(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
const convInfo = backend_util.computePool3DInfo(x.shape, filterSize, strides,
/*dilations=*/ 1, pad, dimRoundingMode, dataFormat);
const out = backend.makeOutput(convInfo.outShape, x.dtype);
wasmAvgPool3D(backend.dataIdMap.get(x.dataId).id, backend.dataIdMap.get(out.dataId).id, convInfo.batchSize,
// Since Pool3D ops (AvgPool3D and MaxPool3D) support 3D filter only, in
// channels should always equal to out channels.
/*channelSize=*/ convInfo.inChannels, convInfo.inDepth, convInfo.inHeight, convInfo.inWidth, convInfo.outDepth, convInfo.outHeight, convInfo.outWidth, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationDepth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterDepth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left);
return out;
}
const avgPool3DConfig = {
kernelName: AvgPool3D,
backendName: 'wasm',
setupFunc: setup$15,
kernelFunc: avgPool3D
};
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let wasmAvgPool3DGrad;
function setup$14(backend) {
wasmAvgPool3DGrad = backend.wasm.cwrap('AvgPool3DGrad', null, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // filterWidth
]);
}
function avgPool3DGrad(args) {
const { inputs, backend, attrs } = args;
const { dy, input } = inputs;
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const convInfo = backend_util.computePool3DInfo(input.shape, filterSize, strides, /*dilations=*/ 1, pad, dimRoundingMode);
const dx = backend.makeOutput(input.shape, input.dtype);
wasmAvgPool3DGrad(backend.dataIdMap.get(dy.dataId).id, backend.dataIdMap.get(dx.dataId).id, convInfo.batchSize,
// Since Pool3D ops (AvgPool3D and MaxPool3D) support 3D filter only, in
// channels should always equal to out channels.
/*channelSize=*/ convInfo.inChannels, convInfo.inDepth, convInfo.inHeight, convInfo.inWidth, convInfo.outDepth, convInfo.outHeight, convInfo.outWidth, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationDepth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterDepth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left, convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth);
return dx;
}
const avgPool3DGradConfig = {
kernelName: AvgPool3DGrad,
backendName: 'wasm',
setupFunc: setup$14,
kernelFunc: avgPool3DGrad
};
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let wasmAvgPoolGrad;
function setup$13(backend) {
wasmAvgPoolGrad = backend.wasm.cwrap('AvgPoolGrad', null, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // filterWidth
]);
}
function avgPoolGrad(args) {
const { inputs, backend, attrs } = args;
const { dy, input } = inputs;
const { filterSize, strides, pad } = attrs;
const convInfo = backend_util.computePool2DInfo(input.shape, filterSize, strides,
/*dilations=*/ 1, pad);
const dx = backend.makeOutput(input.shape, input.dtype);
wasmAvgPoolGrad(backend.dataIdMap.get(dy.dataId).id, backend.dataIdMap.get(dx.dataId).id, convInfo.batchSize,
// Since Pool ops (AvgPool and MaxPool) support 2D filter only, in
// channels should always equal to out channels.
/*channelSize=*/ convInfo.inChannels, convInfo.inHeight, convInfo.inWidth, convInfo.outHeight, convInfo.outWidth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.top, convInfo.padInfo.left, convInfo.filterHeight, convInfo.filterWidth);
return dx;
}
const avgPoolGradConfig = {
kernelName: AvgPoolGrad,
backendName: 'wasm',
setupFunc: setup$13,
kernelFunc: avgPoolGrad
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reshape(args) {
const { inputs, attrs } = args;
const { x } = inputs;
const { shape } = attrs;
const xSize = util.sizeFromShape(x.shape);
const $shape = util.inferFromImplicitShape(shape, xSize);
util.assert(xSize === util.sizeFromShape($shape), () => `new shape: ${$shape}, old shape: ${x.shape}. New shape and old ` +
`shape must have the same number of elements.`);
// Backend needs to track refCount for the dataId for reshape op
args.backend.incRef(x.dataId);
return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
}
const reshapeConfig = {
kernelName: Reshape,
backendName: 'wasm',
kernelFunc: reshape
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
let wasmBatchMatMul;
function setup$12(backend) {
wasmBatchMatMul = backend.wasm.cwrap(BatchMatMul, null /* void */, [
'number',
'array',
'number',
'number',
'array',
'number',
'number',
'number',
'number' // out_id
]);
}
function batchMatMul(args) {
const { inputs, backend, attrs } = args;
const { a, b } = inputs;
const { transposeA, transposeB } = attrs;
if (a.dtype !== 'float32' || b.dtype !== 'float32') {
throw new Error(`BatchMatMul for non non-float32 tensors not yet supported.`);
}
const aRank = a.shape.length;
const bRank = b.shape.length;
const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
const outerDimsA = a.shape.slice(0, -2);
const outerDimsB = b.shape.slice(0, -2);
const batchDimA = util.sizeFromShape(outerDimsA);
const batchDimB = util.sizeFromShape(outerDimsB);
const outShapeOuterDims = broadcast_util.assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
util.assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
`${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
`${b.shape} and transposeA=${transposeA}` +
` and transposeB=${transposeB} must match.`);
const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] :
[batchDimA, outerShapeA, innerShapeA];
const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] :
[batchDimB, innerShapeB, outerShapeB];
// The rest of the implementation is designed to operate on rank-3 tensors
const a3d = reshape({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
const b3d = reshape({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
const a3dId = backend.dataIdMap.get(a3d.dataId).id;
const b3dId = backend.dataIdMap.get(b3d.dataId).id;
const leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
const rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
const batchDim = Math.max(batchDimA, batchDimB);
const out = backend.makeOutput([batchDim, leftDim, rightDim], a3d.dtype);
const outId = backend.dataIdMap.get(out.dataId).id;
const aShapeBytes = new Uint8Array(new Int32Array(a3d.shape).buffer);
const bShapeBytes = new Uint8Array(new Int32Array(b3d.shape).buffer);
wasmBatchMatMul(a3dId, aShapeBytes, a3d.shape.length, b3dId, bShapeBytes, b3d.shape.length, transposeA, transposeB, outId);
backend.disposeData(a3d.dataId);
backend.disposeData(b3d.dataId);
out.shape = outShape;
return out;
}
const batchMatMulConfig = {
kernelName: BatchMatMul,
backendName: 'wasm',
setupFunc: setup$12,
kernelFunc: batchMatMul
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the Lic