@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
71 lines • 10.5 kB
JavaScript
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { AvgPool, backend_util } from '@tensorflow/tfjs-core';
let wasmAvgPool;
function setup(backend) {
wasmAvgPool = backend.wasm.cwrap(AvgPool, null /* void */, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // outId
]);
}
function avgPool(args) {
const { inputs, attrs, backend } = args;
const x = inputs.x;
const xId = backend.dataIdMap.get(x.dataId).id;
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const padTop = convInfo.padInfo.top;
const padRight = convInfo.padInfo.right;
const padBottom = convInfo.padInfo.bottom;
const padLeft = convInfo.padInfo.left;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const channels = convInfo.inChannels;
if (convInfo.dataFormat !== 'channelsLast') {
throw new Error(`wasm backend does not support dataFormat:'` +
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
}
if (convInfo.dilationWidth !== 1 || convInfo.dilationHeight !== 1) {
throw new Error(`was backend only supports average pooling with dilation = [1, 1], ` +
`got [${convInfo.dilationHeight}, ${convInfo.dilationWidth}].`);
}
const out = backend.makeOutput(convInfo.outShape, 'float32');
const outId = backend.dataIdMap.get(out.dataId).id;
wasmAvgPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, strideHeight, strideWidth, channels, outId);
return out;
}
export const avgPoolConfig = {
kernelName: AvgPool,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: avgPool
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQXZnUG9vbC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL0F2Z1Bvb2wudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLE9BQU8sRUFBK0IsWUFBWSxFQUFxQyxNQUFNLHVCQUF1QixDQUFDO0FBSTdILElBQUksV0FJNkQsQ0FBQztBQUVsRSxTQUFTLEtBQUssQ0FBQyxPQUFvQjtJQUNqQyxXQUFXLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsT0FBTyxFQUFFLElBQUksQ0FBQyxVQUFVLEVBQUU7UUFDekQsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVEsRUFBRyxRQUFRO0tBQ3BCLENBQUMsQ0FBQztBQUNMLENBQUM7QUFFRCxTQUFTLE9BQU8sQ0FDWixJQUF3RTtJQUMxRSxNQUFNLEVBQUMsTUFBTSxFQUFFLEtBQUssRUFBRSxPQUFPLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFFdEMsTUFBTSxDQUFDLEdBQUcsTUFBTSxDQUFDLENBQWEsQ0FBQztJQUMvQixNQUFNLEdBQUcsR0FBRyxPQUFPLENBQUMsU0FBUyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUMsRUFBRSxDQUFDO0lBRS9DLE1BQU0sRUFBQyxVQUFVLEVBQUUsT0FBTyxFQUFFLEdBQUcsRUFBRSxlQUFlLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFDMUQsTUFBTSxRQUFRLEdBQUcsWUFBWSxDQUFDLGlCQUFpQixDQUMzQyxDQUFDLENBQUMsS0FBSyxFQUFFLFVBQVUsRUFBRSxPQUFPLEVBQUUsQ0FBQyxDQUFDLGVBQWUsRUFBRSxHQUFHLEVBQUUsZUFBZSxDQUFDLENBQUM7SUFFM0UsTUFBTSxZQUFZLEdBQUcsUUFBUSxDQUFDLFlBQVksQ0FBQztJQUMzQyxNQUFNLFdBQVcsR0FBRyxRQUFRLENBQUMsV0FBVyxDQUFDO0lBQ3pDLE1BQU0sTUFBTSxHQUFHLFFBQVEsQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDO0lBQ3BDLE1BQU0sUUFBUSxHQUFHLFFBQVEsQ0FBQyxPQUFPLENBQUMsS0FBSyxDQUFDO0lBQ3hDLE1BQU0sU0FBUyxHQUFHLFFBQVEsQ0FBQyxPQUFPLENBQUMsTUFBTSxDQUFDO0lBQzFDLE1BQU0sT0FBTyxHQUFHLFFBQVEsQ0FBQyxPQUFPLENBQUMsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sWUFBWSxHQUFHLFFBQVEsQ0FBQyxZQUFZLENBQUM7SUFDM0MsTUFBTSxXQUFXLEdBQUcsUUFBUSxDQUFDLFdBQVcsQ0FBQztJQUN6QyxNQUFNLFFBQVEsR0FBRyxRQUFRLENBQUMsVUFBVSxDQUFDO0lBRXJDLElBQUksUUFBUSxDQUFDLFVBQVUsS0FBSyxjQUFjLEVBQUU7UUFDMUMsTUFBTSxJQUFJLEtBQUssQ0FDWCw0Q0FBNEM7WUFDNUMsR0FBRyxRQUFRLENBQUMsVUFBVSwrQkFBK0IsQ0FBQyxDQUFDO0tBQzVEO0lBRUQsSUFBSSxRQUFRLENBQUMsYUFBYSxLQUFLLENBQUMsSUFBSSxRQUFRLENBQUMsY0FBYyxLQUFLLENBQUMsRUFBRTtRQUNqRSxNQUFNLElBQUksS0FBSyxDQUNYLG9FQUFvRTtZQUNwRSxRQUFRLFFBQVEsQ0FBQyxjQUFjLEtBQUssUUFBUSxDQUFDLGFBQWEsSUFBSSxDQUFDLENBQUM7S0FDckU7SUFFRCxNQUFNLEdBQUcsR0FBRyxPQUFPLENBQUMsVUFBVSxDQUFDLFFBQVEsQ0FBQyxRQUFRLEVBQUUsU0FBUyxDQUFDLENBQUM7SUFDN0QsTUFBTSxLQUFLLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUVuRCxXQUFXLENBQ1AsR0FBRyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLFlBQVksRUFBRSxXQUFXLEVBQ2xFLE1BQU0sRUFBRSxRQUFRLEVBQUUsU0FBUyxFQUFFLE9BQU8sRUFBRSxZQUFZLEVBQUUsV0FBVyxFQUFFLFFBQVEsRUFDekUsS0FBSyxDQUFDLENBQUM7SUFDWCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxhQUFhLEdBQWlCO0lBQ3pDLFVBQVUsRUFBRSxPQUFPO0lBQ25CLFdBQVcsRUFBRSxNQUFNO0lBQ25CLFNBQVMsRUFBRSxLQUFLO0lBQ2hCLFVBQVUsRUFBRSxPQUFnQztDQUM3QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTkgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge0F2Z1Bvb2wsIEF2Z1Bvb2xBdHRycywgQXZnUG9vbElucHV0cywgYmFja2VuZF91dGlsLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvcjREfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge0JhY2tlbmRXYXNtfSBmcm9tICcuLi9iYWNrZW5kX3dhc20nO1xuXG5sZXQgd2FzbUF2Z1Bvb2w6IChcbiAgICB4SWQ6IG51bWJlciwgYmF0Y2hTaXplOiBudW1iZXIsIGlucHV0SGVpZ2h0OiBudW1iZXIsIGlucHV0V2lkdGg6IG51bWJlcixcbiAgICBmaWx0ZXJIZWlnaHQ6IG51bWJlciwgZmlsdGVyV2lkdGg6IG51bWJlciwgcGFkVG9wOiBudW1iZXIsIHBhZFJpZ2h0OiBudW1iZXIsXG4gICAgcGFkQm90dG9tOiBudW1iZXIsIHBhZExlZnQ6IG51bWJlciwgc3RyaWRlSGVpZ2h0OiBudW1iZXIsXG4gICAgc3RyaWRlV2lkdGg6IG51bWJlciwgY2hhbm5lbHM6IG51bWJlciwgb3V0SWQ6IG51bWJlcikgPT4gdm9pZDtcblxuZnVuY3Rpb24gc2V0dXAoYmFja2VuZDogQmFja2VuZFdhc20pIHtcbiAgd2FzbUF2Z1Bvb2wgPSBiYWNrZW5kLndhc20uY3dyYXAoQXZnUG9vbCwgbnVsbCAvKiB2b2lkICovLCBbXG4gICAgJ251bWJlcicsICAvLyB4SWRcbiAgICAnbnVtYmVyJywgIC8vIGJhdGNoU2l6ZVxuICAgICdudW1iZXInLCAgLy8gaW5wdXRIZWlnaHRcbiAgICAnbnVtYmVyJywgIC8vIGlucHV0V2lkdGhcbiAgICAnbnVtYmVyJywgIC8vIGZpbHRlckhlaWdodFxuICAgICdudW1iZXInLCAgLy8gZmlsdGVyV2lkdGhcbiAgICAnbnVtYmVyJywgIC8vIHBhZFRvcFxuICAgICdudW1iZXInLCAgLy8gcGFkUmlnaHRcbiAgICAnbnVtYmVyJywgIC8vIHBhZEJvdHRvbVxuICAgICdudW1iZXInLCAgLy8gcGFkTGVmdFxuICAgICdudW1iZXInLCAgLy8gc3RyaWRlSGVpZ2h0XG4gICAgJ251bWJlcicsICAvLyBzdHJpZGVXaWR0aFxuICAgICdudW1iZXInLCAgLy8gY2hhbm5lbHNcbiAgICAnbnVtYmVyJywgIC8vIG91dElkXG4gIF0pO1xufVxuXG5mdW5jdGlvbiBhdmdQb29sKFxuICAgIGFyZ3M6IHtpbnB1dHM6IEF2Z1Bvb2xJbnB1dHMsIGJhY2tlbmQ6IEJhY2tlbmRXYXNtLCBhdHRyczogQXZnUG9vbEF0dHJzfSkge1xuICBjb25zdCB7aW5wdXRzLCBhdHRycywgYmFja2VuZH0gPSBhcmdzO1xuXG4gIGNvbnN0IHggPSBpbnB1dHMueCBhcyBUZW5zb3I0RDtcbiAgY29uc3QgeElkID0gYmFja2VuZC5kYXRhSWRNYXAuZ2V0KHguZGF0YUlkKS5pZDtcblxuICBjb25zdCB7ZmlsdGVyU2l6ZSwgc3RyaWRlcywgcGFkLCBkaW1Sb3VuZGluZ01vZGV9ID0gYXR0cnM7XG4gIGNvbnN0IGNvbnZJbmZvID0gYmFja2VuZF91dGlsLmNvbXB1dGVQb29sMkRJbmZvKFxuICAgICAgeC5zaGFwZSwgZmlsdGVyU2l6ZSwgc3RyaWRlcywgMSAvKiBkaWxhdGlvbnMgKi8sIHBhZCwgZGltUm91bmRpbmdNb2RlKTtcblxuICBjb25zdCBmaWx0ZXJIZWlnaHQgPSBjb252SW5mby5maWx0ZXJIZWlnaHQ7XG4gIGNvbnN0IGZpbHRlcldpZHRoID0gY29udkluZm8uZmlsdGVyV2lkdGg7XG4gIGNvbnN0IHBhZFRvcCA9IGNvbnZJbmZvLnBhZEluZm8udG9wO1xuICBjb25zdCBwYWRSaWdodCA9IGNvbnZJbmZvLnBhZEluZm8ucmlnaHQ7XG4gIGNvbnN0IHBhZEJvdHRvbSA9IGNvbnZJbmZvLnBhZEluZm8uYm90dG9tO1xuICBjb25zdCBwYWRMZWZ0ID0gY29udkluZm8ucGFkSW5mby5sZWZ0O1xuICBjb25zdCBzdHJpZGVIZWlnaHQgPSBjb252SW5mby5zdHJpZGVIZWlnaHQ7XG4gIGNvbnN0IHN0cmlkZVdpZHRoID0gY29udkluZm8uc3RyaWRlV2lkdGg7XG4gIGNvbnN0IGNoYW5uZWxzID0gY29udkluZm8uaW5DaGFubmVscztcblxuICBpZiAoY29udkluZm8uZGF0YUZvcm1hdCAhPT0gJ2NoYW5uZWxzTGFzdCcpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgIGB3YXNtIGJhY2tlbmQgZG9lcyBub3Qgc3VwcG9ydCBkYXRhRm9ybWF0OidgICtcbiAgICAgICAgYCR7Y29udkluZm8uZGF0YUZvcm1hdH0nLiBQbGVhc2UgdXNlICdjaGFubmVsc0xhc3QnLmApO1xuICB9XG5cbiAgaWYgKGNvbnZJbmZvLmRpbGF0aW9uV2lkdGggIT09IDEgfHwgY29udkluZm8uZGlsYXRpb25IZWlnaHQgIT09IDEpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgIGB3YXMgYmFja2VuZCBvbmx5IHN1cHBvcnRzIGF2ZXJhZ2UgcG9vbGluZyB3aXRoIGRpbGF0aW9uID0gWzEsIDFdLCBgICtcbiAgICAgICAgYGdvdCBbJHtjb252SW5mby5kaWxhdGlvbkhlaWdodH0sICR7Y29udkluZm8uZGlsYXRpb25XaWR0aH1dLmApO1xuICB9XG5cbiAgY29uc3Qgb3V0ID0gYmFja2VuZC5tYWtlT3V0cHV0KGNvbnZJbmZvLm91dFNoYXBlLCAnZmxvYXQzMicpO1xuICBjb25zdCBvdXRJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldChvdXQuZGF0YUlkKS5pZDtcblxuICB3YXNtQXZnUG9vbChcbiAgICAgIHhJZCwgeC5zaGFwZVswXSwgeC5zaGFwZVsxXSwgeC5zaGFwZVsyXSwgZmlsdGVySGVpZ2h0LCBmaWx0ZXJXaWR0aCxcbiAgICAgIHBhZFRvcCwgcGFkUmlnaHQsIHBhZEJvdHRvbSwgcGFkTGVmdCwgc3RyaWRlSGVpZ2h0LCBzdHJpZGVXaWR0aCwgY2hhbm5lbHMsXG4gICAgICBvdXRJZCk7XG4gIHJldHVybiBvdXQ7XG59XG5cbmV4cG9ydCBjb25zdCBhdmdQb29sQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IEF2Z1Bvb2wsXG4gIGJhY2tlbmROYW1lOiAnd2FzbScsXG4gIHNldHVwRnVuYzogc2V0dXAsXG4gIGtlcm5lbEZ1bmM6IGF2Z1Bvb2wgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19