UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

73 lines 9.98 kB
/** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, Sum, util } from '@tensorflow/tfjs-core'; import { permuteAxesAndTranspose } from './kernel_utils'; import { CppDType } from './types'; let wasmSum; function setup(backend) { wasmSum = backend.wasm.cwrap(Sum, null /*void*/, [ 'number', 'number', 'number', 'number', // out_id ]); } function sum(args) { const { backend, inputs, attrs } = args; const { axis, keepDims } = attrs; const { x } = inputs; const xId = backend.dataIdMap.get(x.dataId).id; let inputId = xId; let input = x; const { transposed, axes, originalAxes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend); let reductionAxes = axes; if (inputWasTransposed) { const transposedId = backend.dataIdMap.get(transposed.dataId).id; if (transposedId !== xId) { // transpose was not a no-op. We will need to dispose of this // once we are done. input = transposed; inputId = transposedId; reductionAxes = backend_util.getInnerMostAxes(reductionAxes.length, input.shape.length); } } backend_util.assertAxesAreInnerMostDims('sum', reductionAxes, input.shape.length); const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(input.shape, reductionAxes); const reduceSize = util.sizeFromShape(reduceShape); const out = backend.makeOutput(outShape, input.dtype); if (util.sizeFromShape(input.shape) !== 0) { const outId = backend.dataIdMap.get(out.dataId).id; wasmSum(inputId, reduceSize, CppDType[out.dtype], outId); } if (inputWasTransposed) { // dispose of the transposed tensor. backend.disposeData(transposed.dataId); } if (keepDims) { // reshape const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes); out.shape = newShape; } return out; } export const sumConfig = { kernelName: Sum, backendName: 'wasm', setupFunc: setup, kernelFunc: sum }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3VtLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdhc20vc3JjL2tlcm5lbHMvU3VtLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQTRCLEdBQUcsRUFBbUMsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJekgsT0FBTyxFQUFDLHVCQUF1QixFQUFDLE1BQU0sZ0JBQWdCLENBQUM7QUFDdkQsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUVqQyxJQUFJLE9BQ0ksQ0FBQztBQUVULFNBQVMsS0FBSyxDQUFDLE9BQW9CO0lBQ2pDLE9BQU8sR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEtBQUssQ0FBQyxHQUFHLEVBQUUsSUFBSSxDQUFDLFFBQVEsRUFBRTtRQUMvQyxRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRLEVBQUcsU0FBUztLQUNyQixDQUFDLENBQUM7QUFDTCxDQUFDO0FBRUQsU0FBUyxHQUFHLENBQUMsSUFBZ0U7SUFFM0UsTUFBTSxFQUFDLE9BQU8sRUFBRSxNQUFNLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxJQUFJLEVBQUUsUUFBUSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBQy9CLE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDbkIsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUMvQyxJQUFJLE9BQU8sR0FBRyxHQUFHLENBQUM7SUFDbEIsSUFBSSxLQUFLLEdBQUcsQ0FBQyxDQUFDO0lBRWQsTUFBTSxFQUFDLFVBQVUsRUFBRSxJQUFJLEVBQUUsWUFBWSxFQUFFLGtCQUFrQixFQUFDLEdBQ3RELHVCQUF1QixDQUFDLENBQUMsRUFBRSxJQUFJLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFFOUMsSUFBSSxhQUFhLEdBQUcsSUFBSSxDQUFDO0lBQ3pCLElBQUksa0JBQWtCLEVBQUU7UUFDdEIsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztRQUNqRSxJQUFJLFlBQVksS0FBSyxHQUFHLEVBQUU7WUFDeEIsNkRBQTZEO1lBQzdELG9CQUFvQjtZQUNwQixLQUFLLEdBQUcsVUFBVSxDQUFDO1lBQ25CLE9BQU8sR0FBRyxZQUFZLENBQUM7WUFDdkIsYUFBYSxHQUFHLFlBQVksQ0FBQyxnQkFBZ0IsQ0FDekMsYUFBYSxDQUFDLE1BQU0sRUFBRSxLQUFLLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQyxDQUFDO1NBQy9DO0tBQ0Y7SUFFRCxZQUFZLENBQUMsMEJBQTBCLENBQ25DLEtBQUssRUFBRSxhQUFhLEVBQUUsS0FBSyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsQ0FBQztJQUM5QyxNQUFNLENBQUMsUUFBUSxFQUFFLFdBQVcsQ0FBQyxHQUN6QixZQUFZLENBQUMseUJBQXlCLENBQUMsS0FBSyxDQUFDLEtBQUssRUFBRSxhQUFhLENBQUMsQ0FBQztJQUN2RSxNQUFNLFVBQVUsR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLFdBQVcsQ0FBQyxDQUFDO0lBRW5ELE1BQU0sR0FBRyxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQUMsUUFBUSxFQUFFLEtBQUssQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUN0RCxJQUFJLElBQUksQ0FBQyxhQUFhLENBQUMsS0FBSyxDQUFDLEtBQUssQ0FBQyxLQUFLLENBQUMsRUFBRTtRQUN6QyxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsU0FBUyxDQUFDLEdBQUcsQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUFDLENBQUMsRUFBRSxDQUFDO1FBQ25ELE9BQU8sQ0FBQyxPQUFPLEVBQUUsVUFBVSxFQUFFLFFBQVEsQ0FBQyxHQUFHLENBQUMsS0FBSyxDQUFDLEVBQUUsS0FBSyxDQUFDLENBQUM7S0FDMUQ7SUFFRCxJQUFJLGtCQUFrQixFQUFFO1FBQ3RCLG9DQUFvQztRQUNwQyxPQUFPLENBQUMsV0FBVyxDQUFDLFVBQVUsQ0FBQyxNQUFNLENBQUMsQ0FBQztLQUN4QztJQUVELElBQUksUUFBUSxFQUFFO1FBQ1osVUFBVTtRQUNWLE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyxvQkFBb0IsQ0FBQyxHQUFHLENBQUMsS0FBSyxFQUFFLFlBQVksQ0FBQyxDQUFDO1FBQzVFLEdBQUcsQ0FBQyxLQUFLLEdBQUcsUUFBUSxDQUFDO0tBQ3RCO0lBRUQsT0FBTyxHQUFHLENBQUM7QUFDYixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sU0FBUyxHQUFpQjtJQUNyQyxVQUFVLEVBQUUsR0FBRztJQUNmLFdBQVcsRUFBRSxNQUFNO0lBQ25CLFNBQVMsRUFBRSxLQUFLO0lBQ2hCLFVBQVUsRUFBRSxHQUE0QjtDQUN6QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTkgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTdW0sIFN1bUF0dHJzLCBTdW1JbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QmFja2VuZFdhc219IGZyb20gJy4uL2JhY2tlbmRfd2FzbSc7XG5cbmltcG9ydCB7cGVybXV0ZUF4ZXNBbmRUcmFuc3Bvc2V9IGZyb20gJy4va2VybmVsX3V0aWxzJztcbmltcG9ydCB7Q3BwRFR5cGV9IGZyb20gJy4vdHlwZXMnO1xuXG5sZXQgd2FzbVN1bTogKHhJZDogbnVtYmVyLCByZWR1Y2VTaXplOiBudW1iZXIsIGR0eXBlOiBudW1iZXIsIG91dElkOiBudW1iZXIpID0+XG4gICAgdm9pZDtcblxuZnVuY3Rpb24gc2V0dXAoYmFja2VuZDogQmFja2VuZFdhc20pOiB2b2lkIHtcbiAgd2FzbVN1bSA9IGJhY2tlbmQud2FzbS5jd3JhcChTdW0sIG51bGwgLyp2b2lkKi8sIFtcbiAgICAnbnVtYmVyJywgIC8vIGlucHV0X2lkXG4gICAgJ251bWJlcicsICAvLyByZWR1Y2Vfc2l6ZVxuICAgICdudW1iZXInLCAgLy8gZHR5cGVcbiAgICAnbnVtYmVyJywgIC8vIG91dF9pZFxuICBdKTtcbn1cblxuZnVuY3Rpb24gc3VtKGFyZ3M6IHtiYWNrZW5kOiBCYWNrZW5kV2FzbSwgaW5wdXRzOiBTdW1JbnB1dHMsIGF0dHJzOiBTdW1BdHRyc30pOlxuICAgIFRlbnNvckluZm8ge1xuICBjb25zdCB7YmFja2VuZCwgaW5wdXRzLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7YXhpcywga2VlcERpbXN9ID0gYXR0cnM7XG4gIGNvbnN0IHt4fSA9IGlucHV0cztcbiAgY29uc3QgeElkID0gYmFja2VuZC5kYXRhSWRNYXAuZ2V0KHguZGF0YUlkKS5pZDtcbiAgbGV0IGlucHV0SWQgPSB4SWQ7XG4gIGxldCBpbnB1dCA9IHg7XG5cbiAgY29uc3Qge3RyYW5zcG9zZWQsIGF4ZXMsIG9yaWdpbmFsQXhlcywgaW5wdXRXYXNUcmFuc3Bvc2VkfSA9XG4gICAgICBwZXJtdXRlQXhlc0FuZFRyYW5zcG9zZSh4LCBheGlzLCBiYWNrZW5kKTtcblxuICBsZXQgcmVkdWN0aW9uQXhlcyA9IGF4ZXM7XG4gIGlmIChpbnB1dFdhc1RyYW5zcG9zZWQpIHtcbiAgICBjb25zdCB0cmFuc3Bvc2VkSWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQodHJhbnNwb3NlZC5kYXRhSWQpLmlkO1xuICAgIGlmICh0cmFuc3Bvc2VkSWQgIT09IHhJZCkge1xuICAgICAgLy8gdHJhbnNwb3NlIHdhcyBub3QgYSBuby1vcC4gV2Ugd2lsbCBuZWVkIHRvIGRpc3Bvc2Ugb2YgdGhpc1xuICAgICAgLy8gb25jZSB3ZSBhcmUgZG9uZS5cbiAgICAgIGlucHV0ID0gdHJhbnNwb3NlZDtcbiAgICAgIGlucHV0SWQgPSB0cmFuc3Bvc2VkSWQ7XG4gICAgICByZWR1Y3Rpb25BeGVzID0gYmFja2VuZF91dGlsLmdldElubmVyTW9zdEF4ZXMoXG4gICAgICAgICAgcmVkdWN0aW9uQXhlcy5sZW5ndGgsIGlucHV0LnNoYXBlLmxlbmd0aCk7XG4gICAgfVxuICB9XG5cbiAgYmFja2VuZF91dGlsLmFzc2VydEF4ZXNBcmVJbm5lck1vc3REaW1zKFxuICAgICAgJ3N1bScsIHJlZHVjdGlvbkF4ZXMsIGlucHV0LnNoYXBlLmxlbmd0aCk7XG4gIGNvbnN0IFtvdXRTaGFwZSwgcmVkdWNlU2hhcGVdID1cbiAgICAgIGJhY2tlbmRfdXRpbC5jb21wdXRlT3V0QW5kUmVkdWNlU2hhcGVzKGlucHV0LnNoYXBlLCByZWR1Y3Rpb25BeGVzKTtcbiAgY29uc3QgcmVkdWNlU2l6ZSA9IHV0aWwuc2l6ZUZyb21TaGFwZShyZWR1Y2VTaGFwZSk7XG5cbiAgY29uc3Qgb3V0ID0gYmFja2VuZC5tYWtlT3V0cHV0KG91dFNoYXBlLCBpbnB1dC5kdHlwZSk7XG4gIGlmICh1dGlsLnNpemVGcm9tU2hhcGUoaW5wdXQuc2hhcGUpICE9PSAwKSB7XG4gICAgY29uc3Qgb3V0SWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQob3V0LmRhdGFJZCkuaWQ7XG4gICAgd2FzbVN1bShpbnB1dElkLCByZWR1Y2VTaXplLCBDcHBEVHlwZVtvdXQuZHR5cGVdLCBvdXRJZCk7XG4gIH1cblxuICBpZiAoaW5wdXRXYXNUcmFuc3Bvc2VkKSB7XG4gICAgLy8gZGlzcG9zZSBvZiB0aGUgdHJhbnNwb3NlZCB0ZW5zb3IuXG4gICAgYmFja2VuZC5kaXNwb3NlRGF0YSh0cmFuc3Bvc2VkLmRhdGFJZCk7XG4gIH1cblxuICBpZiAoa2VlcERpbXMpIHtcbiAgICAvLyByZXNoYXBlXG4gICAgY29uc3QgbmV3U2hhcGUgPSBiYWNrZW5kX3V0aWwuZXhwYW5kU2hhcGVUb0tlZXBEaW0ob3V0LnNoYXBlLCBvcmlnaW5hbEF4ZXMpO1xuICAgIG91dC5zaGFwZSA9IG5ld1NoYXBlO1xuICB9XG5cbiAgcmV0dXJuIG91dDtcbn1cblxuZXhwb3J0IGNvbnN0IHN1bUNvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBTdW0sXG4gIGJhY2tlbmROYW1lOiAnd2FzbScsXG4gIHNldHVwRnVuYzogc2V0dXAsXG4gIGtlcm5lbEZ1bmM6IHN1bSBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=