@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
82 lines • 12 kB
JavaScript
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { backend_util, MaxPool, util } from '@tensorflow/tfjs-core';
let wasmMaxPool;
function setup(backend) {
wasmMaxPool = backend.wasm.cwrap(MaxPool, null /* void */, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // outId
]);
}
function maxPool(args) {
const { inputs, attrs, backend } = args;
const x = inputs.x;
const xId = backend.dataIdMap.get(x.dataId).id;
// TF API supports int32 input. CPU and WebGL backend also support int32
// input. WASM backend doesn't support it because it uses xnnpack which only
// supports float32.
//
// Add the following assert only for the WASM backend instead of at core op
// level.
//
// TODO: add support for int32 input.
util.assert(x.dtype === 'float32', () => `Error in MaxPool: only float32 input is supported. Got ${x.dtype}.`);
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const padTop = convInfo.padInfo.top;
const padRight = convInfo.padInfo.right;
const padBottom = convInfo.padInfo.bottom;
const padLeft = convInfo.padInfo.left;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const inputChannels = convInfo.inChannels;
const outputChannels = convInfo.outChannels;
if (convInfo.dataFormat !== 'channelsLast') {
throw new Error(`wasm backend does not support dataFormat:'` +
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
}
const out = backend.makeOutput(convInfo.outShape, 'float32');
const outId = backend.dataIdMap.get(out.dataId).id;
wasmMaxPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId);
return out;
}
export const maxPoolConfig = {
kernelName: MaxPool,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: maxPool
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTWF4UG9vbC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL01heFBvb2wudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBNEIsT0FBTyxFQUF5QyxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUluSSxJQUFJLFdBS3FFLENBQUM7QUFFMUUsU0FBUyxLQUFLLENBQUMsT0FBb0I7SUFDakMsV0FBVyxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsS0FBSyxDQUFDLE9BQU8sRUFBRSxJQUFJLENBQUMsVUFBVSxFQUFFO1FBQ3pELFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRLEVBQUcsUUFBUTtLQUNwQixDQUFDLENBQUM7QUFDTCxDQUFDO0FBRUQsU0FBUyxPQUFPLENBQ1osSUFBd0U7SUFDMUUsTUFBTSxFQUFDLE1BQU0sRUFBRSxLQUFLLEVBQUUsT0FBTyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBRXRDLE1BQU0sQ0FBQyxHQUFHLE1BQU0sQ0FBQyxDQUFhLENBQUM7SUFDL0IsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUUvQyx3RUFBd0U7SUFDeEUsNEVBQTRFO0lBQzVFLG9CQUFvQjtJQUNwQixFQUFFO0lBQ0YsMkVBQTJFO0lBQzNFLFNBQVM7SUFDVCxFQUFFO0lBQ0YscUNBQXFDO0lBQ3JDLElBQUksQ0FBQyxNQUFNLENBQ1AsQ0FBQyxDQUFDLEtBQUssS0FBSyxTQUFTLEVBQ3JCLEdBQUcsRUFBRSxDQUNELDBEQUEwRCxDQUFDLENBQUMsS0FBSyxHQUFHLENBQUMsQ0FBQztJQUU5RSxNQUFNLEVBQUMsVUFBVSxFQUFFLE9BQU8sRUFBRSxHQUFHLEVBQUUsZUFBZSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBQzFELE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyxpQkFBaUIsQ0FDM0MsQ0FBQyxDQUFDLEtBQUssRUFBRSxVQUFVLEVBQUUsT0FBTyxFQUFFLENBQUMsQ0FBQyxlQUFlLEVBQUUsR0FBRyxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBRTNFLE1BQU0sWUFBWSxHQUFHLFFBQVEsQ0FBQyxZQUFZLENBQUM7SUFDM0MsTUFBTSxXQUFXLEdBQUcsUUFBUSxDQUFDLFdBQVcsQ0FBQztJQUN6QyxNQUFNLE1BQU0sR0FBRyxRQUFRLENBQUMsT0FBTyxDQUFDLEdBQUcsQ0FBQztJQUNwQyxNQUFNLFFBQVEsR0FBRyxRQUFRLENBQUMsT0FBTyxDQUFDLEtBQUssQ0FBQztJQUN4QyxNQUFNLFNBQVMsR0FBRyxRQUFRLENBQUMsT0FBTyxDQUFDLE1BQU0sQ0FBQztJQUMxQyxNQUFNLE9BQU8sR0FBRyxRQUFRLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQztJQUN0QyxNQUFNLGNBQWMsR0FBRyxRQUFRLENBQUMsY0FBYyxDQUFDO0lBQy9DLE1BQU0sYUFBYSxHQUFHLFFBQVEsQ0FBQyxhQUFhLENBQUM7SUFDN0MsTUFBTSxZQUFZLEdBQUcsUUFBUSxDQUFDLFlBQVksQ0FBQztJQUMzQyxNQUFNLFdBQVcsR0FBRyxRQUFRLENBQUMsV0FBVyxDQUFDO0lBQ3pDLE1BQU0sYUFBYSxHQUFHLFFBQVEsQ0FBQyxVQUFVLENBQUM7SUFDMUMsTUFBTSxjQUFjLEdBQUcsUUFBUSxDQUFDLFdBQVcsQ0FBQztJQUU1QyxJQUFJLFFBQVEsQ0FBQyxVQUFVLEtBQUssY0FBYyxFQUFFO1FBQzFDLE1BQU0sSUFBSSxLQUFLLENBQ1gsNENBQTRDO1lBQzVDLEdBQUcsUUFBUSxDQUFDLFVBQVUsK0JBQStCLENBQUMsQ0FBQztLQUM1RDtJQUVELE1BQU0sR0FBRyxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQUMsUUFBUSxDQUFDLFFBQVEsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUM3RCxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsU0FBUyxDQUFDLEdBQUcsQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUFDLENBQUMsRUFBRSxDQUFDO0lBRW5ELFdBQVcsQ0FDUCxHQUFHLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQUUsWUFBWSxFQUFFLFdBQVcsRUFDbEUsTUFBTSxFQUFFLFFBQVEsRUFBRSxTQUFTLEVBQUUsT0FBTyxFQUFFLGNBQWMsRUFBRSxhQUFhLEVBQ25FLFlBQVksRUFBRSxXQUFXLEVBQUUsYUFBYSxFQUFFLGNBQWMsRUFBRSxLQUFLLENBQUMsQ0FBQztJQUNyRSxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxhQUFhLEdBQWlCO0lBQ3pDLFVBQVUsRUFBRSxPQUFPO0lBQ25CLFdBQVcsRUFBRSxNQUFNO0lBQ25CLFNBQVMsRUFBRSxLQUFLO0lBQ2hCLFVBQVUsRUFBRSxPQUFnQztDQUM3QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTkgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBNYXhQb29sLCBNYXhQb29sQXR0cnMsIE1heFBvb2xJbnB1dHMsIFRlbnNvcjRELCB1dGlsfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge0JhY2tlbmRXYXNtfSBmcm9tICcuLi9iYWNrZW5kX3dhc20nO1xuXG5sZXQgd2FzbU1heFBvb2w6IChcbiAgICB4SWQ6IG51bWJlciwgYmF0Y2hTaXplOiBudW1iZXIsIGlucHV0SGVpZ2h0OiBudW1iZXIsIGlucHV0V2lkdGg6IG51bWJlcixcbiAgICBmaWx0ZXJIZWlnaHQ6IG51bWJlciwgZmlsdGVyV2lkdGg6IG51bWJlciwgcGFkVG9wOiBudW1iZXIsIHBhZFJpZ2h0OiBudW1iZXIsXG4gICAgcGFkQm90dG9tOiBudW1iZXIsIHBhZExlZnQ6IG51bWJlciwgZGlsYXRpb25IZWlnaHQ6IG51bWJlcixcbiAgICBkaWxhdGlvbldpZHRoOiBudW1iZXIsIHN0cmlkZUhlaWdodDogbnVtYmVyLCBzdHJpZGVXaWR0aDogbnVtYmVyLFxuICAgIGlucHV0Q2hhbm5lbHM6IG51bWJlciwgb3V0cHV0Q2hhbm5lbHM6IG51bWJlciwgb3V0SWQ6IG51bWJlcikgPT4gdm9pZDtcblxuZnVuY3Rpb24gc2V0dXAoYmFja2VuZDogQmFja2VuZFdhc20pIHtcbiAgd2FzbU1heFBvb2wgPSBiYWNrZW5kLndhc20uY3dyYXAoTWF4UG9vbCwgbnVsbCAvKiB2b2lkICovLCBbXG4gICAgJ251bWJlcicsICAvLyB4SWRcbiAgICAnbnVtYmVyJywgIC8vIGJhdGNoU2l6ZVxuICAgICdudW1iZXInLCAgLy8gaW5wdXRIZWlnaHRcbiAgICAnbnVtYmVyJywgIC8vIGlucHV0V2lkdGhcbiAgICAnbnVtYmVyJywgIC8vIGZpbHRlckhlaWdodFxuICAgICdudW1iZXInLCAgLy8gZmlsdGVyV2lkdGhcbiAgICAnbnVtYmVyJywgIC8vIHBhZFRvcFxuICAgICdudW1iZXInLCAgLy8gcGFkUmlnaHRcbiAgICAnbnVtYmVyJywgIC8vIHBhZEJvdHRvbVxuICAgICdudW1iZXInLCAgLy8gcGFkTGVmdFxuICAgICdudW1iZXInLCAgLy8gZGlsYXRpb25IZWlnaHRcbiAgICAnbnVtYmVyJywgIC8vIGRpbGF0aW9uV2lkdGhcbiAgICAnbnVtYmVyJywgIC8vIHN0cmlkZUhlaWdodFxuICAgICdudW1iZXInLCAgLy8gc3RyaWRlV2lkdGhcbiAgICAnbnVtYmVyJywgIC8vIGlucHV0Q2hhbm5lbHNcbiAgICAnbnVtYmVyJywgIC8vIG91dHB1dENoYW5uZWxzXG4gICAgJ251bWJlcicsICAvLyBvdXRJZFxuICBdKTtcbn1cblxuZnVuY3Rpb24gbWF4UG9vbChcbiAgICBhcmdzOiB7aW5wdXRzOiBNYXhQb29sSW5wdXRzLCBiYWNrZW5kOiBCYWNrZW5kV2FzbSwgYXR0cnM6IE1heFBvb2xBdHRyc30pIHtcbiAgY29uc3Qge2lucHV0cywgYXR0cnMsIGJhY2tlbmR9ID0gYXJncztcblxuICBjb25zdCB4ID0gaW5wdXRzLnggYXMgVGVuc29yNEQ7XG4gIGNvbnN0IHhJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldCh4LmRhdGFJZCkuaWQ7XG5cbiAgLy8gVEYgQVBJIHN1cHBvcnRzIGludDMyIGlucHV0LiBDUFUgYW5kIFdlYkdMIGJhY2tlbmQgYWxzbyBzdXBwb3J0IGludDMyXG4gIC8vIGlucHV0LiBXQVNNIGJhY2tlbmQgZG9lc24ndCBzdXBwb3J0IGl0IGJlY2F1c2UgaXQgdXNlcyB4bm5wYWNrIHdoaWNoIG9ubHlcbiAgLy8gc3VwcG9ydHMgZmxvYXQzMi5cbiAgLy9cbiAgLy8gQWRkIHRoZSBmb2xsb3dpbmcgYXNzZXJ0IG9ubHkgZm9yIHRoZSBXQVNNIGJhY2tlbmQgaW5zdGVhZCBvZiBhdCBjb3JlIG9wXG4gIC8vIGxldmVsLlxuICAvL1xuICAvLyBUT0RPOiBhZGQgc3VwcG9ydCBmb3IgaW50MzIgaW5wdXQuXG4gIHV0aWwuYXNzZXJ0KFxuICAgICAgeC5kdHlwZSA9PT0gJ2Zsb2F0MzInLFxuICAgICAgKCkgPT5cbiAgICAgICAgICBgRXJyb3IgaW4gTWF4UG9vbDogb25seSBmbG9hdDMyIGlucHV0IGlzIHN1cHBvcnRlZC4gR290ICR7eC5kdHlwZX0uYCk7XG5cbiAgY29uc3Qge2ZpbHRlclNpemUsIHN0cmlkZXMsIHBhZCwgZGltUm91bmRpbmdNb2RlfSA9IGF0dHJzO1xuICBjb25zdCBjb252SW5mbyA9IGJhY2tlbmRfdXRpbC5jb21wdXRlUG9vbDJESW5mbyhcbiAgICAgIHguc2hhcGUsIGZpbHRlclNpemUsIHN0cmlkZXMsIDEgLyogZGlsYXRpb25zICovLCBwYWQsIGRpbVJvdW5kaW5nTW9kZSk7XG5cbiAgY29uc3QgZmlsdGVySGVpZ2h0ID0gY29udkluZm8uZmlsdGVySGVpZ2h0O1xuICBjb25zdCBmaWx0ZXJXaWR0aCA9IGNvbnZJbmZvLmZpbHRlcldpZHRoO1xuICBjb25zdCBwYWRUb3AgPSBjb252SW5mby5wYWRJbmZvLnRvcDtcbiAgY29uc3QgcGFkUmlnaHQgPSBjb252SW5mby5wYWRJbmZvLnJpZ2h0O1xuICBjb25zdCBwYWRCb3R0b20gPSBjb252SW5mby5wYWRJbmZvLmJvdHRvbTtcbiAgY29uc3QgcGFkTGVmdCA9IGNvbnZJbmZvLnBhZEluZm8ubGVmdDtcbiAgY29uc3QgZGlsYXRpb25IZWlnaHQgPSBjb252SW5mby5kaWxhdGlvbkhlaWdodDtcbiAgY29uc3QgZGlsYXRpb25XaWR0aCA9IGNvbnZJbmZvLmRpbGF0aW9uV2lkdGg7XG4gIGNvbnN0IHN0cmlkZUhlaWdodCA9IGNvbnZJbmZvLnN0cmlkZUhlaWdodDtcbiAgY29uc3Qgc3RyaWRlV2lkdGggPSBjb252SW5mby5zdHJpZGVXaWR0aDtcbiAgY29uc3QgaW5wdXRDaGFubmVscyA9IGNvbnZJbmZvLmluQ2hhbm5lbHM7XG4gIGNvbnN0IG91dHB1dENoYW5uZWxzID0gY29udkluZm8ub3V0Q2hhbm5lbHM7XG5cbiAgaWYgKGNvbnZJbmZvLmRhdGFGb3JtYXQgIT09ICdjaGFubmVsc0xhc3QnKSB7XG4gICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICBgd2FzbSBiYWNrZW5kIGRvZXMgbm90IHN1cHBvcnQgZGF0YUZvcm1hdDonYCArXG4gICAgICAgIGAke2NvbnZJbmZvLmRhdGFGb3JtYXR9Jy4gUGxlYXNlIHVzZSAnY2hhbm5lbHNMYXN0Jy5gKTtcbiAgfVxuXG4gIGNvbnN0IG91dCA9IGJhY2tlbmQubWFrZU91dHB1dChjb252SW5mby5vdXRTaGFwZSwgJ2Zsb2F0MzInKTtcbiAgY29uc3Qgb3V0SWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQob3V0LmRhdGFJZCkuaWQ7XG5cbiAgd2FzbU1heFBvb2woXG4gICAgICB4SWQsIHguc2hhcGVbMF0sIHguc2hhcGVbMV0sIHguc2hhcGVbMl0sIGZpbHRlckhlaWdodCwgZmlsdGVyV2lkdGgsXG4gICAgICBwYWRUb3AsIHBhZFJpZ2h0LCBwYWRCb3R0b20sIHBhZExlZnQsIGRpbGF0aW9uSGVpZ2h0LCBkaWxhdGlvbldpZHRoLFxuICAgICAgc3RyaWRlSGVpZ2h0LCBzdHJpZGVXaWR0aCwgaW5wdXRDaGFubmVscywgb3V0cHV0Q2hhbm5lbHMsIG91dElkKTtcbiAgcmV0dXJuIG91dDtcbn1cblxuZXhwb3J0IGNvbnN0IG1heFBvb2xDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogTWF4UG9vbCxcbiAgYmFja2VuZE5hbWU6ICd3YXNtJyxcbiAgc2V0dXBGdW5jOiBzZXR1cCxcbiAga2VybmVsRnVuYzogbWF4UG9vbCBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=