UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

66 lines 10.2 kB
/** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, Cumprod, util } from '@tensorflow/tfjs-core'; import { CppDType } from './types'; import { transpose } from './Transpose'; let wasmCumprod; function setup(backend) { wasmCumprod = backend.wasm.cwrap(Cumprod, null /* void */, [ 'number', 'number', 'number', 'number', 'number', 'number' // dtype ]); } export function cumprod(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, exclusive, reverse } = attrs; const xRank = x.shape.length; util.assert(x.dtype === 'float32' || x.dtype === 'int32', () => `cumprod does not support ${x.dtype} tensors in the WASM backend`); // permute required axis to inner most axis const permutation = backend_util.getAxesPermutation([axis], xRank); let permutedX = x; if (permutation !== null) { permutedX = transpose({ inputs: { x }, attrs: { perm: permutation }, backend }); } const permutedAxis = backend_util.getInnerMostAxes(1, xRank)[0]; backend_util.assertAxesAreInnerMostDims('cumprod', [permutedAxis], xRank); const permutedOut = backend.makeOutput(permutedX.shape, permutedX.dtype); const finalDim = permutedX.shape[permutedAxis]; const permutedXId = backend.dataIdMap.get(permutedX.dataId).id; const permutedOutId = backend.dataIdMap.get(permutedOut.dataId).id; wasmCumprod(permutedXId, exclusive ? 1 : 0, reverse ? 1 : 0, finalDim, permutedOutId, CppDType[x.dtype]); // transpose data back if permuted let out = permutedOut; if (permutation !== null) { const undoPermutation = backend_util.getUndoAxesPermutation(permutation); out = transpose({ inputs: { x: permutedOut }, attrs: { perm: undoPermutation }, backend }); backend.disposeData(permutedX.dataId); backend.disposeData(permutedOut.dataId); } return out; } export const cumprodConfig = { kernelName: Cumprod, backendName: 'wasm', setupFunc: setup, kernelFunc: cumprod }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQ3VtcHJvZC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL0N1bXByb2QudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBNEIsT0FBTyxFQUEyQyxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUlySSxPQUFPLEVBQUMsUUFBUSxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBRWpDLE9BQU8sRUFBQyxTQUFTLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFdEMsSUFBSSxXQUNzRSxDQUFDO0FBRTNFLFNBQVMsS0FBSyxDQUFDLE9BQW9CO0lBQ2pDLFdBQVcsR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEtBQUssQ0FBQyxPQUFPLEVBQUUsSUFBSSxDQUFDLFVBQVUsRUFBRTtRQUN6RCxRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVEsQ0FBRSxRQUFRO0tBQ25CLENBQUMsQ0FBQztBQUNMLENBQUM7QUFFRCxNQUFNLFVBQVUsT0FBTyxDQUNyQixJQUF3RTtJQUV4RSxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQUMsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFDekMsTUFBTSxLQUFLLEdBQUcsQ0FBQyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUM7SUFFN0IsSUFBSSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsS0FBSyxLQUFLLFNBQVMsSUFBSSxDQUFDLENBQUMsS0FBSyxLQUFLLE9BQU8sRUFDdEQsR0FBRyxFQUFFLENBQUMsNEJBQTRCLENBQUMsQ0FBQyxLQUFLLDhCQUE4QixDQUFDLENBQUM7SUFDM0UsMkNBQTJDO0lBQzNDLE1BQU0sV0FBVyxHQUFHLFlBQVksQ0FBQyxrQkFBa0IsQ0FBQyxDQUFDLElBQUksQ0FBQyxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBQ25FLElBQUksU0FBUyxHQUFHLENBQUMsQ0FBQztJQUNsQixJQUFJLFdBQVcsS0FBSyxJQUFJLEVBQUU7UUFDeEIsU0FBUyxHQUFHLFNBQVMsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQyxFQUFFLEtBQUssRUFBRSxFQUFDLElBQUksRUFBRSxXQUFXLEVBQUMsRUFBRSxPQUFPLEVBQUMsQ0FBQyxDQUFDO0tBQzNFO0lBQ0QsTUFBTSxZQUFZLEdBQUcsWUFBWSxDQUFDLGdCQUFnQixDQUFDLENBQUMsRUFBRSxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUNoRSxZQUFZLENBQUMsMEJBQTBCLENBQUMsU0FBUyxFQUFFLENBQUMsWUFBWSxDQUFDLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFMUUsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxTQUFTLENBQUMsS0FBSyxFQUFFLFNBQVMsQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUN6RSxNQUFNLFFBQVEsR0FBRyxTQUFTLENBQUMsS0FBSyxDQUFDLFlBQVksQ0FBQyxDQUFDO0lBQy9DLE1BQU0sV0FBVyxHQUFHLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLENBQUM7SUFDL0QsTUFBTSxhQUFhLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsV0FBVyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUNuRSxXQUFXLENBQUMsV0FBVyxFQUFFLFNBQVMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUUsT0FBTyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsRUFBRSxRQUFRLEVBQ3pELGFBQWEsRUFBRSxRQUFRLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUM7SUFFOUMsa0NBQWtDO0lBQ2xDLElBQUksR0FBRyxHQUFHLFdBQVcsQ0FBQztJQUN0QixJQUFJLFdBQVcsS0FBSyxJQUFJLEVBQUU7UUFDeEIsTUFBTSxlQUFlLEdBQUcsWUFBWSxDQUFDLHNCQUFzQixDQUFDLFdBQVcsQ0FBQyxDQUFDO1FBQ3pFLEdBQUcsR0FBRyxTQUFTLENBQ2IsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsV0FBVyxFQUFDLEVBQUUsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFFLGVBQWUsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7UUFDdkUsT0FBTyxDQUFDLFdBQVcsQ0FBQyxTQUFTLENBQUMsTUFBTSxDQUFDLENBQUM7UUFDdEMsT0FBTyxDQUFDLFdBQVcsQ0FBQyxXQUFXLENBQUMsTUFBTSxDQUFDLENBQUM7S0FDekM7SUFDRCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxhQUFhLEdBQWlCO0lBQ3pDLFVBQVUsRUFBRSxPQUFPO0lBQ25CLFdBQVcsRUFBRSxNQUFNO0lBQ25CLFNBQVMsRUFBRSxLQUFLO0lBQ2hCLFVBQVUsRUFBRSxPQUFnQztDQUM3QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjIgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBDdW1wcm9kLCBDdW1wcm9kQXR0cnMsIEN1bXByb2RJbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QmFja2VuZFdhc219IGZyb20gJy4uL2JhY2tlbmRfd2FzbSc7XG5cbmltcG9ydCB7Q3BwRFR5cGV9IGZyb20gJy4vdHlwZXMnO1xuXG5pbXBvcnQge3RyYW5zcG9zZX0gZnJvbSAnLi9UcmFuc3Bvc2UnO1xuXG5sZXQgd2FzbUN1bXByb2Q6ICh4SWQ6IG51bWJlciwgZXhjbHVzaXZlOiBudW1iZXIsIHJldmVyc2U6IG51bWJlcixcbiAgICAgICAgICAgICAgICAgZmluYWxEaW06IG51bWJlciwgb3V0SWQ6IG51bWJlciwgZHR5cGU6IENwcERUeXBlKSA9PiB2b2lkO1xuXG5mdW5jdGlvbiBzZXR1cChiYWNrZW5kOiBCYWNrZW5kV2FzbSkge1xuICB3YXNtQ3VtcHJvZCA9IGJhY2tlbmQud2FzbS5jd3JhcChDdW1wcm9kLCBudWxsIC8qIHZvaWQgKi8sIFtcbiAgICAnbnVtYmVyJywgLy8geF9pZFxuICAgICdudW1iZXInLCAvLyBleGNsdXNpdmVcbiAgICAnbnVtYmVyJywgLy8gcmV2ZXJzZVxuICAgICdudW1iZXInLCAvLyBmaW5hbF9kaW1cbiAgICAnbnVtYmVyJywgLy8gb3V0X2lkXG4gICAgJ251bWJlcicgIC8vIGR0eXBlXG4gIF0pO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gY3VtcHJvZChcbiAgYXJnczoge2lucHV0czogQ3VtcHJvZElucHV0cywgYmFja2VuZDogQmFja2VuZFdhc20sIGF0dHJzOiBDdW1wcm9kQXR0cnN9KTpcblRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eH0gPSBpbnB1dHM7XG4gIGNvbnN0IHtheGlzLCBleGNsdXNpdmUsIHJldmVyc2V9ID0gYXR0cnM7XG4gIGNvbnN0IHhSYW5rID0geC5zaGFwZS5sZW5ndGg7XG5cbiAgdXRpbC5hc3NlcnQoeC5kdHlwZSA9PT0gJ2Zsb2F0MzInIHx8IHguZHR5cGUgPT09ICdpbnQzMicsXG4gICAgKCkgPT4gYGN1bXByb2QgZG9lcyBub3Qgc3VwcG9ydCAke3guZHR5cGV9IHRlbnNvcnMgaW4gdGhlIFdBU00gYmFja2VuZGApO1xuICAvLyBwZXJtdXRlIHJlcXVpcmVkIGF4aXMgdG8gaW5uZXIgbW9zdCBheGlzXG4gIGNvbnN0IHBlcm11dGF0aW9uID0gYmFja2VuZF91dGlsLmdldEF4ZXNQZXJtdXRhdGlvbihbYXhpc10sIHhSYW5rKTtcbiAgbGV0IHBlcm11dGVkWCA9IHg7XG4gIGlmIChwZXJtdXRhdGlvbiAhPT0gbnVsbCkge1xuICAgIHBlcm11dGVkWCA9IHRyYW5zcG9zZSh7aW5wdXRzOiB7eH0sIGF0dHJzOiB7cGVybTogcGVybXV0YXRpb259LCBiYWNrZW5kfSk7XG4gIH1cbiAgY29uc3QgcGVybXV0ZWRBeGlzID0gYmFja2VuZF91dGlsLmdldElubmVyTW9zdEF4ZXMoMSwgeFJhbmspWzBdO1xuICBiYWNrZW5kX3V0aWwuYXNzZXJ0QXhlc0FyZUlubmVyTW9zdERpbXMoJ2N1bXByb2QnLCBbcGVybXV0ZWRBeGlzXSwgeFJhbmspO1xuXG4gIGNvbnN0IHBlcm11dGVkT3V0ID0gYmFja2VuZC5tYWtlT3V0cHV0KHBlcm11dGVkWC5zaGFwZSwgcGVybXV0ZWRYLmR0eXBlKTtcbiAgY29uc3QgZmluYWxEaW0gPSBwZXJtdXRlZFguc2hhcGVbcGVybXV0ZWRBeGlzXTtcbiAgY29uc3QgcGVybXV0ZWRYSWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQocGVybXV0ZWRYLmRhdGFJZCkuaWQ7XG4gIGNvbnN0IHBlcm11dGVkT3V0SWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQocGVybXV0ZWRPdXQuZGF0YUlkKS5pZDtcbiAgd2FzbUN1bXByb2QocGVybXV0ZWRYSWQsIGV4Y2x1c2l2ZSA/IDEgOiAwLCByZXZlcnNlID8gMSA6IDAsIGZpbmFsRGltLFxuICAgICAgICAgICAgICBwZXJtdXRlZE91dElkLCBDcHBEVHlwZVt4LmR0eXBlXSk7XG5cbiAgLy8gdHJhbnNwb3NlIGRhdGEgYmFjayBpZiBwZXJtdXRlZFxuICBsZXQgb3V0ID0gcGVybXV0ZWRPdXQ7XG4gIGlmIChwZXJtdXRhdGlvbiAhPT0gbnVsbCkge1xuICAgIGNvbnN0IHVuZG9QZXJtdXRhdGlvbiA9IGJhY2tlbmRfdXRpbC5nZXRVbmRvQXhlc1Blcm11dGF0aW9uKHBlcm11dGF0aW9uKTtcbiAgICBvdXQgPSB0cmFuc3Bvc2UoXG4gICAgICB7aW5wdXRzOiB7eDogcGVybXV0ZWRPdXR9LCBhdHRyczoge3Blcm06IHVuZG9QZXJtdXRhdGlvbn0sIGJhY2tlbmR9KTtcbiAgICBiYWNrZW5kLmRpc3Bvc2VEYXRhKHBlcm11dGVkWC5kYXRhSWQpO1xuICAgIGJhY2tlbmQuZGlzcG9zZURhdGEocGVybXV0ZWRPdXQuZGF0YUlkKTtcbiAgfVxuICByZXR1cm4gb3V0O1xufVxuXG5leHBvcnQgY29uc3QgY3VtcHJvZENvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBDdW1wcm9kLFxuICBiYWNrZW5kTmFtZTogJ3dhc20nLFxuICBzZXR1cEZ1bmM6IHNldHVwLFxuICBrZXJuZWxGdW5jOiBjdW1wcm9kIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==