@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
57 lines • 7.76 kB
JavaScript
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { gather_util, GatherNd } from '@tensorflow/tfjs-core';
import { CppDType } from './types';
let wasmGatherNd;
function setup(backend) {
wasmGatherNd = backend.wasm.cwrap(GatherNd, null /*void*/, [
'number',
'number',
'number',
'number',
'number',
'number',
'array',
'number' // outId
]);
}
function gatherNd(args) {
const { backend, inputs } = args;
const { params, indices } = inputs;
const [resultShape, numSlices, sliceSize, strides] = gather_util.prepareAndValidate(params, indices);
const out = backend.makeOutput(resultShape, params.dtype);
if (numSlices === 0) {
return out;
}
const indicesShape = indices.shape;
const sliceRank = indicesShape[indicesShape.length - 1];
const xData = backend.dataIdMap.get(params.dataId);
const xId = xData.id;
const indicesData = backend.dataIdMap.get(indices.dataId);
const indicesId = indicesData.id;
const stridesBytes = new Uint8Array(new Int32Array(strides).buffer);
const outId = backend.dataIdMap.get(out.dataId).id;
wasmGatherNd(xId, CppDType[params.dtype], indicesId, numSlices, sliceRank, sliceSize, stridesBytes, outId);
return out;
}
export const gatherNdConfig = {
kernelName: GatherNd,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: gatherNd
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiR2F0aGVyTmQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2FzbS9zcmMva2VybmVscy9HYXRoZXJOZC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsV0FBVyxFQUFFLFFBQVEsRUFBMkMsTUFBTSx1QkFBdUIsQ0FBQztBQUl0RyxPQUFPLEVBQUMsUUFBUSxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBRWpDLElBQUksWUFHSSxDQUFDO0FBRVQsU0FBUyxLQUFLLENBQUMsT0FBb0I7SUFDakMsWUFBWSxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsS0FBSyxDQUFDLFFBQVEsRUFBRSxJQUFJLENBQUMsUUFBUSxFQUFFO1FBQ3pELFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRO1FBQ1IsUUFBUTtRQUNSLE9BQU87UUFDUCxRQUFRLENBQUcsUUFBUTtLQUNwQixDQUFDLENBQUM7QUFDTCxDQUFDO0FBRUQsU0FBUyxRQUFRLENBQUMsSUFBb0Q7SUFFcEUsTUFBTSxFQUFDLE9BQU8sRUFBRSxNQUFNLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDL0IsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFFakMsTUFBTSxDQUFDLFdBQVcsRUFBRSxTQUFTLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBQyxHQUM5QyxXQUFXLENBQUMsa0JBQWtCLENBQUMsTUFBTSxFQUFFLE9BQU8sQ0FBQyxDQUFDO0lBRXBELE1BQU0sR0FBRyxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQUMsV0FBVyxFQUFFLE1BQU0sQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUMxRCxJQUFJLFNBQVMsS0FBSyxDQUFDLEVBQUU7UUFDbkIsT0FBTyxHQUFHLENBQUM7S0FDWjtJQUVELE1BQU0sWUFBWSxHQUFHLE9BQU8sQ0FBQyxLQUFLLENBQUM7SUFDbkMsTUFBTSxTQUFTLEdBQUcsWUFBWSxDQUFDLFlBQVksQ0FBQyxNQUFNLEdBQUcsQ0FBQyxDQUFDLENBQUM7SUFFeEQsTUFBTSxLQUFLLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQ25ELE1BQU0sR0FBRyxHQUFHLEtBQUssQ0FBQyxFQUFFLENBQUM7SUFDckIsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsT0FBTyxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQzFELE1BQU0sU0FBUyxHQUFHLFdBQVcsQ0FBQyxFQUFFLENBQUM7SUFFakMsTUFBTSxZQUFZLEdBQUcsSUFBSSxVQUFVLENBQUMsSUFBSSxVQUFVLENBQUMsT0FBTyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUM7SUFFcEUsTUFBTSxLQUFLLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUNuRCxZQUFZLENBQ1IsR0FBRyxFQUFFLFFBQVEsQ0FBQyxNQUFNLENBQUMsS0FBSyxDQUFDLEVBQUUsU0FBUyxFQUFFLFNBQVMsRUFBRSxTQUFTLEVBQUUsU0FBUyxFQUN2RSxZQUFZLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFekIsT0FBTyxHQUFHLENBQUM7QUFDYixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sY0FBYyxHQUFpQjtJQUMxQyxVQUFVLEVBQUUsUUFBUTtJQUNwQixXQUFXLEVBQUUsTUFBTTtJQUNuQixTQUFTLEVBQUUsS0FBSztJQUNoQixVQUFVLEVBQUUsUUFBUTtDQUNyQixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTkgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2dhdGhlcl91dGlsLCBHYXRoZXJOZCwgR2F0aGVyTmRJbnB1dHMsIEtlcm5lbENvbmZpZywgVGVuc29ySW5mb30gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtCYWNrZW5kV2FzbX0gZnJvbSAnLi4vYmFja2VuZF93YXNtJztcblxuaW1wb3J0IHtDcHBEVHlwZX0gZnJvbSAnLi90eXBlcyc7XG5cbmxldCB3YXNtR2F0aGVyTmQ6IChcbiAgICB4SWQ6IG51bWJlciwgZHR5cGU6IENwcERUeXBlLCBpbmRpY2VzSWQ6IG51bWJlciwgbnVtU2xpY2VzOiBudW1iZXIsXG4gICAgc2xpY2VSYW5rOiBudW1iZXIsIHNsaWNlU2l6ZTogbnVtYmVyLCBzdHJpZGVzOiBVaW50OEFycmF5LCBvdXRJZDogbnVtYmVyKSA9PlxuICAgIHZvaWQ7XG5cbmZ1bmN0aW9uIHNldHVwKGJhY2tlbmQ6IEJhY2tlbmRXYXNtKTogdm9pZCB7XG4gIHdhc21HYXRoZXJOZCA9IGJhY2tlbmQud2FzbS5jd3JhcChHYXRoZXJOZCwgbnVsbCAvKnZvaWQqLywgW1xuICAgICdudW1iZXInLCAgLy8geElkXG4gICAgJ251bWJlcicsICAvLyBkdHlwZVxuICAgICdudW1iZXInLCAgLy8gaW5kaWNlc0lkXG4gICAgJ251bWJlcicsICAvLyBudW1TbGljZXNcbiAgICAnbnVtYmVyJywgIC8vIHNsaWNlUmFua1xuICAgICdudW1iZXInLCAgLy8gc2xpY2VTaXplXG4gICAgJ2FycmF5JywgICAvLyBzdHJpZGVzXG4gICAgJ251bWJlcicgICAvLyBvdXRJZFxuICBdKTtcbn1cblxuZnVuY3Rpb24gZ2F0aGVyTmQoYXJnczoge2JhY2tlbmQ6IEJhY2tlbmRXYXNtLCBpbnB1dHM6IEdhdGhlck5kSW5wdXRzfSk6XG4gICAgVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtiYWNrZW5kLCBpbnB1dHN9ID0gYXJncztcbiAgY29uc3Qge3BhcmFtcywgaW5kaWNlc30gPSBpbnB1dHM7XG5cbiAgY29uc3QgW3Jlc3VsdFNoYXBlLCBudW1TbGljZXMsIHNsaWNlU2l6ZSwgc3RyaWRlc10gPVxuICAgICAgZ2F0aGVyX3V0aWwucHJlcGFyZUFuZFZhbGlkYXRlKHBhcmFtcywgaW5kaWNlcyk7XG5cbiAgY29uc3Qgb3V0ID0gYmFja2VuZC5tYWtlT3V0cHV0KHJlc3VsdFNoYXBlLCBwYXJhbXMuZHR5cGUpO1xuICBpZiAobnVtU2xpY2VzID09PSAwKSB7XG4gICAgcmV0dXJuIG91dDtcbiAgfVxuXG4gIGNvbnN0IGluZGljZXNTaGFwZSA9IGluZGljZXMuc2hhcGU7XG4gIGNvbnN0IHNsaWNlUmFuayA9IGluZGljZXNTaGFwZVtpbmRpY2VzU2hhcGUubGVuZ3RoIC0gMV07XG5cbiAgY29uc3QgeERhdGEgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQocGFyYW1zLmRhdGFJZCk7XG4gIGNvbnN0IHhJZCA9IHhEYXRhLmlkO1xuICBjb25zdCBpbmRpY2VzRGF0YSA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldChpbmRpY2VzLmRhdGFJZCk7XG4gIGNvbnN0IGluZGljZXNJZCA9IGluZGljZXNEYXRhLmlkO1xuXG4gIGNvbnN0IHN0cmlkZXNCeXRlcyA9IG5ldyBVaW50OEFycmF5KG5ldyBJbnQzMkFycmF5KHN0cmlkZXMpLmJ1ZmZlcik7XG5cbiAgY29uc3Qgb3V0SWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQob3V0LmRhdGFJZCkuaWQ7XG4gIHdhc21HYXRoZXJOZChcbiAgICAgIHhJZCwgQ3BwRFR5cGVbcGFyYW1zLmR0eXBlXSwgaW5kaWNlc0lkLCBudW1TbGljZXMsIHNsaWNlUmFuaywgc2xpY2VTaXplLFxuICAgICAgc3RyaWRlc0J5dGVzLCBvdXRJZCk7XG5cbiAgcmV0dXJuIG91dDtcbn1cblxuZXhwb3J0IGNvbnN0IGdhdGhlck5kQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IEdhdGhlck5kLFxuICBiYWNrZW5kTmFtZTogJ3dhc20nLFxuICBzZXR1cEZ1bmM6IHNldHVwLFxuICBrZXJuZWxGdW5jOiBnYXRoZXJOZFxufTtcbiJdfQ==