@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
60 lines • 9.17 kB
JavaScript
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { AvgPoolGrad, backend_util } from '@tensorflow/tfjs-core';
let wasmAvgPoolGrad;
function setup(backend) {
wasmAvgPoolGrad = backend.wasm.cwrap('AvgPoolGrad', null, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // filterWidth
]);
}
export function avgPoolGrad(args) {
const { inputs, backend, attrs } = args;
const { dy, input } = inputs;
const { filterSize, strides, pad } = attrs;
const convInfo = backend_util.computePool2DInfo(input.shape, filterSize, strides,
/*dilations=*/ 1, pad);
const dx = backend.makeOutput(input.shape, input.dtype);
wasmAvgPoolGrad(backend.dataIdMap.get(dy.dataId).id, backend.dataIdMap.get(dx.dataId).id, convInfo.batchSize,
// Since Pool ops (AvgPool and MaxPool) support 2D filter only, in
// channels should always equal to out channels.
/*channelSize=*/ convInfo.inChannels, convInfo.inHeight, convInfo.inWidth, convInfo.outHeight, convInfo.outWidth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.top, convInfo.padInfo.left, convInfo.filterHeight, convInfo.filterWidth);
return dx;
}
export const avgPoolGradConfig = {
kernelName: AvgPoolGrad,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: avgPoolGrad
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQXZnUG9vbEdyYWQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2FzbS9zcmMva2VybmVscy9BdmdQb29sR3JhZC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsV0FBVyxFQUF1QyxZQUFZLEVBQXVDLE1BQU0sdUJBQXVCLENBQUM7QUFJM0ksSUFBSSxlQU1rRCxDQUFDO0FBRXZELFNBQVMsS0FBSyxDQUFDLE9BQW9CO0lBQ2pDLGVBQWUsR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEtBQUssQ0FBQyxhQUFhLEVBQUUsSUFBSSxFQUFFO1FBQ3hELFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUSxFQUFHLGNBQWM7S0FDMUIsQ0FBQyxDQUFDO0FBQ0wsQ0FBQztBQUVELE1BQU0sVUFBVSxXQUFXLENBQUMsSUFJM0I7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLEVBQUUsRUFBRSxLQUFLLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDM0IsTUFBTSxFQUFDLFVBQVUsRUFBRSxPQUFPLEVBQUUsR0FBRyxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRXpDLE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyxpQkFBaUIsQ0FDM0MsS0FBSyxDQUFDLEtBQXlDLEVBQUUsVUFBVSxFQUFFLE9BQU87SUFDcEUsY0FBYyxDQUFBLENBQUMsRUFBRSxHQUFHLENBQUMsQ0FBQztJQUMxQixNQUFNLEVBQUUsR0FBRyxPQUFPLENBQUMsVUFBVSxDQUFDLEtBQUssQ0FBQyxLQUFLLEVBQUUsS0FBSyxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBRXhELGVBQWUsQ0FDWCxPQUFPLENBQUMsU0FBUyxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsTUFBTSxDQUFDLENBQUMsRUFBRSxFQUNuQyxPQUFPLENBQUMsU0FBUyxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsTUFBTSxDQUFDLENBQUMsRUFBRSxFQUNuQyxRQUFRLENBQUMsU0FBUztJQUNsQixrRUFBa0U7SUFDbEUsZ0RBQWdEO0lBQ2hELGdCQUFnQixDQUFBLFFBQVEsQ0FBQyxVQUFVLEVBQ25DLFFBQVEsQ0FBQyxRQUFRLEVBQ2pCLFFBQVEsQ0FBQyxPQUFPLEVBQ2hCLFFBQVEsQ0FBQyxTQUFTLEVBQ2xCLFFBQVEsQ0FBQyxRQUFRLEVBQ2pCLFFBQVEsQ0FBQyxZQUFZLEVBQ3JCLFFBQVEsQ0FBQyxXQUFXLEVBQ3BCLFFBQVEsQ0FBQyxjQUFjLEVBQ3ZCLFFBQVEsQ0FBQyxhQUFhLEVBQ3RCLFFBQVEsQ0FBQyxxQkFBcUIsRUFDOUIsUUFBUSxDQUFDLG9CQUFvQixFQUM3QixRQUFRLENBQUMsT0FBTyxDQUFDLEdBQUcsRUFDcEIsUUFBUSxDQUFDLE9BQU8sQ0FBQyxJQUFJLEVBQ3JCLFFBQVEsQ0FBQyxZQUFZLEVBQ3JCLFFBQVEsQ0FBQyxXQUFXLENBQ3ZCLENBQUM7SUFDRixPQUFPLEVBQUUsQ0FBQztBQUNaLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxpQkFBaUIsR0FBaUI7SUFDN0MsVUFBVSxFQUFFLFdBQVc7SUFDdkIsV0FBVyxFQUFFLE1BQU07SUFDbkIsU0FBUyxFQUFFLEtBQUs7SUFDaEIsVUFBVSxFQUFFLFdBQW9DO0NBQ2pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMyBHb29nbGUgTExDLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7QXZnUG9vbEdyYWQsIEF2Z1Bvb2xHcmFkQXR0cnMsIEF2Z1Bvb2xHcmFkSW5wdXRzLCBiYWNrZW5kX3V0aWwsIEtlcm5lbENvbmZpZywgS2VybmVsRnVuYywgVGVuc29ySW5mb30gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtCYWNrZW5kV2FzbX0gZnJvbSAnLi4vYmFja2VuZF93YXNtJztcblxubGV0IHdhc21BdmdQb29sR3JhZDogKFxuICAgIGR5SWQ6IG51bWJlciwgZHhJZDogbnVtYmVyLCBiYXRjaFNpemU6IG51bWJlciwgY2hhbm5lbFNpemU6IG51bWJlcixcbiAgICBpbkhlaWdodDogbnVtYmVyLCBpbldpZHRoOiBudW1iZXIsIG91dEhlaWdodDogbnVtYmVyLCBvdXRXaWR0aDogbnVtYmVyLFxuICAgIHN0cmlkZUhlaWdodDogbnVtYmVyLCBzdHJpZGVXaWR0aDogbnVtYmVyLCBkaWxhdGlvbkhlaWdodDogbnVtYmVyLFxuICAgIGRpbGF0aW9uV2lkdGg6IG51bWJlciwgZWZmZWN0aXZlRmlsdGVySGVpZ2h0OiBudW1iZXIsXG4gICAgZWZmZWN0aXZlRmlsdGVyV2lkdGg6IG51bWJlciwgcGFkVG9wOiBudW1iZXIsIHBhZExlZnQ6IG51bWJlcixcbiAgICBmaWx0ZXJIZWlnaHQ6IG51bWJlciwgZmlsdGVyV2lkdGg6IG51bWJlcikgPT4gdm9pZDtcblxuZnVuY3Rpb24gc2V0dXAoYmFja2VuZDogQmFja2VuZFdhc20pIHtcbiAgd2FzbUF2Z1Bvb2xHcmFkID0gYmFja2VuZC53YXNtLmN3cmFwKCdBdmdQb29sR3JhZCcsIG51bGwsIFtcbiAgICAnbnVtYmVyJywgIC8vIGR5SWRcbiAgICAnbnVtYmVyJywgIC8vIGR4SWRcbiAgICAnbnVtYmVyJywgIC8vIGJhdGNoU2l6ZVxuICAgICdudW1iZXInLCAgLy8gY2hhbm5lbFNpemVcbiAgICAnbnVtYmVyJywgIC8vIGluSGVpZ2h0XG4gICAgJ251bWJlcicsICAvLyBpbldpZHRoXG4gICAgJ251bWJlcicsICAvLyBvdXRIZWlnaHRcbiAgICAnbnVtYmVyJywgIC8vIG91dFdpZHRoXG4gICAgJ251bWJlcicsICAvLyBzdHJpZGVIZWlnaHRcbiAgICAnbnVtYmVyJywgIC8vIHN0cmlkZVdpZHRoXG4gICAgJ251bWJlcicsICAvLyBkaWxhdGlvbkhlaWdodFxuICAgICdudW1iZXInLCAgLy8gZGlsYXRpb25XaWR0aFxuICAgICdudW1iZXInLCAgLy8gZWZmZWN0aXZlRmlsdGVySGVpZ2h0XG4gICAgJ251bWJlcicsICAvLyBlZmZlY3RpdmVGaWx0ZXJXaWR0aFxuICAgICdudW1iZXInLCAgLy8gcGFkVG9wXG4gICAgJ251bWJlcicsICAvLyBwYWRMZWZ0XG4gICAgJ251bWJlcicsICAvLyBmaWx0ZXJIZWlnaHRcbiAgICAnbnVtYmVyJywgIC8vIGZpbHRlcldpZHRoXG4gIF0pO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gYXZnUG9vbEdyYWQoYXJnczoge1xuICBpbnB1dHM6IEF2Z1Bvb2xHcmFkSW5wdXRzLFxuICBhdHRyczogQXZnUG9vbEdyYWRBdHRycyxcbiAgYmFja2VuZDogQmFja2VuZFdhc20sXG59KTogVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHtkeSwgaW5wdXR9ID0gaW5wdXRzO1xuICBjb25zdCB7ZmlsdGVyU2l6ZSwgc3RyaWRlcywgcGFkfSA9IGF0dHJzO1xuXG4gIGNvbnN0IGNvbnZJbmZvID0gYmFja2VuZF91dGlsLmNvbXB1dGVQb29sMkRJbmZvKFxuICAgICAgaW5wdXQuc2hhcGUgYXMgW251bWJlciwgbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sIGZpbHRlclNpemUsIHN0cmlkZXMsXG4gICAgICAvKmRpbGF0aW9ucz0qLzEsIHBhZCk7XG4gIGNvbnN0IGR4ID0gYmFja2VuZC5tYWtlT3V0cHV0KGlucHV0LnNoYXBlLCBpbnB1dC5kdHlwZSk7XG5cbiAgd2FzbUF2Z1Bvb2xHcmFkKFxuICAgICAgYmFja2VuZC5kYXRhSWRNYXAuZ2V0KGR5LmRhdGFJZCkuaWQsXG4gICAgICBiYWNrZW5kLmRhdGFJZE1hcC5nZXQoZHguZGF0YUlkKS5pZCxcbiAgICAgIGNvbnZJbmZvLmJhdGNoU2l6ZSxcbiAgICAgIC8vIFNpbmNlIFBvb2wgb3BzIChBdmdQb29sIGFuZCBNYXhQb29sKSBzdXBwb3J0IDJEIGZpbHRlciBvbmx5LCBpblxuICAgICAgLy8gY2hhbm5lbHMgc2hvdWxkIGFsd2F5cyBlcXVhbCB0byBvdXQgY2hhbm5lbHMuXG4gICAgICAvKmNoYW5uZWxTaXplPSovY29udkluZm8uaW5DaGFubmVscyxcbiAgICAgIGNvbnZJbmZvLmluSGVpZ2h0LFxuICAgICAgY29udkluZm8uaW5XaWR0aCxcbiAgICAgIGNvbnZJbmZvLm91dEhlaWdodCxcbiAgICAgIGNvbnZJbmZvLm91dFdpZHRoLFxuICAgICAgY29udkluZm8uc3RyaWRlSGVpZ2h0LFxuICAgICAgY29udkluZm8uc3RyaWRlV2lkdGgsXG4gICAgICBjb252SW5mby5kaWxhdGlvbkhlaWdodCxcbiAgICAgIGNvbnZJbmZvLmRpbGF0aW9uV2lkdGgsXG4gICAgICBjb252SW5mby5lZmZlY3RpdmVGaWx0ZXJIZWlnaHQsXG4gICAgICBjb252SW5mby5lZmZlY3RpdmVGaWx0ZXJXaWR0aCxcbiAgICAgIGNvbnZJbmZvLnBhZEluZm8udG9wLFxuICAgICAgY29udkluZm8ucGFkSW5mby5sZWZ0LFxuICAgICAgY29udkluZm8uZmlsdGVySGVpZ2h0LFxuICAgICAgY29udkluZm8uZmlsdGVyV2lkdGgsXG4gICk7XG4gIHJldHVybiBkeDtcbn1cblxuZXhwb3J0IGNvbnN0IGF2Z1Bvb2xHcmFkQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IEF2Z1Bvb2xHcmFkLFxuICBiYWNrZW5kTmFtZTogJ3dhc20nLFxuICBzZXR1cEZ1bmM6IHNldHVwLFxuICBrZXJuZWxGdW5jOiBhdmdQb29sR3JhZCBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=