UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

54 lines 8.3 kB
/** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, util } from '@tensorflow/tfjs-core'; import { CppDType } from './types'; export function createBinaryKernelConfig(kernelName, supportsFullBroadcast, dtype) { let wasmFunc; function setupFunc(backend) { wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [ 'number', 'array', 'number', 'number', 'array', 'number', 'number', 'number' // out_id ]); } function kernelFunc(args) { const { backend, inputs } = args; const { a, b } = inputs; const aId = backend.dataIdMap.get(a.dataId).id; const bId = backend.dataIdMap.get(b.dataId).id; const outputType = dtype != null ? dtype : a.dtype; const newShape = backend_util.assertAndGetBroadcastShape(a.shape, b.shape); const out = backend.makeOutput(newShape, outputType); // Short-circuit zero-sized tensors. if (util.sizeFromShape(newShape) === 0) { return out; } const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer); const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer); const outId = backend.dataIdMap.get(out.dataId).id; const kernelFunc = () => wasmFunc(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, CppDType[a.dtype], outId); kernelFunc(); return out; } return { kernelName, backendName: 'wasm', setupFunc, kernelFunc }; } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYmluYXJ5X2tlcm5lbC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13YXNtL3NyYy9rZXJuZWxzL2JpbmFyeV9rZXJuZWwudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBb0QsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJM0csT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUVqQyxNQUFNLFVBQVUsd0JBQXdCLENBQ3BDLFVBQWtCLEVBQUUscUJBQThCLEVBQ2xELEtBQWdCO0lBQ2xCLElBQUksUUFHUSxDQUFDO0lBRWIsU0FBUyxTQUFTLENBQUMsT0FBb0I7UUFDckMsUUFBUSxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsS0FBSyxDQUFDLFVBQVUsRUFBRSxJQUFJLENBQUMsVUFBVSxFQUFFO1lBQ3pELFFBQVE7WUFDUixPQUFPO1lBQ1AsUUFBUTtZQUNSLFFBQVE7WUFDUixPQUFPO1lBQ1AsUUFBUTtZQUNSLFFBQVE7WUFDUixRQUFRLENBQUcsU0FBUztTQUNyQixDQUFDLENBQUM7SUFDTCxDQUFDO0lBRUQsU0FBUyxVQUFVLENBQUMsSUFBa0Q7UUFFcEUsTUFBTSxFQUFDLE9BQU8sRUFBRSxNQUFNLEVBQUMsR0FBRyxJQUFJLENBQUM7UUFDL0IsTUFBTSxFQUFDLENBQUMsRUFBRSxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7UUFDdEIsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztRQUMvQyxNQUFNLEdBQUcsR0FBRyxPQUFPLENBQUMsU0FBUyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUMsRUFBRSxDQUFDO1FBRS9DLE1BQU0sVUFBVSxHQUFHLEtBQUssSUFBSSxJQUFJLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQztRQUNuRCxNQUFNLFFBQVEsR0FBRyxZQUFZLENBQUMsMEJBQTBCLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUM7UUFDM0UsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxRQUFRLEVBQUUsVUFBVSxDQUFDLENBQUM7UUFFckQsb0NBQW9DO1FBQ3BDLElBQUksSUFBSSxDQUFDLGFBQWEsQ0FBQyxRQUFRLENBQUMsS0FBSyxDQUFDLEVBQUU7WUFDdEMsT0FBTyxHQUFHLENBQUM7U0FDWjtRQUVELE1BQU0sV0FBVyxHQUFHLElBQUksVUFBVSxDQUFDLElBQUksVUFBVSxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQztRQUNuRSxNQUFNLFdBQVcsR0FBRyxJQUFJLFVBQVUsQ0FBQyxJQUFJLFVBQVUsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUM7UUFDbkUsTUFBTSxLQUFLLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsR0FBRyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztRQUNuRCxNQUFNLFVBQVUsR0FBRyxHQUFHLEVBQUUsQ0FBQyxRQUFRLENBQzdCLEdBQUcsRUFBRSxXQUFXLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxNQUFNLEVBQUUsR0FBRyxFQUFFLFdBQVcsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFDbEUsUUFBUSxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsRUFBRSxLQUFLLENBQUMsQ0FBQztRQUU5QixVQUFVLEVBQUUsQ0FBQztRQUNiLE9BQU8sR0FBRyxDQUFDO0lBQ2IsQ0FBQztJQUVELE9BQU8sRUFBQyxVQUFVLEVBQUUsV0FBVyxFQUFFLE1BQU0sRUFBRSxTQUFTLEVBQUUsVUFBVSxFQUFDLENBQUM7QUFDbEUsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE5IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIEJpbmFyeUlucHV0cywgRGF0YVR5cGUsIEtlcm5lbENvbmZpZywgVGVuc29ySW5mbywgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtCYWNrZW5kV2FzbX0gZnJvbSAnLi4vYmFja2VuZF93YXNtJztcblxuaW1wb3J0IHtDcHBEVHlwZX0gZnJvbSAnLi90eXBlcyc7XG5cbmV4cG9ydCBmdW5jdGlvbiBjcmVhdGVCaW5hcnlLZXJuZWxDb25maWcoXG4gICAga2VybmVsTmFtZTogc3RyaW5nLCBzdXBwb3J0c0Z1bGxCcm9hZGNhc3Q6IGJvb2xlYW4sXG4gICAgZHR5cGU/OiBEYXRhVHlwZSk6IEtlcm5lbENvbmZpZyB7XG4gIGxldCB3YXNtRnVuYzpcbiAgICAgIChhSWQ6IG51bWJlciwgYVNoYXBlOiBVaW50OEFycmF5LCBhU2hhcGVMZW46IG51bWJlciwgYklkOiBudW1iZXIsXG4gICAgICAgYlNoYXBlOiBVaW50OEFycmF5LCBiU2hhcGVMZW46IG51bWJlciwgZHR5cGU6IG51bWJlciwgb3V0SWQ6IG51bWJlcikgPT5cbiAgICAgICAgICB2b2lkO1xuXG4gIGZ1bmN0aW9uIHNldHVwRnVuYyhiYWNrZW5kOiBCYWNrZW5kV2FzbSk6IHZvaWQge1xuICAgIHdhc21GdW5jID0gYmFja2VuZC53YXNtLmN3cmFwKGtlcm5lbE5hbWUsIG51bGwgLyogdm9pZCAqLywgW1xuICAgICAgJ251bWJlcicsICAvLyBhX2lkLFxuICAgICAgJ2FycmF5JywgICAvLyBhX3NoYXBlXG4gICAgICAnbnVtYmVyJywgIC8vIGFfc2hhcGUubGVuZ3RoXG4gICAgICAnbnVtYmVyJywgIC8vIGJfaWRcbiAgICAgICdhcnJheScsICAgLy8gYl9zaGFwZVxuICAgICAgJ251bWJlcicsICAvLyBiX3NoYXBlLmxlbmd0aFxuICAgICAgJ251bWJlcicsICAvLyBkdHlwZVxuICAgICAgJ251bWJlcicgICAvLyBvdXRfaWRcbiAgICBdKTtcbiAgfVxuXG4gIGZ1bmN0aW9uIGtlcm5lbEZ1bmMoYXJnczoge2JhY2tlbmQ6IEJhY2tlbmRXYXNtLCBpbnB1dHM6IEJpbmFyeUlucHV0c30pOlxuICAgICAgVGVuc29ySW5mbyB7XG4gICAgY29uc3Qge2JhY2tlbmQsIGlucHV0c30gPSBhcmdzO1xuICAgIGNvbnN0IHthLCBifSA9IGlucHV0cztcbiAgICBjb25zdCBhSWQgPSBiYWNrZW5kLmRhdGFJZE1hcC5nZXQoYS5kYXRhSWQpLmlkO1xuICAgIGNvbnN0IGJJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldChiLmRhdGFJZCkuaWQ7XG5cbiAgICBjb25zdCBvdXRwdXRUeXBlID0gZHR5cGUgIT0gbnVsbCA/IGR0eXBlIDogYS5kdHlwZTtcbiAgICBjb25zdCBuZXdTaGFwZSA9IGJhY2tlbmRfdXRpbC5hc3NlcnRBbmRHZXRCcm9hZGNhc3RTaGFwZShhLnNoYXBlLCBiLnNoYXBlKTtcbiAgICBjb25zdCBvdXQgPSBiYWNrZW5kLm1ha2VPdXRwdXQobmV3U2hhcGUsIG91dHB1dFR5cGUpO1xuXG4gICAgLy8gU2hvcnQtY2lyY3VpdCB6ZXJvLXNpemVkIHRlbnNvcnMuXG4gICAgaWYgKHV0aWwuc2l6ZUZyb21TaGFwZShuZXdTaGFwZSkgPT09IDApIHtcbiAgICAgIHJldHVybiBvdXQ7XG4gICAgfVxuXG4gICAgY29uc3QgYVNoYXBlQnl0ZXMgPSBuZXcgVWludDhBcnJheShuZXcgSW50MzJBcnJheShhLnNoYXBlKS5idWZmZXIpO1xuICAgIGNvbnN0IGJTaGFwZUJ5dGVzID0gbmV3IFVpbnQ4QXJyYXkobmV3IEludDMyQXJyYXkoYi5zaGFwZSkuYnVmZmVyKTtcbiAgICBjb25zdCBvdXRJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldChvdXQuZGF0YUlkKS5pZDtcbiAgICBjb25zdCBrZXJuZWxGdW5jID0gKCkgPT4gd2FzbUZ1bmMoXG4gICAgICAgIGFJZCwgYVNoYXBlQnl0ZXMsIGEuc2hhcGUubGVuZ3RoLCBiSWQsIGJTaGFwZUJ5dGVzLCBiLnNoYXBlLmxlbmd0aCxcbiAgICAgICAgQ3BwRFR5cGVbYS5kdHlwZV0sIG91dElkKTtcblxuICAgIGtlcm5lbEZ1bmMoKTtcbiAgICByZXR1cm4gb3V0O1xuICB9XG5cbiAgcmV0dXJuIHtrZXJuZWxOYW1lLCBiYWNrZW5kTmFtZTogJ3dhc20nLCBzZXR1cEZ1bmMsIGtlcm5lbEZ1bmN9O1xufVxuIl19