UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

61 lines 9.1 kB
/** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { scatter_util, TensorScatterUpdate, util } from '@tensorflow/tfjs-core'; import { CppDType } from './types'; let wasmTensorScatterUpdate; function setup(backend) { wasmTensorScatterUpdate = backend.wasm.cwrap(TensorScatterUpdate, null /*void*/, [ 'number', 'number', 'number', 'number', 'number', 'number', 'array', 'number', 'number', 'number', // tensorId ]); } function tensorScatterUpdate(args) { const { backend, inputs, attrs } = args; const { tensor, indices, updates } = inputs; const {} = attrs; const out = backend.makeOutput(tensor.shape, tensor.dtype); if (util.sizeFromShape(tensor.shape) === 0) { return out; } const { sliceRank, numUpdates, sliceSize, strides, outputSize } = scatter_util.calculateShapes(updates, indices, tensor.shape); const indicesData = backend.dataIdMap.get(indices.dataId); const indicesId = indicesData.id; const updatesData = backend.dataIdMap.get(updates.dataId); const updatesId = updatesData.id; const tensorData = backend.dataIdMap.get(tensor.dataId); const tensorId = tensorData.id; const stridesBytes = new Uint8Array(new Int32Array(strides).buffer); const outId = backend.dataIdMap.get(out.dataId).id; wasmTensorScatterUpdate(indicesId, updatesId, CppDType[updates.dtype], sliceRank, numUpdates, sliceSize, stridesBytes, outputSize, outId, tensorId); return out; } export const tensorScatterUpdateConfig = { kernelName: TensorScatterUpdate, backendName: 'wasm', setupFunc: setup, kernelFunc: tensorScatterUpdate }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiVGVuc29yU2NhdHRlclVwZGF0ZS5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL1RlbnNvclNjYXR0ZXJVcGRhdGUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUEyQixZQUFZLEVBQWMsbUJBQW1CLEVBQXVELElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBSXpLLE9BQU8sRUFBQyxRQUFRLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFFakMsSUFBSSx1QkFHNEQsQ0FBQztBQUVqRSxTQUFTLEtBQUssQ0FBQyxPQUFvQjtJQUNqQyx1QkFBdUI7UUFDbkIsT0FBTyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsbUJBQW1CLEVBQUUsSUFBSSxDQUFDLFFBQVEsRUFBRTtZQUNyRCxRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRO1lBQ1IsUUFBUTtZQUNSLFFBQVE7WUFDUixPQUFPO1lBQ1AsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRLEVBQUcsV0FBVztTQUN2QixDQUFDLENBQUM7QUFDVCxDQUFDO0FBRUQsU0FBUyxtQkFBbUIsQ0FBQyxJQUk1QjtJQUNDLE1BQU0sRUFBQyxPQUFPLEVBQUUsTUFBTSxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxPQUFPLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDMUMsTUFBTSxFQUFFLEdBQUcsS0FBSyxDQUFDO0lBRWpCLE1BQU0sR0FBRyxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQUMsTUFBTSxDQUFDLEtBQUssRUFBRSxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUM7SUFDM0QsSUFBSSxJQUFJLENBQUMsYUFBYSxDQUFDLE1BQU0sQ0FBQyxLQUFLLENBQUMsS0FBSyxDQUFDLEVBQUU7UUFDMUMsT0FBTyxHQUFHLENBQUM7S0FDWjtJQUVELE1BQU0sRUFBQyxTQUFTLEVBQUUsVUFBVSxFQUFFLFNBQVMsRUFBRSxPQUFPLEVBQUUsVUFBVSxFQUFDLEdBQ3pELFlBQVksQ0FBQyxlQUFlLENBQUMsT0FBTyxFQUFFLE9BQU8sRUFBRSxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUM7SUFFakUsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsT0FBTyxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQzFELE1BQU0sU0FBUyxHQUFHLFdBQVcsQ0FBQyxFQUFFLENBQUM7SUFFakMsTUFBTSxXQUFXLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsT0FBTyxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQzFELE1BQU0sU0FBUyxHQUFHLFdBQVcsQ0FBQyxFQUFFLENBQUM7SUFFakMsTUFBTSxVQUFVLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQ3hELE1BQU0sUUFBUSxHQUFHLFVBQVUsQ0FBQyxFQUFFLENBQUM7SUFFL0IsTUFBTSxZQUFZLEdBQUcsSUFBSSxVQUFVLENBQUMsSUFBSSxVQUFVLENBQUMsT0FBTyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUM7SUFFcEUsTUFBTSxLQUFLLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUNuRCx1QkFBdUIsQ0FDbkIsU0FBUyxFQUFFLFNBQVMsRUFBRSxRQUFRLENBQUMsT0FBTyxDQUFDLEtBQUssQ0FBQyxFQUFFLFNBQVMsRUFBRSxVQUFVLEVBQ3BFLFNBQVMsRUFBRSxZQUFZLEVBQUUsVUFBVSxFQUFFLEtBQUssRUFBRSxRQUFRLENBQUMsQ0FBQztJQUUxRCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSx5QkFBeUIsR0FBaUI7SUFDckQsVUFBVSxFQUFFLG1CQUFtQjtJQUMvQixXQUFXLEVBQUUsTUFBTTtJQUNuQixTQUFTLEVBQUUsS0FBSztJQUNoQixVQUFVLEVBQUUsbUJBQTRDO0NBQ3pELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMiBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7S2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBzY2F0dGVyX3V0aWwsIFRlbnNvckluZm8sIFRlbnNvclNjYXR0ZXJVcGRhdGUsIFRlbnNvclNjYXR0ZXJVcGRhdGVBdHRycywgVGVuc29yU2NhdHRlclVwZGF0ZUlucHV0cywgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtCYWNrZW5kV2FzbX0gZnJvbSAnLi4vYmFja2VuZF93YXNtJztcblxuaW1wb3J0IHtDcHBEVHlwZX0gZnJvbSAnLi90eXBlcyc7XG5cbmxldCB3YXNtVGVuc29yU2NhdHRlclVwZGF0ZTogKFxuICAgIGluZGljZXNJZDogbnVtYmVyLCB1cGRhdGVzSWQ6IG51bWJlciwgZHR5cGU6IENwcERUeXBlLCBzbGljZVJhbms6IG51bWJlcixcbiAgICBudW1VcGRhdGVzOiBudW1iZXIsIHNsaWNlU2l6ZTogbnVtYmVyLCBzdHJpZGVzOiBVaW50OEFycmF5LFxuICAgIG91dHB1dFNpemU6IG51bWJlciwgb3V0SWQ6IG51bWJlciwgdGVuc29ySWQ6IG51bWJlcikgPT4gdm9pZDtcblxuZnVuY3Rpb24gc2V0dXAoYmFja2VuZDogQmFja2VuZFdhc20pOiB2b2lkIHtcbiAgd2FzbVRlbnNvclNjYXR0ZXJVcGRhdGUgPVxuICAgICAgYmFja2VuZC53YXNtLmN3cmFwKFRlbnNvclNjYXR0ZXJVcGRhdGUsIG51bGwgLyp2b2lkKi8sIFtcbiAgICAgICAgJ251bWJlcicsICAvLyBpbmRpY2VzSWRcbiAgICAgICAgJ251bWJlcicsICAvLyB1cGRhdGVzSWRcbiAgICAgICAgJ251bWJlcicsICAvLyBkdHlwZVxuICAgICAgICAnbnVtYmVyJywgIC8vIHNsaWNlUmFua1xuICAgICAgICAnbnVtYmVyJywgIC8vIG51bVVwZGF0ZXNcbiAgICAgICAgJ251bWJlcicsICAvLyBzbGljZVNpemVcbiAgICAgICAgJ2FycmF5JywgICAvLyBzdHJpZGVzXG4gICAgICAgICdudW1iZXInLCAgLy8gb3V0cHV0U2l6ZVxuICAgICAgICAnbnVtYmVyJywgIC8vIG91dElkXG4gICAgICAgICdudW1iZXInLCAgLy8gdGVuc29ySWRcbiAgICAgIF0pO1xufVxuXG5mdW5jdGlvbiB0ZW5zb3JTY2F0dGVyVXBkYXRlKGFyZ3M6IHtcbiAgYmFja2VuZDogQmFja2VuZFdhc20sXG4gIGlucHV0czogVGVuc29yU2NhdHRlclVwZGF0ZUlucHV0cyxcbiAgYXR0cnM6IFRlbnNvclNjYXR0ZXJVcGRhdGVBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7YmFja2VuZCwgaW5wdXRzLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7dGVuc29yLCBpbmRpY2VzLCB1cGRhdGVzfSA9IGlucHV0cztcbiAgY29uc3Qge30gPSBhdHRycztcblxuICBjb25zdCBvdXQgPSBiYWNrZW5kLm1ha2VPdXRwdXQodGVuc29yLnNoYXBlLCB0ZW5zb3IuZHR5cGUpO1xuICBpZiAodXRpbC5zaXplRnJvbVNoYXBlKHRlbnNvci5zaGFwZSkgPT09IDApIHtcbiAgICByZXR1cm4gb3V0O1xuICB9XG5cbiAgY29uc3Qge3NsaWNlUmFuaywgbnVtVXBkYXRlcywgc2xpY2VTaXplLCBzdHJpZGVzLCBvdXRwdXRTaXplfSA9XG4gICAgICBzY2F0dGVyX3V0aWwuY2FsY3VsYXRlU2hhcGVzKHVwZGF0ZXMsIGluZGljZXMsIHRlbnNvci5zaGFwZSk7XG5cbiAgY29uc3QgaW5kaWNlc0RhdGEgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQoaW5kaWNlcy5kYXRhSWQpO1xuICBjb25zdCBpbmRpY2VzSWQgPSBpbmRpY2VzRGF0YS5pZDtcblxuICBjb25zdCB1cGRhdGVzRGF0YSA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldCh1cGRhdGVzLmRhdGFJZCk7XG4gIGNvbnN0IHVwZGF0ZXNJZCA9IHVwZGF0ZXNEYXRhLmlkO1xuXG4gIGNvbnN0IHRlbnNvckRhdGEgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQodGVuc29yLmRhdGFJZCk7XG4gIGNvbnN0IHRlbnNvcklkID0gdGVuc29yRGF0YS5pZDtcblxuICBjb25zdCBzdHJpZGVzQnl0ZXMgPSBuZXcgVWludDhBcnJheShuZXcgSW50MzJBcnJheShzdHJpZGVzKS5idWZmZXIpO1xuXG4gIGNvbnN0IG91dElkID0gYmFja2VuZC5kYXRhSWRNYXAuZ2V0KG91dC5kYXRhSWQpLmlkO1xuICB3YXNtVGVuc29yU2NhdHRlclVwZGF0ZShcbiAgICAgIGluZGljZXNJZCwgdXBkYXRlc0lkLCBDcHBEVHlwZVt1cGRhdGVzLmR0eXBlXSwgc2xpY2VSYW5rLCBudW1VcGRhdGVzLFxuICAgICAgc2xpY2VTaXplLCBzdHJpZGVzQnl0ZXMsIG91dHB1dFNpemUsIG91dElkLCB0ZW5zb3JJZCk7XG5cbiAgcmV0dXJuIG91dDtcbn1cblxuZXhwb3J0IGNvbnN0IHRlbnNvclNjYXR0ZXJVcGRhdGVDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogVGVuc29yU2NhdHRlclVwZGF0ZSxcbiAgYmFja2VuZE5hbWU6ICd3YXNtJyxcbiAgc2V0dXBGdW5jOiBzZXR1cCxcbiAga2VybmVsRnVuYzogdGVuc29yU2NhdHRlclVwZGF0ZSBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=