@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
61 lines • 9.1 kB
JavaScript
/**
* @license
* Copyright 2022 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { scatter_util, TensorScatterUpdate, util } from '@tensorflow/tfjs-core';
import { CppDType } from './types';
let wasmTensorScatterUpdate;
function setup(backend) {
wasmTensorScatterUpdate =
backend.wasm.cwrap(TensorScatterUpdate, null /*void*/, [
'number',
'number',
'number',
'number',
'number',
'number',
'array',
'number',
'number',
'number', // tensorId
]);
}
function tensorScatterUpdate(args) {
const { backend, inputs, attrs } = args;
const { tensor, indices, updates } = inputs;
const {} = attrs;
const out = backend.makeOutput(tensor.shape, tensor.dtype);
if (util.sizeFromShape(tensor.shape) === 0) {
return out;
}
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = scatter_util.calculateShapes(updates, indices, tensor.shape);
const indicesData = backend.dataIdMap.get(indices.dataId);
const indicesId = indicesData.id;
const updatesData = backend.dataIdMap.get(updates.dataId);
const updatesId = updatesData.id;
const tensorData = backend.dataIdMap.get(tensor.dataId);
const tensorId = tensorData.id;
const stridesBytes = new Uint8Array(new Int32Array(strides).buffer);
const outId = backend.dataIdMap.get(out.dataId).id;
wasmTensorScatterUpdate(indicesId, updatesId, CppDType[updates.dtype], sliceRank, numUpdates, sliceSize, stridesBytes, outputSize, outId, tensorId);
return out;
}
export const tensorScatterUpdateConfig = {
kernelName: TensorScatterUpdate,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: tensorScatterUpdate
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiVGVuc29yU2NhdHRlclVwZGF0ZS5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL1RlbnNvclNjYXR0ZXJVcGRhdGUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUEyQixZQUFZLEVBQWMsbUJBQW1CLEVBQXVELElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBSXpLLE9BQU8sRUFBQyxRQUFRLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFFakMsSUFBSSx1QkFHNEQsQ0FBQztBQUVqRSxTQUFTLEtBQUssQ0FBQyxPQUFvQjtJQUNqQyx1QkFBdUI7UUFDbkIsT0FBTyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsbUJBQW1CLEVBQUUsSUFBSSxDQUFDLFFBQVEsRUFBRTtZQUNyRCxRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVE7WUFDUixPQUFPO1lBQ1AsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRLEVBQUcsV0FBVztTQUN2QixDQUFDLENBQUM7QUFDVCxDQUFDO0FBRUQsU0FBUyxtQkFBbUIsQ0FBQyxJQUk1QjtJQUNDLE1BQU0sRUFBQyxPQUFPLEVBQUUsTUFBTSxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxPQUFPLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDMUMsTUFBTSxFQUFFLEdBQUcsS0FBSyxDQUFDO0lBRWpCLE1BQU0sR0FBRyxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQUMsTUFBTSxDQUFDLEtBQUssRUFBRSxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUM7SUFDM0QsSUFBSSxJQUFJLENBQUMsYUFBYSxDQUFDLE1BQU0sQ0FBQyxLQUFLLENBQUMsS0FBSyxDQUFDLEVBQUU7UUFDMUMsT0FBTyxHQUFHLENBQUM7S0FDWjtJQUVELE1BQU0sRUFBQyxTQUFTLEVBQUUsVUFBVSxFQUFFLFNBQVMsRUFBRSxPQUFPLEVBQUUsVUFBVSxFQUFDLEdBQ3pELFlBQVksQ0FBQyxlQUFlLENBQUMsT0FBTyxFQUFFLE9BQU8sRUFBRSxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUM7SUFFakUsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsT0FBTyxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQzFELE1BQU0sU0FBUyxHQUFHLFdBQVcsQ0FBQyxFQUFFLENBQUM7SUFFakMsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsT0FBTyxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQzFELE1BQU0sU0FBUyxHQUFHLFdBQVcsQ0FBQyxFQUFFLENBQUM7SUFFakMsTUFBTSxVQUFVLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQ3hELE1BQU0sUUFBUSxHQUFHLFVBQVUsQ0FBQyxFQUFFLENBQUM7SUFFL0IsTUFBTSxZQUFZLEdBQUcsSUFBSSxVQUFVLENBQUMsSUFBSSxVQUFVLENBQUMsT0FBTyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUM7SUFFcEUsTUFBTSxLQUFLLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUNuRCx1QkFBdUIsQ0FDbkIsU0FBUyxFQUFFLFNBQVMsRUFBRSxRQUFRLENBQUMsT0FBTyxDQUFDLEtBQUssQ0FBQyxFQUFFLFNBQVMsRUFBRSxVQUFVLEVBQ3BFLFNBQVMsRUFBRSxZQUFZLEVBQUUsVUFBVSxFQUFFLEtBQUssRUFBRSxRQUFRLENBQUMsQ0FBQztJQUUxRCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSx5QkFBeUIsR0FBaUI7SUFDckQsVUFBVSxFQUFFLG1CQUFtQjtJQUMvQixXQUFXLEVBQUUsTUFBTTtJQUNuQixTQUFTLEVBQUUsS0FBSztJQUNoQixVQUFVLEVBQUUsbUJBQTRDO0NBQ3pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMiBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7S2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBzY2F0dGVyX3V0aWwsIFRlbnNvckluZm8sIFRlbnNvclNjYXR0ZXJVcGRhdGUsIFRlbnNvclNjYXR0ZXJVcGRhdGVBdHRycywgVGVuc29yU2NhdHRlclVwZGF0ZUlucHV0cywgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtCYWNrZW5kV2FzbX0gZnJvbSAnLi4vYmFja2VuZF93YXNtJztcblxuaW1wb3J0IHtDcHBEVHlwZX0gZnJvbSAnLi90eXBlcyc7XG5cbmxldCB3YXNtVGVuc29yU2NhdHRlclVwZGF0ZTogKFxuICAgIGluZGljZXNJZDogbnVtYmVyLCB1cGRhdGVzSWQ6IG51bWJlciwgZHR5cGU6IENwcERUeXBlLCBzbGljZVJhbms6IG51bWJlcixcbiAgICBudW1VcGRhdGVzOiBudW1iZXIsIHNsaWNlU2l6ZTogbnVtYmVyLCBzdHJpZGVzOiBVaW50OEFycmF5LFxuICAgIG91dHB1dFNpemU6IG51bWJlciwgb3V0SWQ6IG51bWJlciwgdGVuc29ySWQ6IG51bWJlcikgPT4gdm9pZDtcblxuZnVuY3Rpb24gc2V0dXAoYmFja2VuZDogQmFja2VuZFdhc20pOiB2b2lkIHtcbiAgd2FzbVRlbnNvclNjYXR0ZXJVcGRhdGUgPVxuICAgICAgYmFja2VuZC53YXNtLmN3cmFwKFRlbnNvclNjYXR0ZXJVcGRhdGUsIG51bGwgLyp2b2lkKi8sIFtcbiAgICAgICAgJ251bWJlcicsICAvLyBpbmRpY2VzSWRcbiAgICAgICAgJ251bWJlcicsICAvLyB1cGRhdGVzSWRcbiAgICAgICAgJ251bWJlcicsICAvLyBkdHlwZVxuICAgICAgICAnbnVtYmVyJywgIC8vIHNsaWNlUmFua1xuICAgICAgICAnbnVtYmVyJywgIC8vIG51bVVwZGF0ZXNcbiAgICAgICAgJ251bWJlcicsICAvLyBzbGljZVNpemVcbiAgICAgICAgJ2FycmF5JywgICAvLyBzdHJpZGVzXG4gICAgICAgICdudW1iZXInLCAgLy8gb3V0cHV0U2l6ZVxuICAgICAgICAnbnVtYmVyJywgIC8vIG91dElkXG4gICAgICAgICdudW1iZXInLCAgLy8gdGVuc29ySWRcbiAgICAgIF0pO1xufVxuXG5mdW5jdGlvbiB0ZW5zb3JTY2F0dGVyVXBkYXRlKGFyZ3M6IHtcbiAgYmFja2VuZDogQmFja2VuZFdhc20sXG4gIGlucHV0czogVGVuc29yU2NhdHRlclVwZGF0ZUlucHV0cyxcbiAgYXR0cnM6IFRlbnNvclNjYXR0ZXJVcGRhdGVBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7YmFja2VuZCwgaW5wdXRzLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7dGVuc29yLCBpbmRpY2VzLCB1cGRhdGVzfSA9IGlucHV0cztcbiAgY29uc3Qge30gPSBhdHRycztcblxuICBjb25zdCBvdXQgPSBiYWNrZW5kLm1ha2VPdXRwdXQodGVuc29yLnNoYXBlLCB0ZW5zb3IuZHR5cGUpO1xuICBpZiAodXRpbC5zaXplRnJvbVNoYXBlKHRlbnNvci5zaGFwZSkgPT09IDApIHtcbiAgICByZXR1cm4gb3V0O1xuICB9XG5cbiAgY29uc3Qge3NsaWNlUmFuaywgbnVtVXBkYXRlcywgc2xpY2VTaXplLCBzdHJpZGVzLCBvdXRwdXRTaXplfSA9XG4gICAgICBzY2F0dGVyX3V0aWwuY2FsY3VsYXRlU2hhcGVzKHVwZGF0ZXMsIGluZGljZXMsIHRlbnNvci5zaGFwZSk7XG5cbiAgY29uc3QgaW5kaWNlc0RhdGEgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQoaW5kaWNlcy5kYXRhSWQpO1xuICBjb25zdCBpbmRpY2VzSWQgPSBpbmRpY2VzRGF0YS5pZDtcblxuICBjb25zdCB1cGRhdGVzRGF0YSA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldCh1cGRhdGVzLmRhdGFJZCk7XG4gIGNvbnN0IHVwZGF0ZXNJZCA9IHVwZGF0ZXNEYXRhLmlkO1xuXG4gIGNvbnN0IHRlbnNvckRhdGEgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQodGVuc29yLmRhdGFJZCk7XG4gIGNvbnN0IHRlbnNvcklkID0gdGVuc29yRGF0YS5pZDtcblxuICBjb25zdCBzdHJpZGVzQnl0ZXMgPSBuZXcgVWludDhBcnJheShuZXcgSW50MzJBcnJheShzdHJpZGVzKS5idWZmZXIpO1xuXG4gIGNvbnN0IG91dElkID0gYmFja2VuZC5kYXRhSWRNYXAuZ2V0KG91dC5kYXRhSWQpLmlkO1xuICB3YXNtVGVuc29yU2NhdHRlclVwZGF0ZShcbiAgICAgIGluZGljZXNJZCwgdXBkYXRlc0lkLCBDcHBEVHlwZVt1cGRhdGVzLmR0eXBlXSwgc2xpY2VSYW5rLCBudW1VcGRhdGVzLFxuICAgICAgc2xpY2VTaXplLCBzdHJpZGVzQnl0ZXMsIG91dHB1dFNpemUsIG91dElkLCB0ZW5zb3JJZCk7XG5cbiAgcmV0dXJuIG91dDtcbn1cblxuZXhwb3J0IGNvbnN0IHRlbnNvclNjYXR0ZXJVcGRhdGVDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogVGVuc29yU2NhdHRlclVwZGF0ZSxcbiAgYmFja2VuZE5hbWU6ICd3YXNtJyxcbiAgc2V0dXBGdW5jOiBzZXR1cCxcbiAga2VybmVsRnVuYzogdGVuc29yU2NhdHRlclVwZGF0ZSBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=