UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

63 lines 11.2 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, MaxPoolWithArgmax, util } from '@tensorflow/tfjs-core'; import { CppDType } from './types'; let wasmMaxPoolWithArgmax; function setup(backend) { wasmMaxPoolWithArgmax = backend.wasm.cwrap('MaxPoolWithArgmax', null, [ 'number', 'number', 'number', 'number', 'boolean', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', // padLeft ]); } export function maxPoolWithArgmax(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { filterSize, strides, pad, includeBatchInIndex } = attrs; util.assert(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`); const dilations = [1, 1]; util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad); const pooled = backend.makeOutput(convInfo.outShape, x.dtype); const indexes = backend.makeOutput(convInfo.outShape, 'int32'); wasmMaxPoolWithArgmax(backend.dataIdMap.get(x.dataId).id, backend.dataIdMap.get(pooled.dataId).id, backend.dataIdMap.get(indexes.dataId).id, CppDType[x.dtype], includeBatchInIndex, convInfo.batchSize, convInfo.inChannels, convInfo.inHeight, convInfo.inWidth, convInfo.outHeight, convInfo.outWidth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.top, convInfo.padInfo.left); return [pooled, indexes]; } export const maxPoolWithArgmaxConfig = { kernelName: MaxPoolWithArgmax, backendName: 'wasm', setupFunc: setup, kernelFunc: maxPoolWithArgmax }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTWF4UG9vbFdpdGhBcmdtYXguanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2FzbS9zcmMva2VybmVscy9NYXhQb29sV2l0aEFyZ21heC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUE0QixpQkFBaUIsRUFBK0QsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJbkssT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUVqQyxJQUFJLHFCQU1zRSxDQUFDO0FBRTNFLFNBQVMsS0FBSyxDQUFDLE9BQW9CO0lBQ2pDLHFCQUFxQixHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsS0FBSyxDQUFDLG1CQUFtQixFQUFFLElBQUksRUFBRTtRQUNwRSxRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsU0FBUztRQUNULFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRLEVBQUksVUFBVTtLQUN2QixDQUFDLENBQUM7QUFDTCxDQUFDO0FBRUQsTUFBTSxVQUFVLGlCQUFpQixDQUFDLElBSWpDO0lBQ0MsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDbkIsTUFBTSxFQUFDLFVBQVUsRUFBRSxPQUFPLEVBQUUsR0FBRyxFQUFFLG1CQUFtQixFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRTlELElBQUksQ0FBQyxNQUFNLENBQ1AsQ0FBQyxDQUFDLEtBQUssQ0FBQyxNQUFNLEtBQUssQ0FBQyxFQUNwQixHQUFHLEVBQUUsQ0FBQyx1REFDRixDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sR0FBRyxDQUFDLENBQUM7SUFDM0IsTUFBTSxTQUFTLEdBQXFCLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO0lBQzNDLElBQUksQ0FBQyxNQUFNLENBQ1AsWUFBWSxDQUFDLDhCQUE4QixDQUFDLE9BQU8sRUFBRSxTQUFTLENBQUMsRUFDL0QsR0FBRyxFQUFFLENBQUMsMkRBQTJEO1FBQzdELGVBQWUsT0FBTyxtQkFBbUIsU0FBUyxHQUFHLENBQUMsQ0FBQztJQUUvRCxNQUFNLFFBQVEsR0FBRyxZQUFZLENBQUMsaUJBQWlCLENBQzNDLENBQUMsQ0FBQyxLQUF5QyxFQUFFLFVBQVUsRUFBRSxPQUFPLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQ3hFLEdBQUcsQ0FBQyxDQUFDO0lBRVQsTUFBTSxNQUFNLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxRQUFRLENBQUMsUUFBUSxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUM5RCxNQUFNLE9BQU8sR0FBRyxPQUFPLENBQUMsVUFBVSxDQUFDLFFBQVEsQ0FBQyxRQUFRLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFFL0QscUJBQXFCLENBQ2pCLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQ2xDLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQ3ZDLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQ3hDLFFBQVEsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQ2pCLG1CQUFtQixFQUNuQixRQUFRLENBQUMsU0FBUyxFQUNsQixRQUFRLENBQUMsVUFBVSxFQUNuQixRQUFRLENBQUMsUUFBUSxFQUNqQixRQUFRLENBQUMsT0FBTyxFQUNoQixRQUFRLENBQUMsU0FBUyxFQUNsQixRQUFRLENBQUMsUUFBUSxFQUNqQixRQUFRLENBQUMsWUFBWSxFQUNyQixRQUFRLENBQUMsV0FBVyxFQUNwQixRQUFRLENBQUMsY0FBYyxFQUN2QixRQUFRLENBQUMsYUFBYSxFQUN0QixRQUFRLENBQUMscUJBQXFCLEVBQzlCLFFBQVEsQ0FBQyxvQkFBb0IsRUFDN0IsUUFBUSxDQUFDLE9BQU8sQ0FBQyxHQUFHLEVBQ3BCLFFBQVEsQ0FBQyxPQUFPLENBQUMsSUFBSSxDQUN4QixDQUFDO0lBQ0YsT0FBTyxDQUFDLE1BQU0sRUFBRSxPQUFPLENBQUMsQ0FBQztBQUMzQixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sdUJBQXVCLEdBQWlCO0lBQ25ELFVBQVUsRUFBRSxpQkFBaUI7SUFDN0IsV0FBVyxFQUFFLE1BQU07SUFDbkIsU0FBUyxFQUFFLEtBQUs7SUFDaEIsVUFBVSxFQUFFLGlCQUEwQztDQUN2RCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjMgR29vZ2xlIExMQy5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBNYXhQb29sV2l0aEFyZ21heCwgTWF4UG9vbFdpdGhBcmdtYXhBdHRycywgTWF4UG9vbFdpdGhBcmdtYXhJbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QmFja2VuZFdhc219IGZyb20gJy4uL2JhY2tlbmRfd2FzbSc7XG5cbmltcG9ydCB7Q3BwRFR5cGV9IGZyb20gJy4vdHlwZXMnO1xuXG5sZXQgd2FzbU1heFBvb2xXaXRoQXJnbWF4OiAoXG4gICAgeElkOiBudW1iZXIsIHBvb2xlZElkOiBudW1iZXIsIGluZGV4ZXNJZDogbnVtYmVyLCBkdHlwZTogbnVtYmVyLFxuICAgIGluY2x1ZGVCYXRjaEluZGV4OiBib29sZWFuLCBiYXRjaFNpemU6IG51bWJlciwgY2hhbm5lbFNpemU6IG51bWJlcixcbiAgICBpbkhlaWdodDogbnVtYmVyLCBpbldpZHRoOiBudW1iZXIsIG91dEhlaWdodDogbnVtYmVyLCBvdXRXaWR0aDogbnVtYmVyLFxuICAgIHN0cmlkZUhlaWdodDogbnVtYmVyLCBzdHJpZGVXaWR0aDogbnVtYmVyLCBkaWxhdGlvbkhlaWdodDogbnVtYmVyLFxuICAgIGRpbGF0aW9uV2lkdGg6IG51bWJlciwgZWZmZWN0aXZlRmlsdGVySGVpZ2h0OiBudW1iZXIsXG4gICAgZWZmZWN0aXZlRmlsdGVyV2lkdGg6IG51bWJlciwgcGFkVG9wOiBudW1iZXIsIHBhZExlZnQ6IG51bWJlcikgPT4gdm9pZDtcblxuZnVuY3Rpb24gc2V0dXAoYmFja2VuZDogQmFja2VuZFdhc20pIHtcbiAgd2FzbU1heFBvb2xXaXRoQXJnbWF4ID0gYmFja2VuZC53YXNtLmN3cmFwKCdNYXhQb29sV2l0aEFyZ21heCcsIG51bGwsIFtcbiAgICAnbnVtYmVyJywgICAvLyB4SWRcbiAgICAnbnVtYmVyJywgICAvLyBwb29sZWRJZFxuICAgICdudW1iZXInLCAgIC8vIGluZGV4ZXNJZFxuICAgICdudW1iZXInLCAgIC8vIGR0eXBlXG4gICAgJ2Jvb2xlYW4nLCAgLy8gaW5jbHVkZUJhdGNoSW5kZXhcbiAgICAnbnVtYmVyJywgICAvLyBiYXRjaFNpemVcbiAgICAnbnVtYmVyJywgICAvLyBjaGFubmVsU2l6ZVxuICAgICdudW1iZXInLCAgIC8vIGluSGVpZ2h0XG4gICAgJ251bWJlcicsICAgLy8gaW5XaWR0aFxuICAgICdudW1iZXInLCAgIC8vIG91dEhlaWdodFxuICAgICdudW1iZXInLCAgIC8vIG91dFdpZHRoXG4gICAgJ251bWJlcicsICAgLy8gc3RyaWRlSGVpZ2h0XG4gICAgJ251bWJlcicsICAgLy8gc3RyaWRlV2lkdGhcbiAgICAnbnVtYmVyJywgICAvLyBkaWxhdGlvbkhlaWdodFxuICAgICdudW1iZXInLCAgIC8vIGRpbGF0aW9uV2lkdGhcbiAgICAnbnVtYmVyJywgICAvLyBlZmZlY3RpdmVGaWx0ZXJIZWlnaHRcbiAgICAnbnVtYmVyJywgICAvLyBlZmZlY3RpdmVGaWx0ZXJXaWR0aFxuICAgICdudW1iZXInLCAgIC8vIHBhZFRvcFxuICAgICdudW1iZXInLCAgIC8vIHBhZExlZnRcbiAgXSk7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBtYXhQb29sV2l0aEFyZ21heChhcmdzOiB7XG4gIGlucHV0czogTWF4UG9vbFdpdGhBcmdtYXhJbnB1dHMsXG4gIGF0dHJzOiBNYXhQb29sV2l0aEFyZ21heEF0dHJzLFxuICBiYWNrZW5kOiBCYWNrZW5kV2FzbSxcbn0pOiBUZW5zb3JJbmZvW10ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eH0gPSBpbnB1dHM7XG4gIGNvbnN0IHtmaWx0ZXJTaXplLCBzdHJpZGVzLCBwYWQsIGluY2x1ZGVCYXRjaEluSW5kZXh9ID0gYXR0cnM7XG5cbiAgdXRpbC5hc3NlcnQoXG4gICAgICB4LnNoYXBlLmxlbmd0aCA9PT0gNCxcbiAgICAgICgpID0+IGBFcnJvciBpbiBtYXhQb29sOiBpbnB1dCBtdXN0IGJlIHJhbmsgNCBidXQgZ290IHJhbmsgJHtcbiAgICAgICAgICB4LnNoYXBlLmxlbmd0aH0uYCk7XG4gIGNvbnN0IGRpbGF0aW9uczogW251bWJlciwgbnVtYmVyXSA9IFsxLCAxXTtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICBiYWNrZW5kX3V0aWwuZWl0aGVyU3RyaWRlc09yRGlsYXRpb25zQXJlT25lKHN0cmlkZXMsIGRpbGF0aW9ucyksXG4gICAgICAoKSA9PiAnRXJyb3IgaW4gbWF4UG9vbDogRWl0aGVyIHN0cmlkZXMgb3IgZGlsYXRpb25zIG11c3QgYmUgMS4gJyArXG4gICAgICAgICAgYEdvdCBzdHJpZGVzICR7c3RyaWRlc30gYW5kIGRpbGF0aW9ucyAnJHtkaWxhdGlvbnN9J2ApO1xuXG4gIGNvbnN0IGNvbnZJbmZvID0gYmFja2VuZF91dGlsLmNvbXB1dGVQb29sMkRJbmZvKFxuICAgICAgeC5zaGFwZSBhcyBbbnVtYmVyLCBudW1iZXIsIG51bWJlciwgbnVtYmVyXSwgZmlsdGVyU2l6ZSwgc3RyaWRlcywgWzEsIDFdLFxuICAgICAgcGFkKTtcblxuICBjb25zdCBwb29sZWQgPSBiYWNrZW5kLm1ha2VPdXRwdXQoY29udkluZm8ub3V0U2hhcGUsIHguZHR5cGUpO1xuICBjb25zdCBpbmRleGVzID0gYmFja2VuZC5tYWtlT3V0cHV0KGNvbnZJbmZvLm91dFNoYXBlLCAnaW50MzInKTtcblxuICB3YXNtTWF4UG9vbFdpdGhBcmdtYXgoXG4gICAgICBiYWNrZW5kLmRhdGFJZE1hcC5nZXQoeC5kYXRhSWQpLmlkLFxuICAgICAgYmFja2VuZC5kYXRhSWRNYXAuZ2V0KHBvb2xlZC5kYXRhSWQpLmlkLFxuICAgICAgYmFja2VuZC5kYXRhSWRNYXAuZ2V0KGluZGV4ZXMuZGF0YUlkKS5pZCxcbiAgICAgIENwcERUeXBlW3guZHR5cGVdLFxuICAgICAgaW5jbHVkZUJhdGNoSW5JbmRleCxcbiAgICAgIGNvbnZJbmZvLmJhdGNoU2l6ZSxcbiAgICAgIGNvbnZJbmZvLmluQ2hhbm5lbHMsXG4gICAgICBjb252SW5mby5pbkhlaWdodCxcbiAgICAgIGNvbnZJbmZvLmluV2lkdGgsXG4gICAgICBjb252SW5mby5vdXRIZWlnaHQsXG4gICAgICBjb252SW5mby5vdXRXaWR0aCxcbiAgICAgIGNvbnZJbmZvLnN0cmlkZUhlaWdodCxcbiAgICAgIGNvbnZJbmZvLnN0cmlkZVdpZHRoLFxuICAgICAgY29udkluZm8uZGlsYXRpb25IZWlnaHQsXG4gICAgICBjb252SW5mby5kaWxhdGlvbldpZHRoLFxuICAgICAgY29udkluZm8uZWZmZWN0aXZlRmlsdGVySGVpZ2h0LFxuICAgICAgY29udkluZm8uZWZmZWN0aXZlRmlsdGVyV2lkdGgsXG4gICAgICBjb252SW5mby5wYWRJbmZvLnRvcCxcbiAgICAgIGNvbnZJbmZvLnBhZEluZm8ubGVmdCxcbiAgKTtcbiAgcmV0dXJuIFtwb29sZWQsIGluZGV4ZXNdO1xufVxuXG5leHBvcnQgY29uc3QgbWF4UG9vbFdpdGhBcmdtYXhDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogTWF4UG9vbFdpdGhBcmdtYXgsXG4gIGJhY2tlbmROYW1lOiAnd2FzbScsXG4gIHNldHVwRnVuYzogc2V0dXAsXG4gIGtlcm5lbEZ1bmM6IG1heFBvb2xXaXRoQXJnbWF4IGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==