@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
64 lines • 10.9 kB
JavaScript
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { backend_util, Dilation2DBackpropFilter } from '@tensorflow/tfjs-core';
import { CppDType } from './types';
let wasmDilation2DBackpropFilter;
function setup(backend) {
wasmDilation2DBackpropFilter =
backend.wasm.cwrap(Dilation2DBackpropFilter, null, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // padLeft
]);
}
export function dilation2DBackpropFilter(args) {
const { inputs, backend, attrs } = args;
const { x, filter, dy } = inputs;
const { strides, pad, dilations } = attrs;
if (x.dtype !== filter.dtype || x.dtype !== dy.dtype) {
throw new Error(`Dilation2DBackpropFilter error: x must have the same dtype as filter and dy. Got ${x.dtype}, ${filter.dtype}, and ${dy.dtype}`);
}
const dilationInfo = backend_util.computeDilation2DInfo(x.shape, filter.shape, strides, pad,
/*dataFormat=*/ 'NHWC', dilations);
const gradients = backend.makeOutput(filter.shape, filter.dtype);
wasmDilation2DBackpropFilter(backend.dataIdMap.get(x.dataId).id, backend.dataIdMap.get(filter.dataId).id, backend.dataIdMap.get(dy.dataId).id, backend.dataIdMap.get(gradients.dataId).id, CppDType[x.dtype], dilationInfo.batchSize,
/*depth=*/ dilationInfo.inChannels, dilationInfo.inHeight, dilationInfo.inWidth, dilationInfo.outHeight, dilationInfo.outWidth, dilationInfo.strideHeight, dilationInfo.strideWidth, dilationInfo.dilationHeight, dilationInfo.dilationWidth, dilationInfo.filterHeight, dilationInfo.filterWidth, dilationInfo.padInfo.top, dilationInfo.padInfo.left);
return gradients;
}
export const dilation2DBackpropFilterConfig = {
kernelName: Dilation2DBackpropFilter,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: dilation2DBackpropFilter
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRGlsYXRpb24yREJhY2twcm9wRmlsdGVyLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdhc20vc3JjL2tlcm5lbHMvRGlsYXRpb24yREJhY2twcm9wRmlsdGVyLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQW1CLHdCQUF3QixFQUEyRCxNQUFNLHVCQUF1QixDQUFDO0FBSXhKLE9BQU8sRUFBQyxRQUFRLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFFakMsSUFBSSw0QkFNd0IsQ0FBQztBQUU3QixTQUFTLEtBQUssQ0FBQyxPQUFvQjtJQUNqQyw0QkFBNEI7UUFDeEIsT0FBTyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsd0JBQXdCLEVBQUUsSUFBSSxFQUFFO1lBQ2pELFFBQVE7WUFDUixRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVEsRUFBRyxVQUFVO1NBQ3RCLENBQUMsQ0FBQztBQUNULENBQUM7QUFFRCxNQUFNLFVBQVUsd0JBQXdCLENBQUMsSUFJeEM7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUUsRUFBRSxFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQy9CLE1BQU0sRUFBQyxPQUFPLEVBQUUsR0FBRyxFQUFFLFNBQVMsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUV4QyxJQUFJLENBQUMsQ0FBQyxLQUFLLEtBQUssTUFBTSxDQUFDLEtBQUssSUFBSSxDQUFDLENBQUMsS0FBSyxLQUFLLEVBQUUsQ0FBQyxLQUFLLEVBQUU7UUFDcEQsTUFBTSxJQUFJLEtBQUssQ0FDWCxvRkFDSSxDQUFDLENBQUMsS0FBSyxLQUFLLE1BQU0sQ0FBQyxLQUFLLFNBQVMsRUFBRSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7S0FDdEQ7SUFFRCxNQUFNLFlBQVksR0FBRyxZQUFZLENBQUMscUJBQXFCLENBQ25ELENBQUMsQ0FBQyxLQUF5QyxFQUMzQyxNQUFNLENBQUMsS0FBaUMsRUFBRSxPQUFPLEVBQUUsR0FBRztJQUN0RCxlQUFlLENBQUEsTUFBTSxFQUFFLFNBQVMsQ0FBQyxDQUFDO0lBRXRDLE1BQU0sU0FBUyxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQUMsTUFBTSxDQUFDLEtBQUssRUFBRSxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUM7SUFFakUsNEJBQTRCLENBQ3hCLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQ2xDLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQ3ZDLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQ25DLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLEVBQzFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQ2pCLFlBQVksQ0FBQyxTQUFTO0lBQ3RCLFVBQVUsQ0FBQSxZQUFZLENBQUMsVUFBVSxFQUNqQyxZQUFZLENBQUMsUUFBUSxFQUNyQixZQUFZLENBQUMsT0FBTyxFQUNwQixZQUFZLENBQUMsU0FBUyxFQUN0QixZQUFZLENBQUMsUUFBUSxFQUNyQixZQUFZLENBQUMsWUFBWSxFQUN6QixZQUFZLENBQUMsV0FBVyxFQUN4QixZQUFZLENBQUMsY0FBYyxFQUMzQixZQUFZLENBQUMsYUFBYSxFQUMxQixZQUFZLENBQUMsWUFBWSxFQUN6QixZQUFZLENBQUMsV0FBVyxFQUN4QixZQUFZLENBQUMsT0FBTyxDQUFDLEdBQUcsRUFDeEIsWUFBWSxDQUFDLE9BQU8sQ0FBQyxJQUFJLENBQzVCLENBQUM7SUFDRixPQUFPLFNBQVMsQ0FBQztBQUNuQixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sOEJBQThCLEdBQWlCO0lBQzFELFVBQVUsRUFBRSx3QkFBd0I7SUFDcEMsV0FBVyxFQUFFLE1BQU07SUFDbkIsU0FBUyxFQUFFLEtBQUs7SUFDaEIsVUFBVSxFQUFFLHdCQUFpRDtDQUM5RCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjMgR29vZ2xlIExMQy5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgRGlsYXRpb24yREF0dHJzLCBEaWxhdGlvbjJEQmFja3Byb3BGaWx0ZXIsIEtlcm5lbENvbmZpZywgS2VybmVsRnVuYywgVGVuc29yM0QsIFRlbnNvcjRELCBUZW5zb3JJbmZvfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge0JhY2tlbmRXYXNtfSBmcm9tICcuLi9iYWNrZW5kX3dhc20nO1xuXG5pbXBvcnQge0NwcERUeXBlfSBmcm9tICcuL3R5cGVzJztcblxubGV0IHdhc21EaWxhdGlvbjJEQmFja3Byb3BGaWx0ZXI6IChcbiAgICB4SWQ6IG51bWJlciwgZmlsdGVySWQ6IG51bWJlciwgZHlJZDogbnVtYmVyLCBncmFkSWQ6IG51bWJlciwgZHR5cGU6IG51bWJlcixcbiAgICBiYXRjaDogbnVtYmVyLCBkZXB0aDogbnVtYmVyLCBpbkhlaWdodDogbnVtYmVyLCBpbldpZHRoOiBudW1iZXIsXG4gICAgb3V0SGVpZ2h0OiBudW1iZXIsIG91dFdpZHRoOiBudW1iZXIsIHN0cmlkZUhlaWdodDogbnVtYmVyLFxuICAgIHN0cmlkZVdpZHRoOiBudW1iZXIsIGRpbGF0aW9uSGVpZ2h0OiBudW1iZXIsIGRpbGF0aW9uV2lkdGg6IG51bWJlcixcbiAgICBmaWx0ZXJIZWlnaHQ6IG51bWJlciwgZmlsdGVyV2lkdGg6IG51bWJlciwgcGFkVG9wOiBudW1iZXIsXG4gICAgcGFkTGVmdDogbnVtYmVyKSA9PiB2b2lkO1xuXG5mdW5jdGlvbiBzZXR1cChiYWNrZW5kOiBCYWNrZW5kV2FzbSkge1xuICB3YXNtRGlsYXRpb24yREJhY2twcm9wRmlsdGVyID1cbiAgICAgIGJhY2tlbmQud2FzbS5jd3JhcChEaWxhdGlvbjJEQmFja3Byb3BGaWx0ZXIsIG51bGwsIFtcbiAgICAgICAgJ251bWJlcicsICAvLyB4SWRcbiAgICAgICAgJ251bWJlcicsICAvLyBmaWx0ZXJJZFxuICAgICAgICAnbnVtYmVyJywgIC8vIGR5SWRcbiAgICAgICAgJ251bWJlcicsICAvLyBncmFkSWRcbiAgICAgICAgJ251bWJlcicsICAvLyBkdHlwZVxuICAgICAgICAnbnVtYmVyJywgIC8vIGJhdGNoXG4gICAgICAgICdudW1iZXInLCAgLy8gZGVwdGhcbiAgICAgICAgJ251bWJlcicsICAvLyBpbkhlaWdodFxuICAgICAgICAnbnVtYmVyJywgIC8vIGluV2lkdGhcbiAgICAgICAgJ251bWJlcicsICAvLyBvdXRIZWlnaHRcbiAgICAgICAgJ251bWJlcicsICAvLyBvdXRXaWR0aFxuICAgICAgICAnbnVtYmVyJywgIC8vIHN0cmlkZUhlaWdodFxuICAgICAgICAnbnVtYmVyJywgIC8vIHN0cmlkZVdpZHRoXG4gICAgICAgICdudW1iZXInLCAgLy8gZGlsYXRpb25IZWlnaHRcbiAgICAgICAgJ251bWJlcicsICAvLyBkaWxhdGlvbldpZHRoXG4gICAgICAgICdudW1iZXInLCAgLy8gZmlsdGVySGVpZ2h0XG4gICAgICAgICdudW1iZXInLCAgLy8gZmlsdGVyV2lkdGhcbiAgICAgICAgJ251bWJlcicsICAvLyBwYWRUb3BcbiAgICAgICAgJ251bWJlcicsICAvLyBwYWRMZWZ0XG4gICAgICBdKTtcbn1cblxuZXhwb3J0IGZ1bmN0aW9uIGRpbGF0aW9uMkRCYWNrcHJvcEZpbHRlcihhcmdzOiB7XG4gIGlucHV0czoge3g6IFRlbnNvcjRELCBmaWx0ZXI6IFRlbnNvcjNELCBkeTogVGVuc29yNER9LFxuICBhdHRyczogRGlsYXRpb24yREF0dHJzLFxuICBiYWNrZW5kOiBCYWNrZW5kV2FzbSxcbn0pOiBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZCwgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge3gsIGZpbHRlciwgZHl9ID0gaW5wdXRzO1xuICBjb25zdCB7c3RyaWRlcywgcGFkLCBkaWxhdGlvbnN9ID0gYXR0cnM7XG5cbiAgaWYgKHguZHR5cGUgIT09IGZpbHRlci5kdHlwZSB8fCB4LmR0eXBlICE9PSBkeS5kdHlwZSkge1xuICAgIHRocm93IG5ldyBFcnJvcihcbiAgICAgICAgYERpbGF0aW9uMkRCYWNrcHJvcEZpbHRlciBlcnJvcjogeCBtdXN0IGhhdmUgdGhlIHNhbWUgZHR5cGUgYXMgZmlsdGVyIGFuZCBkeS4gR290ICR7XG4gICAgICAgICAgICB4LmR0eXBlfSwgJHtmaWx0ZXIuZHR5cGV9LCBhbmQgJHtkeS5kdHlwZX1gKTtcbiAgfVxuXG4gIGNvbnN0IGRpbGF0aW9uSW5mbyA9IGJhY2tlbmRfdXRpbC5jb21wdXRlRGlsYXRpb24yREluZm8oXG4gICAgICB4LnNoYXBlIGFzIFtudW1iZXIsIG51bWJlciwgbnVtYmVyLCBudW1iZXJdLFxuICAgICAgZmlsdGVyLnNoYXBlIGFzIFtudW1iZXIsIG51bWJlciwgbnVtYmVyXSwgc3RyaWRlcywgcGFkLFxuICAgICAgLypkYXRhRm9ybWF0PSovJ05IV0MnLCBkaWxhdGlvbnMpO1xuXG4gIGNvbnN0IGdyYWRpZW50cyA9IGJhY2tlbmQubWFrZU91dHB1dChmaWx0ZXIuc2hhcGUsIGZpbHRlci5kdHlwZSk7XG5cbiAgd2FzbURpbGF0aW9uMkRCYWNrcHJvcEZpbHRlcihcbiAgICAgIGJhY2tlbmQuZGF0YUlkTWFwLmdldCh4LmRhdGFJZCkuaWQsXG4gICAgICBiYWNrZW5kLmRhdGFJZE1hcC5nZXQoZmlsdGVyLmRhdGFJZCkuaWQsXG4gICAgICBiYWNrZW5kLmRhdGFJZE1hcC5nZXQoZHkuZGF0YUlkKS5pZCxcbiAgICAgIGJhY2tlbmQuZGF0YUlkTWFwLmdldChncmFkaWVudHMuZGF0YUlkKS5pZCxcbiAgICAgIENwcERUeXBlW3guZHR5cGVdLFxuICAgICAgZGlsYXRpb25JbmZvLmJhdGNoU2l6ZSxcbiAgICAgIC8qZGVwdGg9Ki9kaWxhdGlvbkluZm8uaW5DaGFubmVscyxcbiAgICAgIGRpbGF0aW9uSW5mby5pbkhlaWdodCxcbiAgICAgIGRpbGF0aW9uSW5mby5pbldpZHRoLFxuICAgICAgZGlsYXRpb25JbmZvLm91dEhlaWdodCxcbiAgICAgIGRpbGF0aW9uSW5mby5vdXRXaWR0aCxcbiAgICAgIGRpbGF0aW9uSW5mby5zdHJpZGVIZWlnaHQsXG4gICAgICBkaWxhdGlvbkluZm8uc3RyaWRlV2lkdGgsXG4gICAgICBkaWxhdGlvbkluZm8uZGlsYXRpb25IZWlnaHQsXG4gICAgICBkaWxhdGlvbkluZm8uZGlsYXRpb25XaWR0aCxcbiAgICAgIGRpbGF0aW9uSW5mby5maWx0ZXJIZWlnaHQsXG4gICAgICBkaWxhdGlvbkluZm8uZmlsdGVyV2lkdGgsXG4gICAgICBkaWxhdGlvbkluZm8ucGFkSW5mby50b3AsXG4gICAgICBkaWxhdGlvbkluZm8ucGFkSW5mby5sZWZ0LFxuICApO1xuICByZXR1cm4gZ3JhZGllbnRzO1xufVxuXG5leHBvcnQgY29uc3QgZGlsYXRpb24yREJhY2twcm9wRmlsdGVyQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IERpbGF0aW9uMkRCYWNrcHJvcEZpbHRlcixcbiAgYmFja2VuZE5hbWU6ICd3YXNtJyxcbiAgc2V0dXBGdW5jOiBzZXR1cCxcbiAga2VybmVsRnVuYzogZGlsYXRpb24yREJhY2twcm9wRmlsdGVyIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==