UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

1,153 lines (1,124 loc) 450 kB
/** * @license * Copyright 2024 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ (function (global, factory) { typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports, require('@tensorflow/tfjs-core'), require('fs'), require('path'), require('perf_hooks'), require('os')) : typeof define === 'function' && define.amd ? define(['exports', '@tensorflow/tfjs-core', 'fs', 'path', 'perf_hooks', 'os'], factory) : (global = typeof globalThis !== 'undefined' ? globalThis : global || self, factory((global.tf = global.tf || {}, global.tf.wasm = global.tf.wasm || {}), global.tf, global.fs, global.path, global.perf_hooks, global.require$$4)); })(this, (function (exports, tfjsCore, require$$0, require$$1, require$$3, require$$4) { 'use strict'; function _mergeNamespaces(n, m) { m.forEach(function (e) { e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) { if (k !== 'default' && !(k in n)) { var d = Object.getOwnPropertyDescriptor(e, k); Object.defineProperty(n, k, d.get ? d : { enumerable: true, get: function () { return e[k]; } }); } }); }); return n; } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // This enum must align with the enum defined in cc/backend.h. var CppDType; (function (CppDType) { CppDType[CppDType["float32"] = 0] = "float32"; CppDType[CppDType["int32"] = 1] = "int32"; CppDType[CppDType["bool"] = 2] = "bool"; CppDType[CppDType["string"] = 3] = "string"; CppDType[CppDType["complex64"] = 4] = "complex64"; })(CppDType || (CppDType = {})); // Must match enum in cc/fusable_activations.h. var FusableActivation; (function (FusableActivation) { FusableActivation[FusableActivation["linear"] = 0] = "linear"; FusableActivation[FusableActivation["relu"] = 1] = "relu"; FusableActivation[FusableActivation["relu6"] = 2] = "relu6"; FusableActivation[FusableActivation["prelu"] = 3] = "prelu"; FusableActivation[FusableActivation["leakyrelu"] = 4] = "leakyrelu"; FusableActivation[FusableActivation["sigmoid"] = 5] = "sigmoid"; FusableActivation[FusableActivation["elu"] = 6] = "elu"; })(FusableActivation || (FusableActivation = {})); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let wasmFusedMatMul; function setup$1a(backend) { wasmFusedMatMul = backend.wasm.cwrap(tfjsCore._FusedMatMul, null /* void */, [ 'number', 'array', 'number', 'number', 'array', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number' // out_id ]); } function fusedBatchMatMul(args) { const { inputs, backend, attrs } = args; const { a, b, bias, preluActivationWeights } = inputs; if (a.dtype !== 'float32' || b.dtype !== 'float32') { throw new Error(`_FusedMatMul for non non-float32 tensors not yet supported.`); } const { transposeA, transposeB, activation, leakyreluAlpha } = attrs; const aId = backend.dataIdMap.get(a.dataId).id; const bId = backend.dataIdMap.get(b.dataId).id; let biasId = 0; if (bias != null) { const biasData = backend.dataIdMap.get(bias.dataId); if (biasData.shape.length !== 1) { throw new Error(`_FusedMatMul only supports rank-1 bias but got ` + `rank ${biasData.shape.length}.`); } biasId = biasData.id; } const preluActivationWeightsId = preluActivationWeights == null ? 0 : backend.dataIdMap.get(preluActivationWeights.dataId).id; const fusedActivation = FusableActivation[activation]; if (fusedActivation == null) { throw new Error(`${activation} activation not yet supported for FusedConv2D ` + `in the wasm backend.`); } const leftDim = transposeA ? a.shape[2] : a.shape[1]; const rightDim = transposeB ? b.shape[1] : b.shape[2]; const batchDims = tfjsCore.broadcast_util.assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2)); const out = backend.makeOutput([...batchDims, leftDim, rightDim], a.dtype); const outId = backend.dataIdMap.get(out.dataId).id; const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer); const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer); wasmFusedMatMul(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, transposeA, transposeB, fusedActivation, biasId, preluActivationWeightsId, leakyreluAlpha || 0, outId); return out; } const _fusedMatMulConfig = { kernelName: tfjsCore._FusedMatMul, backendName: 'wasm', setupFunc: setup$1a, kernelFunc: fusedBatchMatMul }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function createUnaryKernelConfig(kernelName, outType) { let wasmFunc; function setupFunc(backend) { wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [ 'number', 'number', 'number', // out_id ]); } function kernelFunc(args) { const { backend, inputs: { x } } = args; const xId = backend.dataIdMap.get(x.dataId).id; const out = backend.makeOutput(x.shape, outType || x.dtype); const outId = backend.dataIdMap.get(out.dataId).id; // Short-circuit zero-sized tensors. if (tfjsCore.util.sizeFromShape(out.shape) === 0) { return out; } wasmFunc(xId, CppDType[x.dtype], outId); return out; } return { kernelName, backendName: 'wasm', setupFunc, kernelFunc }; } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const absConfig = createUnaryKernelConfig(tfjsCore.Abs); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const acosConfig = createUnaryKernelConfig(tfjsCore.Acos); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const acoshConfig = createUnaryKernelConfig(tfjsCore.Acosh); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function createBinaryKernelConfig(kernelName, supportsFullBroadcast, dtype) { let wasmFunc; function setupFunc(backend) { wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [ 'number', 'array', 'number', 'number', 'array', 'number', 'number', 'number' // out_id ]); } function kernelFunc(args) { const { backend, inputs } = args; const { a, b } = inputs; const aId = backend.dataIdMap.get(a.dataId).id; const bId = backend.dataIdMap.get(b.dataId).id; const outputType = dtype != null ? dtype : a.dtype; const newShape = tfjsCore.backend_util.assertAndGetBroadcastShape(a.shape, b.shape); const out = backend.makeOutput(newShape, outputType); // Short-circuit zero-sized tensors. if (tfjsCore.util.sizeFromShape(newShape) === 0) { return out; } const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer); const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer); const outId = backend.dataIdMap.get(out.dataId).id; const kernelFunc = () => wasmFunc(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, CppDType[a.dtype], outId); kernelFunc(); return out; } return { kernelName, backendName: 'wasm', setupFunc, kernelFunc }; } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const addConfig = createBinaryKernelConfig(tfjsCore.Add); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let wasmFunc$6; function setupFunc$1(backend) { wasmFunc$6 = backend.wasm.cwrap(tfjsCore.AddN, null /* void */, [ 'array', 'number', 'number', 'number', // out_id ]); } function addn(args) { const { inputs, backend } = args; const out = backend.makeOutput(inputs[0].shape, inputs[0].dtype); // Short-circuit zero-sized tensors. if (tfjsCore.util.sizeFromShape(out.shape) === 0) { return out; } const inputIds = inputs.map(x => backend.dataIdMap.get(x.dataId).id); const inputIdsBytes = new Uint8Array(new Int32Array(inputIds).buffer); const outId = backend.dataIdMap.get(out.dataId).id; wasmFunc$6(inputIdsBytes, inputIds.length, CppDType[out.dtype], outId); return out; } const addNConfig = { kernelName: tfjsCore.AddN, backendName: 'wasm', setupFunc: setupFunc$1, kernelFunc: addn, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function identity(args) { const { inputs: { x }, backend } = args; if (x.dtype === 'string') { return tfjsCore.tensor(backend.readSync(x.dataId), x.shape, x.dtype); } const out = backend.makeOutput(x.shape, x.dtype); const inVals = backend.typedArrayFromHeap(x); const outVals = backend.typedArrayFromHeap(out); outVals.set(inVals); return out; } const identityConfig = { kernelName: tfjsCore.Identity, backendName: 'wasm', kernelFunc: identity, }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let wasmTranspose; function setup$19(backend) { wasmTranspose = backend.wasm.cwrap(tfjsCore.Transpose, null /* void */, [ 'number', 'array', 'number', 'number', 'number', 'array', 'number', // perm.length ]); } function transpose(args) { const { inputs, backend, attrs } = args; // Reduce any dimensions with size one. Lower-rank transpose kernel performs // better due to simpler memory access pattern. const [reducedShape, perm] = removeOneSizeDims(inputs.x.shape, attrs.perm); let permIsNoOp = true; for (let i = 0; i < perm.length; i++) { if (perm[i] !== i) { permIsNoOp = false; } } const outShape = computeOutShape(inputs.x.shape, attrs.perm); const x = { dataId: inputs.x.dataId, shape: reducedShape, dtype: inputs.x.dtype }; if (permIsNoOp) { const cloned = identity({ inputs, backend }); cloned.shape = outShape; return cloned; } const out = backend.makeOutput(outShape, x.dtype); const xId = backend.dataIdMap.get(x.dataId).id; const outId = backend.dataIdMap.get(out.dataId).id; const permBytes = new Uint8Array(new Int32Array(perm).buffer); const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); wasmTranspose(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], outId, permBytes, perm.length); return out; } function computeOutShape(inShape, perm) { const outShape = new Array(inShape.length); for (let i = 0; i < outShape.length; i++) { outShape[i] = inShape[perm[i]]; } return outShape; } function removeOneSizeDims(shape, perm) { const newShape = []; const newPerm = []; for (let i = 0; i < shape.length; ++i) { if (shape[i] !== 1) { newShape.push(shape[i]); } if (shape[perm[i]] !== 1) { newPerm.push(perm[i]); } } for (let i = 0; i < newPerm.length; ++i) { let minValIdx = -1; for (let j = 0; j < newPerm.length; ++j) { if (newPerm[j] >= i && (minValIdx === -1 || newPerm[minValIdx] > newPerm[j])) { minValIdx = j; } } newPerm[minValIdx] = i; } return [newShape, newPerm]; } const transposeConfig = { kernelName: tfjsCore.Transpose, backendName: 'wasm', kernelFunc: transpose, setupFunc: setup$19, }; /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Compute permutation axes and do a transpose if necessary. * * Used by reduction ops. * @param x input TensorInfo * @param axis reduction axes * @param backend wasm backend instance */ function permuteAxesAndTranspose(x, axis, backend) { const xShape = x.shape; const xRank = x.shape.length; const originalAxes = tfjsCore.util.parseAxisParam(axis, xShape); let axes = originalAxes; const permutedAxes = tfjsCore.backend_util.getAxesPermutation(axes, xRank); let xTransposed = null; let inputWasTransposed = false; if (permutedAxes != null) { const newShape = new Array(xRank); for (let i = 0; i < newShape.length; i++) { newShape[i] = xShape[permutedAxes[i]]; } axes = tfjsCore.backend_util.getInnerMostAxes(axes.length, xRank); xTransposed = transpose({ inputs: { x }, attrs: { perm: permutedAxes }, backend }); const xId = backend.dataIdMap.get(x.dataId).id; const transposedId = backend.dataIdMap.get(xTransposed.dataId).id; if (transposedId !== xId) { inputWasTransposed = true; } } return { transposed: xTransposed, originalAxes, axes, inputWasTransposed }; } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let wasmAll; function setup$18(backend) { wasmAll = backend.wasm.cwrap(tfjsCore.All, null /*void*/, ['number, number, number']); } function all(args) { const { backend, inputs, attrs } = args; const { axis, keepDims } = attrs; const { x } = inputs; const xId = backend.dataIdMap.get(x.dataId).id; let inputId = xId; let input = x; const { transposed, axes, originalAxes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend); if (inputWasTransposed) { const transposedId = backend.dataIdMap.get(transposed.dataId).id; input = transposed; inputId = transposedId; } const inputRank = input.shape.length; tfjsCore.backend_util.assertAxesAreInnerMostDims('all', axes, inputRank); const [outShape, reduceShape] = tfjsCore.backend_util.computeOutAndReduceShapes(input.shape, axes); const reduceSize = tfjsCore.util.sizeFromShape(reduceShape); const out = backend.makeOutput(outShape, x.dtype); if (tfjsCore.util.sizeFromShape(input.shape) !== 0) { const outId = backend.dataIdMap.get(out.dataId).id; wasmAll(inputId, reduceSize, outId); } if (inputWasTransposed) { // dispose of the transposed tensor. backend.disposeData(transposed.dataId); } if (keepDims) { // reshape const newShape = tfjsCore.backend_util.expandShapeToKeepDim(out.shape, originalAxes); out.shape = newShape; } return out; } const allConfig = { kernelName: tfjsCore.All, backendName: 'wasm', setupFunc: setup$18, kernelFunc: all }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let wasmAny; function setup$17(backend) { wasmAny = backend.wasm.cwrap(tfjsCore.Any, null /*void*/, ['number, number, number']); } function any(args) { const { backend, inputs, attrs } = args; const { axis, keepDims } = attrs; const { x } = inputs; const xId = backend.dataIdMap.get(x.dataId).id; let inputId = xId; let input = x; const { transposed, axes, originalAxes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend); if (inputWasTransposed) { const transposedId = backend.dataIdMap.get(transposed.dataId).id; input = transposed; inputId = transposedId; } const inputRank = input.shape.length; tfjsCore.backend_util.assertAxesAreInnerMostDims('any', axes, inputRank); const [outShape, reduceShape] = tfjsCore.backend_util.computeOutAndReduceShapes(input.shape, axes); const reduceSize = tfjsCore.util.sizeFromShape(reduceShape); const out = backend.makeOutput(outShape, x.dtype); if (tfjsCore.util.sizeFromShape(input.shape) !== 0) { const outId = backend.dataIdMap.get(out.dataId).id; wasmAny(inputId, reduceSize, outId); } if (inputWasTransposed) { // dispose of the transposed tensor. backend.disposeData(transposed.dataId); } if (keepDims) { // reshape const newShape = tfjsCore.backend_util.expandShapeToKeepDim(out.shape, originalAxes); out.shape = newShape; } return out; } const anyConfig = { kernelName: tfjsCore.Any, backendName: 'wasm', setupFunc: setup$17, kernelFunc: any }; /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function createArgMinMaxKernelConfig(kernelName) { let wasmFunc; function setupFunc(backend) { wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [ 'number', 'number', 'number', 'number', 'number' // out_id ]); } function kernelFunc(args) { const { backend, inputs, attrs } = args; const { axis } = attrs; const { x } = inputs; const xId = backend.dataIdMap.get(x.dataId).id; let inputId = xId; let input = x; const { transposed, axes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend); if (inputWasTransposed) { const transposedId = backend.dataIdMap.get(transposed.dataId).id; if (transposedId !== xId) { // transpose was not a no-op. We will need to dispose of this // once we are done. input = transposed; inputId = transposedId; } } const outShape = input.shape.slice(0, -1); const out = backend.makeOutput(outShape, 'int32'); const outId = backend.dataIdMap.get(out.dataId).id; const outerSize = tfjsCore.util.sizeFromShape(out.shape); const innerSize = input.shape[axes[0]]; wasmFunc(inputId, CppDType[input.dtype], outerSize, innerSize, outId); if (inputWasTransposed) { // dispose of the transposed tensor. backend.disposeData(transposed.dataId); } return out; } return { kernelName, backendName: 'wasm', setupFunc, kernelFunc: kernelFunc, }; } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const argMaxConfig = createArgMinMaxKernelConfig(tfjsCore.ArgMax); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const argMinConfig = createArgMinMaxKernelConfig(tfjsCore.ArgMin); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const asinConfig = createUnaryKernelConfig(tfjsCore.Asin); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const asinhConfig = createUnaryKernelConfig(tfjsCore.Asinh); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const atanConfig = createUnaryKernelConfig(tfjsCore.Atan); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const atan2Config = createBinaryKernelConfig(tfjsCore.Atan2); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const atanhConfig = createUnaryKernelConfig(tfjsCore.Atanh); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let wasmAvgPool; function setup$16(backend) { wasmAvgPool = backend.wasm.cwrap(tfjsCore.AvgPool, null /* void */, [ 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', // outId ]); } function avgPool(args) { const { inputs, attrs, backend } = args; const x = inputs.x; const xId = backend.dataIdMap.get(x.dataId).id; const { filterSize, strides, pad, dimRoundingMode } = attrs; const convInfo = tfjsCore.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode); const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const padTop = convInfo.padInfo.top; const padRight = convInfo.padInfo.right; const padBottom = convInfo.padInfo.bottom; const padLeft = convInfo.padInfo.left; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const channels = convInfo.inChannels; if (convInfo.dataFormat !== 'channelsLast') { throw new Error(`wasm backend does not support dataFormat:'` + `${convInfo.dataFormat}'. Please use 'channelsLast'.`); } if (convInfo.dilationWidth !== 1 || convInfo.dilationHeight !== 1) { throw new Error(`was backend only supports average pooling with dilation = [1, 1], ` + `got [${convInfo.dilationHeight}, ${convInfo.dilationWidth}].`); } const out = backend.makeOutput(convInfo.outShape, 'float32'); const outId = backend.dataIdMap.get(out.dataId).id; wasmAvgPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, strideHeight, strideWidth, channels, outId); return out; } const avgPoolConfig = { kernelName: tfjsCore.AvgPool, backendName: 'wasm', setupFunc: setup$16, kernelFunc: avgPool }; /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let wasmAvgPool3D; function setup$15(backend) { wasmAvgPool3D = backend.wasm.cwrap('AvgPool3D', null, [ 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', // padLeft ]); } function avgPool3D(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs; const convInfo = tfjsCore.backend_util.computePool3DInfo(x.shape, filterSize, strides, /*dilations=*/ 1, pad, dimRoundingMode, dataFormat); const out = backend.makeOutput(convInfo.outShape, x.dtype); wasmAvgPool3D(backend.dataIdMap.get(x.dataId).id, backend.dataIdMap.get(out.dataId).id, convInfo.batchSize, // Since Pool3D ops (AvgPool3D and MaxPool3D) support 3D filter only, in // channels should always equal to out channels. /*channelSize=*/ convInfo.inChannels, convInfo.inDepth, convInfo.inHeight, convInfo.inWidth, convInfo.outDepth, convInfo.outHeight, convInfo.outWidth, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationDepth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterDepth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left); return out; } const avgPool3DConfig = { kernelName: tfjsCore.AvgPool3D, backendName: 'wasm', setupFunc: setup$15, kernelFunc: avgPool3D }; /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let wasmAvgPool3DGrad; function setup$14(backend) { wasmAvgPool3DGrad = backend.wasm.cwrap('AvgPool3DGrad', null, [ 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', // filterWidth ]); } function avgPool3DGrad(args) { const { inputs, backend, attrs } = args; const { dy, input } = inputs; const { filterSize, strides, pad, dimRoundingMode } = attrs; const convInfo = tfjsCore.backend_util.computePool3DInfo(input.shape, filterSize, strides, /*dilations=*/ 1, pad, dimRoundingMode); const dx = backend.makeOutput(input.shape, input.dtype); wasmAvgPool3DGrad(backend.dataIdMap.get(dy.dataId).id, backend.dataIdMap.get(dx.dataId).id, convInfo.batchSize, // Since Pool3D ops (AvgPool3D and MaxPool3D) support 3D filter only, in // channels should always equal to out channels. /*channelSize=*/ convInfo.inChannels, convInfo.inDepth, convInfo.inHeight, convInfo.inWidth, convInfo.outDepth, convInfo.outHeight, convInfo.outWidth, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationDepth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterDepth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left, convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth); return dx; } const avgPool3DGradConfig = { kernelName: tfjsCore.AvgPool3DGrad, backendName: 'wasm', setupFunc: setup$14, kernelFunc: avgPool3DGrad }; /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ let wasmAvgPoolGrad; function setup$13(backend) { wasmAvgPoolGrad = backend.wasm.cwrap('AvgPoolGrad', null, [ 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', // filterWidth ]); } function avgPoolGrad(args) { const { inputs, backend, attrs } = args; const { dy, input } = inputs; const { filterSize, strides, pad } = attrs; const convInfo = tfjsCore.backend_util.computePool2DInfo(input.shape, filterSize, strides, /*dilations=*/ 1, pad); const dx = backend.makeOutput(input.shape, input.dtype); wasmAvgPoolGrad(backend.dataIdMap.get(dy.dataId).id, backend.dataIdMap.get(dx.dataId).id, convInfo.batchSize, // Since Pool ops (AvgPool and MaxPool) support 2D filter only, in // channels should always equal to out channels. /*channelSize=*/ convInfo.inChannels, convInfo.inHeight, convInfo.inWidth, convInfo.outHeight, convInfo.outWidth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.top, convInfo.padInfo.left, convInfo.filterHeight, convInfo.filterWidth); return dx; } const avgPoolGradConfig = { kernelName: tfjsCore.AvgPoolGrad, backendName: 'wasm', setupFunc: setup$13, kernelFunc: avgPoolGrad }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function reshape(args) { const { inputs, attrs } = args; const { x } = inputs; const { shape } = attrs; const xSize = tfjsCore.util.sizeFromShape(x.shape); const $shape = tfjsCore.util.inferFromImplicitShape(shape, xSize); tfjsCore.util.assert(xSize === tfjsCore.util.sizeFromShape($shape), () => `new shape: ${$shape}, old shape: ${x.shape}. New shape and old ` + `shape must have the same number of elements.`); // Backend needs to track refCount for the dataId for reshape op args.backend.incRef(x.dataId); return { dataId: x.dataId, shape: $shape, dtype: x.dtype }; } const reshapeConfig = { kernelName: tfjsCore.Reshape, backendName: 'wasm', kernelFunc: reshape }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * =======