UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

52 lines 7.55 kB
/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { TopK } from '@tensorflow/tfjs-core'; import { CppDType } from './types'; let wasmTopK; function setup(backend) { wasmTopK = backend.wasm.cwrap(TopK, null /* void */, [ 'number', 'array', 'number', 'number', 'number', 'bool', 'number', 'number', // outIndicesId ]); } export const topk = ({ inputs, backend, attrs }) => { const { x } = inputs; const { k, sorted } = attrs; const xId = backend.dataIdMap.get(x.dataId).id; const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); const outputShape = x.shape.slice(); outputShape[outputShape.length - 1] = k; const outValues = backend.makeOutput(outputShape, x.dtype); const outValuesId = backend.dataIdMap.get(outValues.dataId).id; const outIndices = backend.makeOutput(outputShape, 'int32'); const outIndicesId = backend.dataIdMap.get(outIndices.dataId).id; wasmTopK(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], k, sorted, outValuesId, outIndicesId); return [outValues, outIndices]; }; export const topKConfig = { kernelName: TopK, backendName: 'wasm', setupFunc: setup, kernelFunc: topk, }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiVG9wSy5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL1RvcEsudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUF1QyxJQUFJLEVBQXdCLE1BQU0sdUJBQXVCLENBQUM7QUFHeEcsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUVqQyxJQUFJLFFBRzZCLENBQUM7QUFFbEMsU0FBUyxLQUFLLENBQUMsT0FBb0I7SUFDakMsUUFBUSxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsS0FBSyxDQUFDLElBQUksRUFBRSxJQUFJLENBQUMsVUFBVSxFQUFFO1FBQ25ELFFBQVE7UUFDUixPQUFPO1FBQ1AsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsTUFBTTtRQUNOLFFBQVE7UUFDUixRQUFRLEVBQUcsZUFBZTtLQUMzQixDQUFDLENBQUM7QUFDTCxDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sSUFBSSxHQUVtQixDQUFDLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsRUFBRSxFQUFFO0lBQ3ZELE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDbkIsTUFBTSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFMUIsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUMvQyxNQUFNLFdBQVcsR0FBRyxJQUFJLFVBQVUsQ0FBQyxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDbkUsTUFBTSxXQUFXLEdBQUcsQ0FBQyxDQUFDLEtBQUssQ0FBQyxLQUFLLEVBQUUsQ0FBQztJQUNwQyxXQUFXLENBQUMsV0FBVyxDQUFDLE1BQU0sR0FBRyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUM7SUFDeEMsTUFBTSxTQUFTLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxXQUFXLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBQzNELE1BQU0sV0FBVyxHQUFHLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLENBQUM7SUFDL0QsTUFBTSxVQUFVLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxXQUFXLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFDNUQsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUVqRSxRQUFRLENBQ0osR0FBRyxFQUFFLFdBQVcsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFBRSxRQUFRLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxFQUFFLENBQUMsRUFBRSxNQUFNLEVBQzlELFdBQVcsRUFBRSxZQUFZLENBQUMsQ0FBQztJQUUvQixPQUFPLENBQUMsU0FBUyxFQUFFLFVBQVUsQ0FBQyxDQUFDO0FBQ2pDLENBQUMsQ0FBQztBQUVWLE1BQU0sQ0FBQyxNQUFNLFVBQVUsR0FBaUI7SUFDdEMsVUFBVSxFQUFFLElBQUk7SUFDaEIsV0FBVyxFQUFFLE1BQU07SUFDbkIsU0FBUyxFQUFFLEtBQUs7SUFDaEIsVUFBVSxFQUFFLElBQTZCO0NBQzFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7S2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBUZW5zb3JJbmZvLCBUb3BLLCBUb3BLQXR0cnMsIFRvcEtJbnB1dHN9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7QmFja2VuZFdhc219IGZyb20gJy4uL2JhY2tlbmRfd2FzbSc7XG5pbXBvcnQge0NwcERUeXBlfSBmcm9tICcuL3R5cGVzJztcblxubGV0IHdhc21Ub3BLOiAoXG4gICAgeElkOiBudW1iZXIsIHhTaGFwZUJ5dGVzOiBVaW50OEFycmF5LCB4U2hhcGVMZW5ndGg6IG51bWJlcixcbiAgICB4RHR5cGU6IENwcERUeXBlLCBrOiBudW1iZXIsIHNvcnRlZDogYm9vbGVhbiwgb3V0VmFsdWVzSWQ6IG51bWJlcixcbiAgICBvdXRJbmRpY2VzSWQ6IG51bWJlcikgPT4gdm9pZDtcblxuZnVuY3Rpb24gc2V0dXAoYmFja2VuZDogQmFja2VuZFdhc20pIHtcbiAgd2FzbVRvcEsgPSBiYWNrZW5kLndhc20uY3dyYXAoVG9wSywgbnVsbCAvKiB2b2lkICovLCBbXG4gICAgJ251bWJlcicsICAvLyB4SWRcbiAgICAnYXJyYXknLCAgIC8vIHguc2hhcGVcbiAgICAnbnVtYmVyJywgIC8vIHguc2hhcGUubGVuZ3RoXG4gICAgJ251bWJlcicsICAvLyB4LmR0eXBlXG4gICAgJ251bWJlcicsICAvLyBrXG4gICAgJ2Jvb2wnLCAgICAvLyBzb3J0ZWRcbiAgICAnbnVtYmVyJywgIC8vIG91dFZhbHVlc0lkXG4gICAgJ251bWJlcicsICAvLyBvdXRJbmRpY2VzSWRcbiAgXSk7XG59XG5cbmV4cG9ydCBjb25zdCB0b3BrOlxuICAgIChhcmdzOiB7aW5wdXRzOiBUb3BLSW5wdXRzLCBiYWNrZW5kOiBCYWNrZW5kV2FzbSwgYXR0cnM6IFRvcEtBdHRyc30pID0+XG4gICAgICAgIFRlbnNvckluZm9bXSB8IFRlbnNvckluZm8gPSAoe2lucHV0cywgYmFja2VuZCwgYXR0cnN9KSA9PiB7XG4gICAgICAgICAgY29uc3Qge3h9ID0gaW5wdXRzO1xuICAgICAgICAgIGNvbnN0IHtrLCBzb3J0ZWR9ID0gYXR0cnM7XG5cbiAgICAgICAgICBjb25zdCB4SWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQoeC5kYXRhSWQpLmlkO1xuICAgICAgICAgIGNvbnN0IHhTaGFwZUJ5dGVzID0gbmV3IFVpbnQ4QXJyYXkobmV3IEludDMyQXJyYXkoeC5zaGFwZSkuYnVmZmVyKTtcbiAgICAgICAgICBjb25zdCBvdXRwdXRTaGFwZSA9IHguc2hhcGUuc2xpY2UoKTtcbiAgICAgICAgICBvdXRwdXRTaGFwZVtvdXRwdXRTaGFwZS5sZW5ndGggLSAxXSA9IGs7XG4gICAgICAgICAgY29uc3Qgb3V0VmFsdWVzID0gYmFja2VuZC5tYWtlT3V0cHV0KG91dHB1dFNoYXBlLCB4LmR0eXBlKTtcbiAgICAgICAgICBjb25zdCBvdXRWYWx1ZXNJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldChvdXRWYWx1ZXMuZGF0YUlkKS5pZDtcbiAgICAgICAgICBjb25zdCBvdXRJbmRpY2VzID0gYmFja2VuZC5tYWtlT3V0cHV0KG91dHB1dFNoYXBlLCAnaW50MzInKTtcbiAgICAgICAgICBjb25zdCBvdXRJbmRpY2VzSWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQob3V0SW5kaWNlcy5kYXRhSWQpLmlkO1xuXG4gICAgICAgICAgd2FzbVRvcEsoXG4gICAgICAgICAgICAgIHhJZCwgeFNoYXBlQnl0ZXMsIHguc2hhcGUubGVuZ3RoLCBDcHBEVHlwZVt4LmR0eXBlXSwgaywgc29ydGVkLFxuICAgICAgICAgICAgICBvdXRWYWx1ZXNJZCwgb3V0SW5kaWNlc0lkKTtcblxuICAgICAgICAgIHJldHVybiBbb3V0VmFsdWVzLCBvdXRJbmRpY2VzXTtcbiAgICAgICAgfTtcblxuZXhwb3J0IGNvbnN0IHRvcEtDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogVG9wSyxcbiAgYmFja2VuZE5hbWU6ICd3YXNtJyxcbiAgc2V0dXBGdW5jOiBzZXR1cCxcbiAga2VybmVsRnVuYzogdG9wayBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmMsXG59O1xuIl19