@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
67 lines • 9 kB
JavaScript
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { util } from '@tensorflow/tfjs-core';
import { permuteAxesAndTranspose } from './kernel_utils';
import { CppDType } from './types';
export function createArgMinMaxKernelConfig(kernelName) {
let wasmFunc;
function setupFunc(backend) {
wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [
'number',
'number',
'number',
'number',
'number' // out_id
]);
}
function kernelFunc(args) {
const { backend, inputs, attrs } = args;
const { axis } = attrs;
const { x } = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
let inputId = xId;
let input = x;
const { transposed, axes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend);
if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
if (transposedId !== xId) {
// transpose was not a no-op. We will need to dispose of this
// once we are done.
input = transposed;
inputId = transposedId;
}
}
const outShape = input.shape.slice(0, -1);
const out = backend.makeOutput(outShape, 'int32');
const outId = backend.dataIdMap.get(out.dataId).id;
const outerSize = util.sizeFromShape(out.shape);
const innerSize = input.shape[axes[0]];
wasmFunc(inputId, CppDType[input.dtype], outerSize, innerSize, outId);
if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}
return out;
}
return {
kernelName,
backendName: 'wasm',
setupFunc,
kernelFunc: kernelFunc,
};
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYXJnbWlubWF4X2tlcm5lbC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL2FyZ21pbm1heF9rZXJuZWwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsT0FBTyxFQUE2RixJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUl2SSxPQUFPLEVBQUMsdUJBQXVCLEVBQUMsTUFBTSxnQkFBZ0IsQ0FBQztBQUN2RCxPQUFPLEVBQUMsUUFBUSxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBRWpDLE1BQU0sVUFBVSwyQkFBMkIsQ0FBQyxVQUNRO0lBQ2xELElBQUksUUFFc0IsQ0FBQztJQUUzQixTQUFTLFNBQVMsQ0FBQyxPQUFvQjtRQUNyQyxRQUFRLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsVUFBVSxFQUFFLElBQUksQ0FBQyxVQUFVLEVBQUU7WUFDekQsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVEsQ0FBRyxTQUFTO1NBQ3JCLENBQUMsQ0FBQztJQUNMLENBQUM7SUFFRCxTQUFTLFVBQVUsQ0FBQyxJQUluQjtRQUNDLE1BQU0sRUFBQyxPQUFPLEVBQUUsTUFBTSxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztRQUN0QyxNQUFNLEVBQUMsSUFBSSxFQUFDLEdBQUcsS0FBSyxDQUFDO1FBQ3JCLE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7UUFDbkIsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztRQUMvQyxJQUFJLE9BQU8sR0FBRyxHQUFHLENBQUM7UUFDbEIsSUFBSSxLQUFLLEdBQUcsQ0FBQyxDQUFDO1FBRWQsTUFBTSxFQUFDLFVBQVUsRUFBRSxJQUFJLEVBQUUsa0JBQWtCLEVBQUMsR0FDeEMsdUJBQXVCLENBQUMsQ0FBQyxFQUFFLElBQUksRUFBRSxPQUFPLENBQUMsQ0FBQztRQUU5QyxJQUFJLGtCQUFrQixFQUFFO1lBQ3RCLE1BQU0sWUFBWSxHQUFHLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLFVBQVUsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLENBQUM7WUFDakUsSUFBSSxZQUFZLEtBQUssR0FBRyxFQUFFO2dCQUN4Qiw2REFBNkQ7Z0JBQzdELG9CQUFvQjtnQkFDcEIsS0FBSyxHQUFHLFVBQVUsQ0FBQztnQkFDbkIsT0FBTyxHQUFHLFlBQVksQ0FBQzthQUN4QjtTQUNGO1FBRUQsTUFBTSxRQUFRLEdBQUcsS0FBSyxDQUFDLEtBQUssQ0FBQyxLQUFLLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDMUMsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxRQUFRLEVBQUUsT0FBTyxDQUFDLENBQUM7UUFDbEQsTUFBTSxLQUFLLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztRQUNuRCxNQUFNLFNBQVMsR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQztRQUNoRCxNQUFNLFNBQVMsR0FBRyxLQUFLLENBQUMsS0FBSyxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQ3ZDLFFBQVEsQ0FBQyxPQUFPLEVBQUUsUUFBUSxDQUFDLEtBQUssQ0FBQyxLQUFLLENBQUMsRUFBRSxTQUFTLEVBQUUsU0FBUyxFQUFFLEtBQUssQ0FBQyxDQUFDO1FBRXRFLElBQUksa0JBQWtCLEVBQUU7WUFDdEIsb0NBQW9DO1lBQ3BDLE9BQU8sQ0FBQyxXQUFXLENBQUMsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDO1NBQ3hDO1FBRUQsT0FBTyxHQUFHLENBQUM7SUFDYixDQUFDO0lBRUQsT0FBTztRQUNMLFVBQVU7UUFDVixXQUFXLEVBQUUsTUFBTTtRQUNuQixTQUFTO1FBQ1QsVUFBVSxFQUFFLFVBQW1DO0tBQ2hELENBQUM7QUFDSixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjMgR29vZ2xlIExMQy5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0IHtBcmdNYXhBdHRycywgQXJnTWF4SW5wdXRzLCBBcmdNaW5BdHRycywgQXJnTWluSW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QmFja2VuZFdhc219IGZyb20gJy4uL2JhY2tlbmRfd2FzbSc7XG5cbmltcG9ydCB7cGVybXV0ZUF4ZXNBbmRUcmFuc3Bvc2V9IGZyb20gJy4va2VybmVsX3V0aWxzJztcbmltcG9ydCB7Q3BwRFR5cGV9IGZyb20gJy4vdHlwZXMnO1xuXG5leHBvcnQgZnVuY3Rpb24gY3JlYXRlQXJnTWluTWF4S2VybmVsQ29uZmlnKGtlcm5lbE5hbWU6ICdBcmdNaW4nfFxuICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnQXJnTWF4Jyk6IEtlcm5lbENvbmZpZyB7XG4gIGxldCB3YXNtRnVuYzogKFxuICAgICAgeElkOiBudW1iZXIsIGR0eXBlOiBudW1iZXIsIG91dGVyU2l6ZTogbnVtYmVyLCBpbm5lclNpemU6IG51bWJlcixcbiAgICAgIG91dElkOiBudW1iZXIpID0+IHZvaWQ7XG5cbiAgZnVuY3Rpb24gc2V0dXBGdW5jKGJhY2tlbmQ6IEJhY2tlbmRXYXNtKSB7XG4gICAgd2FzbUZ1bmMgPSBiYWNrZW5kLndhc20uY3dyYXAoa2VybmVsTmFtZSwgbnVsbCAvKiB2b2lkICovLCBbXG4gICAgICAnbnVtYmVyJywgIC8vIHhfaWRcbiAgICAgICdudW1iZXInLCAgLy8gZHR5cGVcbiAgICAgICdudW1iZXInLCAgLy8gb3V0ZXJfc2l6ZVxuICAgICAgJ251bWJlcicsICAvLyBpbm5lcl9zaXplXG4gICAgICAnbnVtYmVyJyAgIC8vIG91dF9pZFxuICAgIF0pO1xuICB9XG5cbiAgZnVuY3Rpb24ga2VybmVsRnVuYyhhcmdzOiB7XG4gICAgYmFja2VuZDogQmFja2VuZFdhc20sXG4gICAgaW5wdXRzOiBBcmdNaW5JbnB1dHMmQXJnTWF4SW5wdXRzLFxuICAgIGF0dHJzOiBBcmdNaW5BdHRycyZBcmdNYXhBdHRycyxcbiAgfSk6IFRlbnNvckluZm8ge1xuICAgIGNvbnN0IHtiYWNrZW5kLCBpbnB1dHMsIGF0dHJzfSA9IGFyZ3M7XG4gICAgY29uc3Qge2F4aXN9ID0gYXR0cnM7XG4gICAgY29uc3Qge3h9ID0gaW5wdXRzO1xuICAgIGNvbnN0IHhJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldCh4LmRhdGFJZCkuaWQ7XG4gICAgbGV0IGlucHV0SWQgPSB4SWQ7XG4gICAgbGV0IGlucHV0ID0geDtcblxuICAgIGNvbnN0IHt0cmFuc3Bvc2VkLCBheGVzLCBpbnB1dFdhc1RyYW5zcG9zZWR9ID1cbiAgICAgICAgcGVybXV0ZUF4ZXNBbmRUcmFuc3Bvc2UoeCwgYXhpcywgYmFja2VuZCk7XG5cbiAgICBpZiAoaW5wdXRXYXNUcmFuc3Bvc2VkKSB7XG4gICAgICBjb25zdCB0cmFuc3Bvc2VkSWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQodHJhbnNwb3NlZC5kYXRhSWQpLmlkO1xuICAgICAgaWYgKHRyYW5zcG9zZWRJZCAhPT0geElkKSB7XG4gICAgICAgIC8vIHRyYW5zcG9zZSB3YXMgbm90IGEgbm8tb3AuIFdlIHdpbGwgbmVlZCB0byBkaXNwb3NlIG9mIHRoaXNcbiAgICAgICAgLy8gb25jZSB3ZSBhcmUgZG9uZS5cbiAgICAgICAgaW5wdXQgPSB0cmFuc3Bvc2VkO1xuICAgICAgICBpbnB1dElkID0gdHJhbnNwb3NlZElkO1xuICAgICAgfVxuICAgIH1cblxuICAgIGNvbnN0IG91dFNoYXBlID0gaW5wdXQuc2hhcGUuc2xpY2UoMCwgLTEpO1xuICAgIGNvbnN0IG91dCA9IGJhY2tlbmQubWFrZU91dHB1dChvdXRTaGFwZSwgJ2ludDMyJyk7XG4gICAgY29uc3Qgb3V0SWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQob3V0LmRhdGFJZCkuaWQ7XG4gICAgY29uc3Qgb3V0ZXJTaXplID0gdXRpbC5zaXplRnJvbVNoYXBlKG91dC5zaGFwZSk7XG4gICAgY29uc3QgaW5uZXJTaXplID0gaW5wdXQuc2hhcGVbYXhlc1swXV07XG4gICAgd2FzbUZ1bmMoaW5wdXRJZCwgQ3BwRFR5cGVbaW5wdXQuZHR5cGVdLCBvdXRlclNpemUsIGlubmVyU2l6ZSwgb3V0SWQpO1xuXG4gICAgaWYgKGlucHV0V2FzVHJhbnNwb3NlZCkge1xuICAgICAgLy8gZGlzcG9zZSBvZiB0aGUgdHJhbnNwb3NlZCB0ZW5zb3IuXG4gICAgICBiYWNrZW5kLmRpc3Bvc2VEYXRhKHRyYW5zcG9zZWQuZGF0YUlkKTtcbiAgICB9XG5cbiAgICByZXR1cm4gb3V0O1xuICB9XG5cbiAgcmV0dXJuIHtcbiAgICBrZXJuZWxOYW1lLFxuICAgIGJhY2tlbmROYW1lOiAnd2FzbScsXG4gICAgc2V0dXBGdW5jLFxuICAgIGtlcm5lbEZ1bmM6IGtlcm5lbEZ1bmMgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jLFxuICB9O1xufVxuIl19