@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
63 lines • 11.2 kB
JavaScript
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { backend_util, MaxPoolWithArgmax, util } from '@tensorflow/tfjs-core';
import { CppDType } from './types';
let wasmMaxPoolWithArgmax;
function setup(backend) {
wasmMaxPoolWithArgmax = backend.wasm.cwrap('MaxPoolWithArgmax', null, [
'number',
'number',
'number',
'number',
'boolean',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // padLeft
]);
}
export function maxPoolWithArgmax(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { filterSize, strides, pad, includeBatchInIndex } = attrs;
util.assert(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`);
const dilations = [1, 1];
util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
`Got strides ${strides} and dilations '${dilations}'`);
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
const pooled = backend.makeOutput(convInfo.outShape, x.dtype);
const indexes = backend.makeOutput(convInfo.outShape, 'int32');
wasmMaxPoolWithArgmax(backend.dataIdMap.get(x.dataId).id, backend.dataIdMap.get(pooled.dataId).id, backend.dataIdMap.get(indexes.dataId).id, CppDType[x.dtype], includeBatchInIndex, convInfo.batchSize, convInfo.inChannels, convInfo.inHeight, convInfo.inWidth, convInfo.outHeight, convInfo.outWidth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.top, convInfo.padInfo.left);
return [pooled, indexes];
}
export const maxPoolWithArgmaxConfig = {
kernelName: MaxPoolWithArgmax,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: maxPoolWithArgmax
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTWF4UG9vbFdpdGhBcmdtYXguanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2FzbS9zcmMva2VybmVscy9NYXhQb29sV2l0aEFyZ21heC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUE0QixpQkFBaUIsRUFBK0QsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJbkssT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUVqQyxJQUFJLHFCQU1zRSxDQUFDO0FBRTNFLFNBQVMsS0FBSyxDQUFDLE9BQW9CO0lBQ2pDLHFCQUFxQixHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsS0FBSyxDQUFDLG1CQUFtQixFQUFFLElBQUksRUFBRTtRQUNwRSxRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsU0FBUztRQUNULFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRLEVBQUksVUFBVTtLQUN2QixDQUFDLENBQUM7QUFDTCxDQUFDO0FBRUQsTUFBTSxVQUFVLGlCQUFpQixDQUFDLElBSWpDO0lBQ0MsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDbkIsTUFBTSxFQUFDLFVBQVUsRUFBRSxPQUFPLEVBQUUsR0FBRyxFQUFFLG1CQUFtQixFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRTlELElBQUksQ0FBQyxNQUFNLENBQ1AsQ0FBQyxDQUFDLEtBQUssQ0FBQyxNQUFNLEtBQUssQ0FBQyxFQUNwQixHQUFHLEVBQUUsQ0FBQyx1REFDRixDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLENBQUM7SUFDM0IsTUFBTSxTQUFTLEdBQXFCLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO0lBQzNDLElBQUksQ0FBQyxNQUFNLENBQ1AsWUFBWSxDQUFDLDhCQUE4QixDQUFDLE9BQU8sRUFBRSxTQUFTLENBQUMsRUFDL0QsR0FBRyxFQUFFLENBQUMsMkRBQTJEO1FBQzdELGVBQWUsT0FBTyxtQkFBbUIsU0FBUyxHQUFHLENBQUMsQ0FBQztJQUUvRCxNQUFNLFFBQVEsR0FBRyxZQUFZLENBQUMsaUJBQWlCLENBQzNDLENBQUMsQ0FBQyxLQUF5QyxFQUFFLFVBQVUsRUFBRSxPQUFPLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQ3hFLEdBQUcsQ0FBQyxDQUFDO0lBRVQsTUFBTSxNQUFNLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxRQUFRLENBQUMsUUFBUSxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUM5RCxNQUFNLE9BQU8sR0FBRyxPQUFPLENBQUMsVUFBVSxDQUFDLFFBQVEsQ0FBQyxRQUFRLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFFL0QscUJBQXFCLENBQ2pCLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQ2xDLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQ3ZDLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQ3hDLFFBQVEsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQ2pCLG1CQUFtQixFQUNuQixRQUFRLENBQUMsU0FBUyxFQUNsQixRQUFRLENBQUMsVUFBVSxFQUNuQixRQUFRLENBQUMsUUFBUSxFQUNqQixRQUFRLENBQUMsT0FBTyxFQUNoQixRQUFRLENBQUMsU0FBUyxFQUNsQixRQUFRLENBQUMsUUFBUSxFQUNqQixRQUFRLENBQUMsWUFBWSxFQUNyQixRQUFRLENBQUMsV0FBVyxFQUNwQixRQUFRLENBQUMsY0FBYyxFQUN2QixRQUFRLENBQUMsYUFBYSxFQUN0QixRQUFRLENBQUMscUJBQXFCLEVBQzlCLFFBQVEsQ0FBQyxvQkFBb0IsRUFDN0IsUUFBUSxDQUFDLE9BQU8sQ0FBQyxHQUFHLEVBQ3BCLFFBQVEsQ0FBQyxPQUFPLENBQUMsSUFBSSxDQUN4QixDQUFDO0lBQ0YsT0FBTyxDQUFDLE1BQU0sRUFBRSxPQUFPLENBQUMsQ0FBQztBQUMzQixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sdUJBQXVCLEdBQWlCO0lBQ25ELFVBQVUsRUFBRSxpQkFBaUI7SUFDN0IsV0FBVyxFQUFFLE1BQU07SUFDbkIsU0FBUyxFQUFFLEtBQUs7SUFDaEIsVUFBVSxFQUFFLGlCQUEwQztDQUN2RCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjMgR29vZ2xlIExMQy5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBNYXhQb29sV2l0aEFyZ21heCwgTWF4UG9vbFdpdGhBcmdtYXhBdHRycywgTWF4UG9vbFdpdGhBcmdtYXhJbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QmFja2VuZFdhc219IGZyb20gJy4uL2JhY2tlbmRfd2FzbSc7XG5cbmltcG9ydCB7Q3BwRFR5cGV9IGZyb20gJy4vdHlwZXMnO1xuXG5sZXQgd2FzbU1heFBvb2xXaXRoQXJnbWF4OiAoXG4gICAgeElkOiBudW1iZXIsIHBvb2xlZElkOiBudW1iZXIsIGluZGV4ZXNJZDogbnVtYmVyLCBkdHlwZTogbnVtYmVyLFxuICAgIGluY2x1ZGVCYXRjaEluZGV4OiBib29sZWFuLCBiYXRjaFNpemU6IG51bWJlciwgY2hhbm5lbFNpemU6IG51bWJlcixcbiAgICBpbkhlaWdodDogbnVtYmVyLCBpbldpZHRoOiBudW1iZXIsIG91dEhlaWdodDogbnVtYmVyLCBvdXRXaWR0aDogbnVtYmVyLFxuICAgIHN0cmlkZUhlaWdodDogbnVtYmVyLCBzdHJpZGVXaWR0aDogbnVtYmVyLCBkaWxhdGlvbkhlaWdodDogbnVtYmVyLFxuICAgIGRpbGF0aW9uV2lkdGg6IG51bWJlciwgZWZmZWN0aXZlRmlsdGVySGVpZ2h0OiBudW1iZXIsXG4gICAgZWZmZWN0aXZlRmlsdGVyV2lkdGg6IG51bWJlciwgcGFkVG9wOiBudW1iZXIsIHBhZExlZnQ6IG51bWJlcikgPT4gdm9pZDtcblxuZnVuY3Rpb24gc2V0dXAoYmFja2VuZDogQmFja2VuZFdhc20pIHtcbiAgd2FzbU1heFBvb2xXaXRoQXJnbWF4ID0gYmFja2VuZC53YXNtLmN3cmFwKCdNYXhQb29sV2l0aEFyZ21heCcsIG51bGwsIFtcbiAgICAnbnVtYmVyJywgICAvLyB4SWRcbiAgICAnbnVtYmVyJywgICAvLyBwb29sZWRJZFxuICAgICdudW1iZXInLCAgIC8vIGluZGV4ZXNJZFxuICAgICdudW1iZXInLCAgIC8vIGR0eXBlXG4gICAgJ2Jvb2xlYW4nLCAgLy8gaW5jbHVkZUJhdGNoSW5kZXhcbiAgICAnbnVtYmVyJywgICAvLyBiYXRjaFNpemVcbiAgICAnbnVtYmVyJywgICAvLyBjaGFubmVsU2l6ZVxuICAgICdudW1iZXInLCAgIC8vIGluSGVpZ2h0XG4gICAgJ251bWJlcicsICAgLy8gaW5XaWR0aFxuICAgICdudW1iZXInLCAgIC8vIG91dEhlaWdodFxuICAgICdudW1iZXInLCAgIC8vIG91dFdpZHRoXG4gICAgJ251bWJlcicsICAgLy8gc3RyaWRlSGVpZ2h0XG4gICAgJ251bWJlcicsICAgLy8gc3RyaWRlV2lkdGhcbiAgICAnbnVtYmVyJywgICAvLyBkaWxhdGlvbkhlaWdodFxuICAgICdudW1iZXInLCAgIC8vIGRpbGF0aW9uV2lkdGhcbiAgICAnbnVtYmVyJywgICAvLyBlZmZlY3RpdmVGaWx0ZXJIZWlnaHRcbiAgICAnbnVtYmVyJywgICAvLyBlZmZlY3RpdmVGaWx0ZXJXaWR0aFxuICAgICdudW1iZXInLCAgIC8vIHBhZFRvcFxuICAgICdudW1iZXInLCAgIC8vIHBhZExlZnRcbiAgXSk7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBtYXhQb29sV2l0aEFyZ21heChhcmdzOiB7XG4gIGlucHV0czogTWF4UG9vbFdpdGhBcmdtYXhJbnB1dHMsXG4gIGF0dHJzOiBNYXhQb29sV2l0aEFyZ21heEF0dHJzLFxuICBiYWNrZW5kOiBCYWNrZW5kV2FzbSxcbn0pOiBUZW5zb3JJbmZvW10ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eH0gPSBpbnB1dHM7XG4gIGNvbnN0IHtmaWx0ZXJTaXplLCBzdHJpZGVzLCBwYWQsIGluY2x1ZGVCYXRjaEluSW5kZXh9ID0gYXR0cnM7XG5cbiAgdXRpbC5hc3NlcnQoXG4gICAgICB4LnNoYXBlLmxlbmd0aCA9PT0gNCxcbiAgICAgICgpID0+IGBFcnJvciBpbiBtYXhQb29sOiBpbnB1dCBtdXN0IGJlIHJhbmsgNCBidXQgZ290IHJhbmsgJHtcbiAgICAgICAgICB4LnNoYXBlLmxlbmd0aH0uYCk7XG4gIGNvbnN0IGRpbGF0aW9uczogW251bWJlciwgbnVtYmVyXSA9IFsxLCAxXTtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICBiYWNrZW5kX3V0aWwuZWl0aGVyU3RyaWRlc09yRGlsYXRpb25zQXJlT25lKHN0cmlkZXMsIGRpbGF0aW9ucyksXG4gICAgICAoKSA9PiAnRXJyb3IgaW4gbWF4UG9vbDogRWl0aGVyIHN0cmlkZXMgb3IgZGlsYXRpb25zIG11c3QgYmUgMS4gJyArXG4gICAgICAgICAgYEdvdCBzdHJpZGVzICR7c3RyaWRlc30gYW5kIGRpbGF0aW9ucyAnJHtkaWxhdGlvbnN9J2ApO1xuXG4gIGNvbnN0IGNvbnZJbmZvID0gYmFja2VuZF91dGlsLmNvbXB1dGVQb29sMkRJbmZvKFxuICAgICAgeC5zaGFwZSBhcyBbbnVtYmVyLCBudW1iZXIsIG51bWJlciwgbnVtYmVyXSwgZmlsdGVyU2l6ZSwgc3RyaWRlcywgWzEsIDFdLFxuICAgICAgcGFkKTtcblxuICBjb25zdCBwb29sZWQgPSBiYWNrZW5kLm1ha2VPdXRwdXQoY29udkluZm8ub3V0U2hhcGUsIHguZHR5cGUpO1xuICBjb25zdCBpbmRleGVzID0gYmFja2VuZC5tYWtlT3V0cHV0KGNvbnZJbmZvLm91dFNoYXBlLCAnaW50MzInKTtcblxuICB3YXNtTWF4UG9vbFdpdGhBcmdtYXgoXG4gICAgICBiYWNrZW5kLmRhdGFJZE1hcC5nZXQoeC5kYXRhSWQpLmlkLFxuICAgICAgYmFja2VuZC5kYXRhSWRNYXAuZ2V0KHBvb2xlZC5kYXRhSWQpLmlkLFxuICAgICAgYmFja2VuZC5kYXRhSWRNYXAuZ2V0KGluZGV4ZXMuZGF0YUlkKS5pZCxcbiAgICAgIENwcERUeXBlW3guZHR5cGVdLFxuICAgICAgaW5jbHVkZUJhdGNoSW5JbmRleCxcbiAgICAgIGNvbnZJbmZvLmJhdGNoU2l6ZSxcbiAgICAgIGNvbnZJbmZvLmluQ2hhbm5lbHMsXG4gICAgICBjb252SW5mby5pbkhlaWdodCxcbiAgICAgIGNvbnZJbmZvLmluV2lkdGgsXG4gICAgICBjb252SW5mby5vdXRIZWlnaHQsXG4gICAgICBjb252SW5mby5vdXRXaWR0aCxcbiAgICAgIGNvbnZJbmZvLnN0cmlkZUhlaWdodCxcbiAgICAgIGNvbnZJbmZvLnN0cmlkZVdpZHRoLFxuICAgICAgY29udkluZm8uZGlsYXRpb25IZWlnaHQsXG4gICAgICBjb252SW5mby5kaWxhdGlvbldpZHRoLFxuICAgICAgY29udkluZm8uZWZmZWN0aXZlRmlsdGVySGVpZ2h0LFxuICAgICAgY29udkluZm8uZWZmZWN0aXZlRmlsdGVyV2lkdGgsXG4gICAgICBjb252SW5mby5wYWRJbmZvLnRvcCxcbiAgICAgIGNvbnZJbmZvLnBhZEluZm8ubGVmdCxcbiAgKTtcbiAgcmV0dXJuIFtwb29sZWQsIGluZGV4ZXNdO1xufVxuXG5leHBvcnQgY29uc3QgbWF4UG9vbFdpdGhBcmdtYXhDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogTWF4UG9vbFdpdGhBcmdtYXgsXG4gIGJhY2tlbmROYW1lOiAnd2FzbScsXG4gIHNldHVwRnVuYzogc2V0dXAsXG4gIGtlcm5lbEZ1bmM6IG1heFBvb2xXaXRoQXJnbWF4IGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==