@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
52 lines • 7.55 kB
JavaScript
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { TopK } from '@tensorflow/tfjs-core';
import { CppDType } from './types';
let wasmTopK;
function setup(backend) {
wasmTopK = backend.wasm.cwrap(TopK, null /* void */, [
'number',
'array',
'number',
'number',
'number',
'bool',
'number',
'number', // outIndicesId
]);
}
export const topk = ({ inputs, backend, attrs }) => {
const { x } = inputs;
const { k, sorted } = attrs;
const xId = backend.dataIdMap.get(x.dataId).id;
const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer);
const outputShape = x.shape.slice();
outputShape[outputShape.length - 1] = k;
const outValues = backend.makeOutput(outputShape, x.dtype);
const outValuesId = backend.dataIdMap.get(outValues.dataId).id;
const outIndices = backend.makeOutput(outputShape, 'int32');
const outIndicesId = backend.dataIdMap.get(outIndices.dataId).id;
wasmTopK(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], k, sorted, outValuesId, outIndicesId);
return [outValues, outIndices];
};
export const topKConfig = {
kernelName: TopK,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: topk,
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiVG9wSy5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL1RvcEsudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUF1QyxJQUFJLEVBQXdCLE1BQU0sdUJBQXVCLENBQUM7QUFHeEcsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUVqQyxJQUFJLFFBRzZCLENBQUM7QUFFbEMsU0FBUyxLQUFLLENBQUMsT0FBb0I7SUFDakMsUUFBUSxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsS0FBSyxDQUFDLElBQUksRUFBRSxJQUFJLENBQUMsVUFBVSxFQUFFO1FBQ25ELFFBQVE7UUFDUixPQUFPO1FBQ1AsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsTUFBTTtRQUNOLFFBQVE7UUFDUixRQUFRLEVBQUcsZUFBZTtLQUMzQixDQUFDLENBQUM7QUFDTCxDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sSUFBSSxHQUVtQixDQUFDLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsRUFBRSxFQUFFO0lBQ3ZELE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDbkIsTUFBTSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFMUIsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUMvQyxNQUFNLFdBQVcsR0FBRyxJQUFJLFVBQVUsQ0FBQyxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDbkUsTUFBTSxXQUFXLEdBQUcsQ0FBQyxDQUFDLEtBQUssQ0FBQyxLQUFLLEVBQUUsQ0FBQztJQUNwQyxXQUFXLENBQUMsV0FBVyxDQUFDLE1BQU0sR0FBRyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUM7SUFDeEMsTUFBTSxTQUFTLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxXQUFXLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBQzNELE1BQU0sV0FBVyxHQUFHLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLENBQUM7SUFDL0QsTUFBTSxVQUFVLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxXQUFXLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFDNUQsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUVqRSxRQUFRLENBQ0osR0FBRyxFQUFFLFdBQVcsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFBRSxRQUFRLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxFQUFFLENBQUMsRUFBRSxNQUFNLEVBQzlELFdBQVcsRUFBRSxZQUFZLENBQUMsQ0FBQztJQUUvQixPQUFPLENBQUMsU0FBUyxFQUFFLFVBQVUsQ0FBQyxDQUFDO0FBQ2pDLENBQUMsQ0FBQztBQUVWLE1BQU0sQ0FBQyxNQUFNLFVBQVUsR0FBaUI7SUFDdEMsVUFBVSxFQUFFLElBQUk7SUFDaEIsV0FBVyxFQUFFLE1BQU07SUFDbkIsU0FBUyxFQUFFLEtBQUs7SUFDaEIsVUFBVSxFQUFFLElBQTZCO0NBQzFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7S2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBUZW5zb3JJbmZvLCBUb3BLLCBUb3BLQXR0cnMsIFRvcEtJbnB1dHN9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QmFja2VuZFdhc219IGZyb20gJy4uL2JhY2tlbmRfd2FzbSc7XG5pbXBvcnQge0NwcERUeXBlfSBmcm9tICcuL3R5cGVzJztcblxubGV0IHdhc21Ub3BLOiAoXG4gICAgeElkOiBudW1iZXIsIHhTaGFwZUJ5dGVzOiBVaW50OEFycmF5LCB4U2hhcGVMZW5ndGg6IG51bWJlcixcbiAgICB4RHR5cGU6IENwcERUeXBlLCBrOiBudW1iZXIsIHNvcnRlZDogYm9vbGVhbiwgb3V0VmFsdWVzSWQ6IG51bWJlcixcbiAgICBvdXRJbmRpY2VzSWQ6IG51bWJlcikgPT4gdm9pZDtcblxuZnVuY3Rpb24gc2V0dXAoYmFja2VuZDogQmFja2VuZFdhc20pIHtcbiAgd2FzbVRvcEsgPSBiYWNrZW5kLndhc20uY3dyYXAoVG9wSywgbnVsbCAvKiB2b2lkICovLCBbXG4gICAgJ251bWJlcicsICAvLyB4SWRcbiAgICAnYXJyYXknLCAgIC8vIHguc2hhcGVcbiAgICAnbnVtYmVyJywgIC8vIHguc2hhcGUubGVuZ3RoXG4gICAgJ251bWJlcicsICAvLyB4LmR0eXBlXG4gICAgJ251bWJlcicsICAvLyBrXG4gICAgJ2Jvb2wnLCAgICAvLyBzb3J0ZWRcbiAgICAnbnVtYmVyJywgIC8vIG91dFZhbHVlc0lkXG4gICAgJ251bWJlcicsICAvLyBvdXRJbmRpY2VzSWRcbiAgXSk7XG59XG5cbmV4cG9ydCBjb25zdCB0b3BrOlxuICAgIChhcmdzOiB7aW5wdXRzOiBUb3BLSW5wdXRzLCBiYWNrZW5kOiBCYWNrZW5kV2FzbSwgYXR0cnM6IFRvcEtBdHRyc30pID0+XG4gICAgICAgIFRlbnNvckluZm9bXSB8IFRlbnNvckluZm8gPSAoe2lucHV0cywgYmFja2VuZCwgYXR0cnN9KSA9PiB7XG4gICAgICAgICAgY29uc3Qge3h9ID0gaW5wdXRzO1xuICAgICAgICAgIGNvbnN0IHtrLCBzb3J0ZWR9ID0gYXR0cnM7XG5cbiAgICAgICAgICBjb25zdCB4SWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQoeC5kYXRhSWQpLmlkO1xuICAgICAgICAgIGNvbnN0IHhTaGFwZUJ5dGVzID0gbmV3IFVpbnQ4QXJyYXkobmV3IEludDMyQXJyYXkoeC5zaGFwZSkuYnVmZmVyKTtcbiAgICAgICAgICBjb25zdCBvdXRwdXRTaGFwZSA9IHguc2hhcGUuc2xpY2UoKTtcbiAgICAgICAgICBvdXRwdXRTaGFwZVtvdXRwdXRTaGFwZS5sZW5ndGggLSAxXSA9IGs7XG4gICAgICAgICAgY29uc3Qgb3V0VmFsdWVzID0gYmFja2VuZC5tYWtlT3V0cHV0KG91dHB1dFNoYXBlLCB4LmR0eXBlKTtcbiAgICAgICAgICBjb25zdCBvdXRWYWx1ZXNJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldChvdXRWYWx1ZXMuZGF0YUlkKS5pZDtcbiAgICAgICAgICBjb25zdCBvdXRJbmRpY2VzID0gYmFja2VuZC5tYWtlT3V0cHV0KG91dHB1dFNoYXBlLCAnaW50MzInKTtcbiAgICAgICAgICBjb25zdCBvdXRJbmRpY2VzSWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQob3V0SW5kaWNlcy5kYXRhSWQpLmlkO1xuXG4gICAgICAgICAgd2FzbVRvcEsoXG4gICAgICAgICAgICAgIHhJZCwgeFNoYXBlQnl0ZXMsIHguc2hhcGUubGVuZ3RoLCBDcHBEVHlwZVt4LmR0eXBlXSwgaywgc29ydGVkLFxuICAgICAgICAgICAgICBvdXRWYWx1ZXNJZCwgb3V0SW5kaWNlc0lkKTtcblxuICAgICAgICAgIHJldHVybiBbb3V0VmFsdWVzLCBvdXRJbmRpY2VzXTtcbiAgICAgICAgfTtcblxuZXhwb3J0IGNvbnN0IHRvcEtDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogVG9wSyxcbiAgYmFja2VuZE5hbWU6ICd3YXNtJyxcbiAgc2V0dXBGdW5jOiBzZXR1cCxcbiAga2VybmVsRnVuYzogdG9wayBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmMsXG59O1xuIl19