UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

73 lines 9.94 kB
/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, Prod, util } from '@tensorflow/tfjs-core'; import { permuteAxesAndTranspose } from './kernel_utils'; import { CppDType } from './types'; let wasmProd; function setup(backend) { wasmProd = backend.wasm.cwrap(Prod, null /*void*/, [ 'number', 'number', 'number', 'number' ]); } function prod(args) { const { backend, inputs, attrs } = args; const { axis, keepDims } = attrs; const { x } = inputs; const xId = backend.dataIdMap.get(x.dataId).id; let inputId = xId; let input = x; const { transposed, axes, originalAxes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend); let reductionAxes = axes; if (inputWasTransposed) { const transposedId = backend.dataIdMap.get(transposed.dataId).id; if (transposedId !== xId) { // transpose was not a no-op. We will need to dispose of this // once we are done. input = transposed; inputId = transposedId; reductionAxes = backend_util.getInnerMostAxes(reductionAxes.length, input.shape.length); } } backend_util.assertAxesAreInnerMostDims('prod', reductionAxes, input.shape.length); const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(input.shape, reductionAxes); const reduceSize = util.sizeFromShape(reduceShape); const out = backend.makeOutput(outShape, input.dtype); if (util.sizeFromShape(input.shape) !== 0) { const outId = backend.dataIdMap.get(out.dataId).id; wasmProd(inputId, reduceSize, CppDType[out.dtype], outId); } if (inputWasTransposed) { // dispose of the transposed tensor. backend.disposeData(transposed.dataId); } if (keepDims) { // reshape const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes); out.shape = newShape; } return out; } export const prodConfig = { kernelName: Prod, backendName: 'wasm', setupFunc: setup, kernelFunc: prod }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiUHJvZC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL1Byb2QudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBNEIsSUFBSSxFQUFxQyxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUk1SCxPQUFPLEVBQUMsdUJBQXVCLEVBQUMsTUFBTSxnQkFBZ0IsQ0FBQztBQUV2RCxPQUFPLEVBQUMsUUFBUSxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBRWpDLElBQUksUUFFcUMsQ0FBQztBQUUxQyxTQUFTLEtBQUssQ0FBQyxPQUFvQjtJQUNqQyxRQUFRLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsSUFBSSxFQUFFLElBQUksQ0FBQyxRQUFRLEVBQUU7UUFDakQsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtLQUNULENBQUMsQ0FBQztBQUNMLENBQUM7QUFFRCxTQUFTLElBQUksQ0FBQyxJQUliO0lBQ0MsTUFBTSxFQUFDLE9BQU8sRUFBRSxNQUFNLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxJQUFJLEVBQUUsUUFBUSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBQy9CLE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDbkIsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUMvQyxJQUFJLE9BQU8sR0FBRyxHQUFHLENBQUM7SUFDbEIsSUFBSSxLQUFLLEdBQUcsQ0FBQyxDQUFDO0lBRWQsTUFBTSxFQUFDLFVBQVUsRUFBRSxJQUFJLEVBQUUsWUFBWSxFQUFFLGtCQUFrQixFQUFDLEdBQ3RELHVCQUF1QixDQUFDLENBQUMsRUFBRSxJQUFJLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFFOUMsSUFBSSxhQUFhLEdBQUcsSUFBSSxDQUFDO0lBQ3pCLElBQUksa0JBQWtCLEVBQUU7UUFDdEIsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztRQUNqRSxJQUFJLFlBQVksS0FBSyxHQUFHLEVBQUU7WUFDeEIsNkRBQTZEO1lBQzdELG9CQUFvQjtZQUNwQixLQUFLLEdBQUcsVUFBVSxDQUFDO1lBQ25CLE9BQU8sR0FBRyxZQUFZLENBQUM7WUFDdkIsYUFBYSxHQUFHLFlBQVksQ0FBQyxnQkFBZ0IsQ0FDekMsYUFBYSxDQUFDLE1BQU0sRUFBRSxLQUFLLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQyxDQUFDO1NBQy9DO0tBQ0Y7SUFFRCxZQUFZLENBQUMsMEJBQTBCLENBQ25DLE1BQU0sRUFBRSxhQUFhLEVBQUUsS0FBSyxDQUFDLEtBQUssQ0FBQyxNQUFNLENBQUMsQ0FBQztJQUMvQyxNQUFNLENBQUMsUUFBUSxFQUFFLFdBQVcsQ0FBQyxHQUN6QixZQUFZLENBQUMseUJBQXlCLENBQUMsS0FBSyxDQUFDLEtBQUssRUFBRSxhQUFhLENBQUMsQ0FBQztJQUN2RSxNQUFNLFVBQVUsR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLFdBQVcsQ0FBQyxDQUFDO0lBRW5ELE1BQU0sR0FBRyxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQUMsUUFBUSxFQUFFLEtBQUssQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUN0RCxJQUFJLElBQUksQ0FBQyxhQUFhLENBQUMsS0FBSyxDQUFDLEtBQUssQ0FBQyxLQUFLLENBQUMsRUFBRTtRQUN6QyxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsU0FBUyxDQUFDLEdBQUcsQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUFDLENBQUMsRUFBRSxDQUFDO1FBQ25ELFFBQVEsQ0FBQyxPQUFPLEVBQUUsVUFBVSxFQUFFLFFBQVEsQ0FBQyxHQUFHLENBQUMsS0FBSyxDQUFDLEVBQUUsS0FBSyxDQUFDLENBQUM7S0FDM0Q7SUFFRCxJQUFJLGtCQUFrQixFQUFFO1FBQ3RCLG9DQUFvQztRQUNwQyxPQUFPLENBQUMsV0FBVyxDQUFDLFVBQVUsQ0FBQyxNQUFNLENBQUMsQ0FBQztLQUN4QztJQUVELElBQUksUUFBUSxFQUFFO1FBQ1osVUFBVTtRQUNWLE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyxvQkFBb0IsQ0FBQyxHQUFHLENBQUMsS0FBSyxFQUFFLFlBQVksQ0FBQyxDQUFDO1FBQzVFLEdBQUcsQ0FBQyxLQUFLLEdBQUcsUUFBUSxDQUFDO0tBQ3RCO0lBRUQsT0FBTyxHQUFHLENBQUM7QUFDYixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sVUFBVSxHQUFpQjtJQUN0QyxVQUFVLEVBQUUsSUFBSTtJQUNoQixXQUFXLEVBQUUsTUFBTTtJQUNuQixTQUFTLEVBQUUsS0FBSztJQUNoQixVQUFVLEVBQUUsSUFBNkI7Q0FDMUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIEtlcm5lbENvbmZpZywgS2VybmVsRnVuYywgUHJvZCwgUHJvZEF0dHJzLCBQcm9kSW5wdXRzLCBUZW5zb3JJbmZvLCB1dGlsfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge0JhY2tlbmRXYXNtfSBmcm9tICcuLi9iYWNrZW5kX3dhc20nO1xuXG5pbXBvcnQge3Blcm11dGVBeGVzQW5kVHJhbnNwb3NlfSBmcm9tICcuL2tlcm5lbF91dGlscyc7XG5cbmltcG9ydCB7Q3BwRFR5cGV9IGZyb20gJy4vdHlwZXMnO1xuXG5sZXQgd2FzbVByb2Q6IChcbiAgICB4SWQ6IG51bWJlciwgcmVkdWNlU2l6ZTogbnVtYmVyLFxuICAgIGR0eXBlOiBudW1iZXIsIG91dElkOiBudW1iZXIpID0+IHZvaWQ7XG5cbmZ1bmN0aW9uIHNldHVwKGJhY2tlbmQ6IEJhY2tlbmRXYXNtKTogdm9pZCB7XG4gIHdhc21Qcm9kID0gYmFja2VuZC53YXNtLmN3cmFwKFByb2QsIG51bGwgLyp2b2lkKi8sIFtcbiAgICAnbnVtYmVyJyxcbiAgICAnbnVtYmVyJyxcbiAgICAnbnVtYmVyJyxcbiAgICAnbnVtYmVyJ1xuICBdKTtcbn1cblxuZnVuY3Rpb24gcHJvZChhcmdzOiB7XG4gIGJhY2tlbmQ6IEJhY2tlbmRXYXNtLFxuICBpbnB1dHM6IFByb2RJbnB1dHMsXG4gIGF0dHJzOiBQcm9kQXR0cnNcbn0pOiBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2JhY2tlbmQsIGlucHV0cywgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge2F4aXMsIGtlZXBEaW1zfSA9IGF0dHJzO1xuICBjb25zdCB7eH0gPSBpbnB1dHM7XG4gIGNvbnN0IHhJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldCh4LmRhdGFJZCkuaWQ7XG4gIGxldCBpbnB1dElkID0geElkO1xuICBsZXQgaW5wdXQgPSB4O1xuXG4gIGNvbnN0IHt0cmFuc3Bvc2VkLCBheGVzLCBvcmlnaW5hbEF4ZXMsIGlucHV0V2FzVHJhbnNwb3NlZH0gPVxuICAgICAgcGVybXV0ZUF4ZXNBbmRUcmFuc3Bvc2UoeCwgYXhpcywgYmFja2VuZCk7XG5cbiAgbGV0IHJlZHVjdGlvbkF4ZXMgPSBheGVzO1xuICBpZiAoaW5wdXRXYXNUcmFuc3Bvc2VkKSB7XG4gICAgY29uc3QgdHJhbnNwb3NlZElkID0gYmFja2VuZC5kYXRhSWRNYXAuZ2V0KHRyYW5zcG9zZWQuZGF0YUlkKS5pZDtcbiAgICBpZiAodHJhbnNwb3NlZElkICE9PSB4SWQpIHtcbiAgICAgIC8vIHRyYW5zcG9zZSB3YXMgbm90IGEgbm8tb3AuIFdlIHdpbGwgbmVlZCB0byBkaXNwb3NlIG9mIHRoaXNcbiAgICAgIC8vIG9uY2Ugd2UgYXJlIGRvbmUuXG4gICAgICBpbnB1dCA9IHRyYW5zcG9zZWQ7XG4gICAgICBpbnB1dElkID0gdHJhbnNwb3NlZElkO1xuICAgICAgcmVkdWN0aW9uQXhlcyA9IGJhY2tlbmRfdXRpbC5nZXRJbm5lck1vc3RBeGVzKFxuICAgICAgICAgIHJlZHVjdGlvbkF4ZXMubGVuZ3RoLCBpbnB1dC5zaGFwZS5sZW5ndGgpO1xuICAgIH1cbiAgfVxuXG4gIGJhY2tlbmRfdXRpbC5hc3NlcnRBeGVzQXJlSW5uZXJNb3N0RGltcyhcbiAgICAgICdwcm9kJywgcmVkdWN0aW9uQXhlcywgaW5wdXQuc2hhcGUubGVuZ3RoKTtcbiAgY29uc3QgW291dFNoYXBlLCByZWR1Y2VTaGFwZV0gPVxuICAgICAgYmFja2VuZF91dGlsLmNvbXB1dGVPdXRBbmRSZWR1Y2VTaGFwZXMoaW5wdXQuc2hhcGUsIHJlZHVjdGlvbkF4ZXMpO1xuICBjb25zdCByZWR1Y2VTaXplID0gdXRpbC5zaXplRnJvbVNoYXBlKHJlZHVjZVNoYXBlKTtcblxuICBjb25zdCBvdXQgPSBiYWNrZW5kLm1ha2VPdXRwdXQob3V0U2hhcGUsIGlucHV0LmR0eXBlKTtcbiAgaWYgKHV0aWwuc2l6ZUZyb21TaGFwZShpbnB1dC5zaGFwZSkgIT09IDApIHtcbiAgICBjb25zdCBvdXRJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldChvdXQuZGF0YUlkKS5pZDtcbiAgICB3YXNtUHJvZChpbnB1dElkLCByZWR1Y2VTaXplLCBDcHBEVHlwZVtvdXQuZHR5cGVdLCBvdXRJZCk7XG4gIH1cblxuICBpZiAoaW5wdXRXYXNUcmFuc3Bvc2VkKSB7XG4gICAgLy8gZGlzcG9zZSBvZiB0aGUgdHJhbnNwb3NlZCB0ZW5zb3IuXG4gICAgYmFja2VuZC5kaXNwb3NlRGF0YSh0cmFuc3Bvc2VkLmRhdGFJZCk7XG4gIH1cblxuICBpZiAoa2VlcERpbXMpIHtcbiAgICAvLyByZXNoYXBlXG4gICAgY29uc3QgbmV3U2hhcGUgPSBiYWNrZW5kX3V0aWwuZXhwYW5kU2hhcGVUb0tlZXBEaW0ob3V0LnNoYXBlLCBvcmlnaW5hbEF4ZXMpO1xuICAgIG91dC5zaGFwZSA9IG5ld1NoYXBlO1xuICB9XG5cbiAgcmV0dXJuIG91dDtcbn1cblxuZXhwb3J0IGNvbnN0IHByb2RDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogUHJvZCxcbiAgYmFja2VuZE5hbWU6ICd3YXNtJyxcbiAgc2V0dXBGdW5jOiBzZXR1cCxcbiAga2VybmVsRnVuYzogcHJvZCBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=