UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

67 lines 9 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { util } from '@tensorflow/tfjs-core'; import { permuteAxesAndTranspose } from './kernel_utils'; import { CppDType } from './types'; export function createArgMinMaxKernelConfig(kernelName) { let wasmFunc; function setupFunc(backend) { wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [ 'number', 'number', 'number', 'number', 'number' // out_id ]); } function kernelFunc(args) { const { backend, inputs, attrs } = args; const { axis } = attrs; const { x } = inputs; const xId = backend.dataIdMap.get(x.dataId).id; let inputId = xId; let input = x; const { transposed, axes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend); if (inputWasTransposed) { const transposedId = backend.dataIdMap.get(transposed.dataId).id; if (transposedId !== xId) { // transpose was not a no-op. We will need to dispose of this // once we are done. input = transposed; inputId = transposedId; } } const outShape = input.shape.slice(0, -1); const out = backend.makeOutput(outShape, 'int32'); const outId = backend.dataIdMap.get(out.dataId).id; const outerSize = util.sizeFromShape(out.shape); const innerSize = input.shape[axes[0]]; wasmFunc(inputId, CppDType[input.dtype], outerSize, innerSize, outId); if (inputWasTransposed) { // dispose of the transposed tensor. backend.disposeData(transposed.dataId); } return out; } return { kernelName, backendName: 'wasm', setupFunc, kernelFunc: kernelFunc, }; } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYXJnbWlubWF4X2tlcm5lbC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL2FyZ21pbm1heF9rZXJuZWwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsT0FBTyxFQUE2RixJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUl2SSxPQUFPLEVBQUMsdUJBQXVCLEVBQUMsTUFBTSxnQkFBZ0IsQ0FBQztBQUN2RCxPQUFPLEVBQUMsUUFBUSxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBRWpDLE1BQU0sVUFBVSwyQkFBMkIsQ0FBQyxVQUNRO0lBQ2xELElBQUksUUFFc0IsQ0FBQztJQUUzQixTQUFTLFNBQVMsQ0FBQyxPQUFvQjtRQUNyQyxRQUFRLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsVUFBVSxFQUFFLElBQUksQ0FBQyxVQUFVLEVBQUU7WUFDekQsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVEsQ0FBRyxTQUFTO1NBQ3JCLENBQUMsQ0FBQztJQUNMLENBQUM7SUFFRCxTQUFTLFVBQVUsQ0FBQyxJQUluQjtRQUNDLE1BQU0sRUFBQyxPQUFPLEVBQUUsTUFBTSxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztRQUN0QyxNQUFNLEVBQUMsSUFBSSxFQUFDLEdBQUcsS0FBSyxDQUFDO1FBQ3JCLE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7UUFDbkIsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztRQUMvQyxJQUFJLE9BQU8sR0FBRyxHQUFHLENBQUM7UUFDbEIsSUFBSSxLQUFLLEdBQUcsQ0FBQyxDQUFDO1FBRWQsTUFBTSxFQUFDLFVBQVUsRUFBRSxJQUFJLEVBQUUsa0JBQWtCLEVBQUMsR0FDeEMsdUJBQXVCLENBQUMsQ0FBQyxFQUFFLElBQUksRUFBRSxPQUFPLENBQUMsQ0FBQztRQUU5QyxJQUFJLGtCQUFrQixFQUFFO1lBQ3RCLE1BQU0sWUFBWSxHQUFHLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLFVBQVUsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLENBQUM7WUFDakUsSUFBSSxZQUFZLEtBQUssR0FBRyxFQUFFO2dCQUN4Qiw2REFBNkQ7Z0JBQzdELG9CQUFvQjtnQkFDcEIsS0FBSyxHQUFHLFVBQVUsQ0FBQztnQkFDbkIsT0FBTyxHQUFHLFlBQVksQ0FBQzthQUN4QjtTQUNGO1FBRUQsTUFBTSxRQUFRLEdBQUcsS0FBSyxDQUFDLEtBQUssQ0FBQyxLQUFLLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDMUMsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxRQUFRLEVBQUUsT0FBTyxDQUFDLENBQUM7UUFDbEQsTUFBTSxLQUFLLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztRQUNuRCxNQUFNLFNBQVMsR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQztRQUNoRCxNQUFNLFNBQVMsR0FBRyxLQUFLLENBQUMsS0FBSyxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQ3ZDLFFBQVEsQ0FBQyxPQUFPLEVBQUUsUUFBUSxDQUFDLEtBQUssQ0FBQyxLQUFLLENBQUMsRUFBRSxTQUFTLEVBQUUsU0FBUyxFQUFFLEtBQUssQ0FBQyxDQUFDO1FBRXRFLElBQUksa0JBQWtCLEVBQUU7WUFDdEIsb0NBQW9DO1lBQ3BDLE9BQU8sQ0FBQyxXQUFXLENBQUMsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDO1NBQ3hDO1FBRUQsT0FBTyxHQUFHLENBQUM7SUFDYixDQUFDO0lBRUQsT0FBTztRQUNMLFVBQVU7UUFDVixXQUFXLEVBQUUsTUFBTTtRQUNuQixTQUFTO1FBQ1QsVUFBVSxFQUFFLFVBQW1DO0tBQ2hELENBQUM7QUFDSixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjMgR29vZ2xlIExMQy5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0IHtBcmdNYXhBdHRycywgQXJnTWF4SW5wdXRzLCBBcmdNaW5BdHRycywgQXJnTWluSW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QmFja2VuZFdhc219IGZyb20gJy4uL2JhY2tlbmRfd2FzbSc7XG5cbmltcG9ydCB7cGVybXV0ZUF4ZXNBbmRUcmFuc3Bvc2V9IGZyb20gJy4va2VybmVsX3V0aWxzJztcbmltcG9ydCB7Q3BwRFR5cGV9IGZyb20gJy4vdHlwZXMnO1xuXG5leHBvcnQgZnVuY3Rpb24gY3JlYXRlQXJnTWluTWF4S2VybmVsQ29uZmlnKGtlcm5lbE5hbWU6ICdBcmdNaW4nfFxuICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnQXJnTWF4Jyk6IEtlcm5lbENvbmZpZyB7XG4gIGxldCB3YXNtRnVuYzogKFxuICAgICAgeElkOiBudW1iZXIsIGR0eXBlOiBudW1iZXIsIG91dGVyU2l6ZTogbnVtYmVyLCBpbm5lclNpemU6IG51bWJlcixcbiAgICAgIG91dElkOiBudW1iZXIpID0+IHZvaWQ7XG5cbiAgZnVuY3Rpb24gc2V0dXBGdW5jKGJhY2tlbmQ6IEJhY2tlbmRXYXNtKSB7XG4gICAgd2FzbUZ1bmMgPSBiYWNrZW5kLndhc20uY3dyYXAoa2VybmVsTmFtZSwgbnVsbCAvKiB2b2lkICovLCBbXG4gICAgICAnbnVtYmVyJywgIC8vIHhfaWRcbiAgICAgICdudW1iZXInLCAgLy8gZHR5cGVcbiAgICAgICdudW1iZXInLCAgLy8gb3V0ZXJfc2l6ZVxuICAgICAgJ251bWJlcicsICAvLyBpbm5lcl9zaXplXG4gICAgICAnbnVtYmVyJyAgIC8vIG91dF9pZFxuICAgIF0pO1xuICB9XG5cbiAgZnVuY3Rpb24ga2VybmVsRnVuYyhhcmdzOiB7XG4gICAgYmFja2VuZDogQmFja2VuZFdhc20sXG4gICAgaW5wdXRzOiBBcmdNaW5JbnB1dHMmQXJnTWF4SW5wdXRzLFxuICAgIGF0dHJzOiBBcmdNaW5BdHRycyZBcmdNYXhBdHRycyxcbiAgfSk6IFRlbnNvckluZm8ge1xuICAgIGNvbnN0IHtiYWNrZW5kLCBpbnB1dHMsIGF0dHJzfSA9IGFyZ3M7XG4gICAgY29uc3Qge2F4aXN9ID0gYXR0cnM7XG4gICAgY29uc3Qge3h9ID0gaW5wdXRzO1xuICAgIGNvbnN0IHhJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldCh4LmRhdGFJZCkuaWQ7XG4gICAgbGV0IGlucHV0SWQgPSB4SWQ7XG4gICAgbGV0IGlucHV0ID0geDtcblxuICAgIGNvbnN0IHt0cmFuc3Bvc2VkLCBheGVzLCBpbnB1dFdhc1RyYW5zcG9zZWR9ID1cbiAgICAgICAgcGVybXV0ZUF4ZXNBbmRUcmFuc3Bvc2UoeCwgYXhpcywgYmFja2VuZCk7XG5cbiAgICBpZiAoaW5wdXRXYXNUcmFuc3Bvc2VkKSB7XG4gICAgICBjb25zdCB0cmFuc3Bvc2VkSWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQodHJhbnNwb3NlZC5kYXRhSWQpLmlkO1xuICAgICAgaWYgKHRyYW5zcG9zZWRJZCAhPT0geElkKSB7XG4gICAgICAgIC8vIHRyYW5zcG9zZSB3YXMgbm90IGEgbm8tb3AuIFdlIHdpbGwgbmVlZCB0byBkaXNwb3NlIG9mIHRoaXNcbiAgICAgICAgLy8gb25jZSB3ZSBhcmUgZG9uZS5cbiAgICAgICAgaW5wdXQgPSB0cmFuc3Bvc2VkO1xuICAgICAgICBpbnB1dElkID0gdHJhbnNwb3NlZElkO1xuICAgICAgfVxuICAgIH1cblxuICAgIGNvbnN0IG91dFNoYXBlID0gaW5wdXQuc2hhcGUuc2xpY2UoMCwgLTEpO1xuICAgIGNvbnN0IG91dCA9IGJhY2tlbmQubWFrZU91dHB1dChvdXRTaGFwZSwgJ2ludDMyJyk7XG4gICAgY29uc3Qgb3V0SWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQob3V0LmRhdGFJZCkuaWQ7XG4gICAgY29uc3Qgb3V0ZXJTaXplID0gdXRpbC5zaXplRnJvbVNoYXBlKG91dC5zaGFwZSk7XG4gICAgY29uc3QgaW5uZXJTaXplID0gaW5wdXQuc2hhcGVbYXhlc1swXV07XG4gICAgd2FzbUZ1bmMoaW5wdXRJZCwgQ3BwRFR5cGVbaW5wdXQuZHR5cGVdLCBvdXRlclNpemUsIGlubmVyU2l6ZSwgb3V0SWQpO1xuXG4gICAgaWYgKGlucHV0V2FzVHJhbnNwb3NlZCkge1xuICAgICAgLy8gZGlzcG9zZSBvZiB0aGUgdHJhbnNwb3NlZCB0ZW5zb3IuXG4gICAgICBiYWNrZW5kLmRpc3Bvc2VEYXRhKHRyYW5zcG9zZWQuZGF0YUlkKTtcbiAgICB9XG5cbiAgICByZXR1cm4gb3V0O1xuICB9XG5cbiAgcmV0dXJuIHtcbiAgICBrZXJuZWxOYW1lLFxuICAgIGJhY2tlbmROYW1lOiAnd2FzbScsXG4gICAgc2V0dXBGdW5jLFxuICAgIGtlcm5lbEZ1bmM6IGtlcm5lbEZ1bmMgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jLFxuICB9O1xufVxuIl19