@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
68 lines • 9.17 kB
JavaScript
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { backend_util, Max, util } from '@tensorflow/tfjs-core';
import { permuteAxesAndTranspose } from './kernel_utils';
import { CppDType } from './types';
let wasmMax;
function setup(backend) {
wasmMax = backend.wasm.cwrap(Max, null /*void*/, [
'number',
'number',
'number',
'number', // out_id
]);
}
function max(args) {
const { backend, inputs, attrs } = args;
const { reductionIndices: axis, keepDims } = attrs;
const { x } = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
let inputId = xId;
let input = x;
const { transposed, axes, originalAxes, inputWasTransposed } = permuteAxesAndTranspose(x, axis, backend);
if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
input = transposed;
inputId = transposedId;
}
const inputRank = input.shape.length;
backend_util.assertAxesAreInnerMostDims('max', axes, inputRank);
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(input.shape, axes);
const reduceSize = util.sizeFromShape(reduceShape);
const out = backend.makeOutput(outShape, x.dtype);
if (util.sizeFromShape(input.shape) !== 0) {
const outId = backend.dataIdMap.get(out.dataId).id;
wasmMax(inputId, CppDType[x.dtype], reduceSize, outId);
}
if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}
if (keepDims) {
// reshape
const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes);
out.shape = newShape;
}
return out;
}
export const maxConfig = {
kernelName: Max,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: max
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTWF4LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdhc20vc3JjL2tlcm5lbHMvTWF4LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQTRCLEdBQUcsRUFBbUMsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFJekgsT0FBTyxFQUFDLHVCQUF1QixFQUFDLE1BQU0sZ0JBQWdCLENBQUM7QUFDdkQsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFNBQVMsQ0FBQztBQUVqQyxJQUFJLE9BQ0ksQ0FBQztBQUVULFNBQVMsS0FBSyxDQUFDLE9BQW9CO0lBQ2pDLE9BQU8sR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEtBQUssQ0FBQyxHQUFHLEVBQUUsSUFBSSxDQUFDLFFBQVEsRUFBRTtRQUMvQyxRQUFRO1FBQ1IsUUFBUTtRQUNSLFFBQVE7UUFDUixRQUFRLEVBQUcsU0FBUztLQUNyQixDQUFDLENBQUM7QUFDTCxDQUFDO0FBRUQsU0FBUyxHQUFHLENBQUMsSUFBZ0U7SUFFM0UsTUFBTSxFQUFDLE9BQU8sRUFBRSxNQUFNLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxnQkFBZ0IsRUFBRSxJQUFJLEVBQUUsUUFBUSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBQ2pELE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDbkIsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLEVBQUUsQ0FBQztJQUMvQyxJQUFJLE9BQU8sR0FBRyxHQUFHLENBQUM7SUFDbEIsSUFBSSxLQUFLLEdBQUcsQ0FBQyxDQUFDO0lBRWQsTUFBTSxFQUFDLFVBQVUsRUFBRSxJQUFJLEVBQUUsWUFBWSxFQUFFLGtCQUFrQixFQUFDLEdBQ3RELHVCQUF1QixDQUFDLENBQUMsRUFBRSxJQUFJLEVBQUUsT0FBTyxDQUFDLENBQUM7SUFFOUMsSUFBSSxrQkFBa0IsRUFBRTtRQUN0QixNQUFNLFlBQVksR0FBRyxPQUFPLENBQUMsU0FBUyxDQUFDLEdBQUcsQ0FBQyxVQUFVLENBQUMsTUFBTSxDQUFDLENBQUMsRUFBRSxDQUFDO1FBQ2pFLEtBQUssR0FBRyxVQUFVLENBQUM7UUFDbkIsT0FBTyxHQUFHLFlBQVksQ0FBQztLQUN4QjtJQUVELE1BQU0sU0FBUyxHQUFHLEtBQUssQ0FBQyxLQUFLLENBQUMsTUFBTSxDQUFDO0lBQ3JDLFlBQVksQ0FBQywwQkFBMEIsQ0FBQyxLQUFLLEVBQUUsSUFBSSxFQUFFLFNBQVMsQ0FBQyxDQUFDO0lBQ2hFLE1BQU0sQ0FBQyxRQUFRLEVBQUUsV0FBVyxDQUFDLEdBQ3pCLFlBQVksQ0FBQyx5QkFBeUIsQ0FBQyxLQUFLLENBQUMsS0FBSyxFQUFFLElBQUksQ0FBQyxDQUFDO0lBQzlELE1BQU0sVUFBVSxHQUFHLElBQUksQ0FBQyxhQUFhLENBQUMsV0FBVyxDQUFDLENBQUM7SUFFbkQsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxRQUFRLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO0lBQ2xELElBQUksSUFBSSxDQUFDLGFBQWEsQ0FBQyxLQUFLLENBQUMsS0FBSyxDQUFDLEtBQUssQ0FBQyxFQUFFO1FBQ3pDLE1BQU0sS0FBSyxHQUFHLE9BQU8sQ0FBQyxTQUFTLENBQUMsR0FBRyxDQUFDLEdBQUcsQ0FBQyxNQUFNLENBQUMsQ0FBQyxFQUFFLENBQUM7UUFDbkQsT0FBTyxDQUFDLE9BQU8sRUFBRSxRQUFRLENBQUMsQ0FBQyxDQUFDLEtBQUssQ0FBQyxFQUFFLFVBQVUsRUFBRSxLQUFLLENBQUMsQ0FBQztLQUN4RDtJQUVELElBQUksa0JBQWtCLEVBQUU7UUFDdEIsb0NBQW9DO1FBQ3BDLE9BQU8sQ0FBQyxXQUFXLENBQUMsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDO0tBQ3hDO0lBRUQsSUFBSSxRQUFRLEVBQUU7UUFDWixVQUFVO1FBQ1YsTUFBTSxRQUFRLEdBQUcsWUFBWSxDQUFDLG9CQUFvQixDQUFDLEdBQUcsQ0FBQyxLQUFLLEVBQUUsWUFBWSxDQUFDLENBQUM7UUFDNUUsR0FBRyxDQUFDLEtBQUssR0FBRyxRQUFRLENBQUM7S0FDdEI7SUFFRCxPQUFPLEdBQUcsQ0FBQztBQUNiLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxTQUFTLEdBQWlCO0lBQ3JDLFVBQVUsRUFBRSxHQUFHO0lBQ2YsV0FBVyxFQUFFLE1BQU07SUFDbkIsU0FBUyxFQUFFLEtBQUs7SUFDaEIsVUFBVSxFQUFFLEdBQTRCO0NBQ3pDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxOSBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7YmFja2VuZF91dGlsLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIE1heCwgTWF4QXR0cnMsIE1heElucHV0cywgVGVuc29ySW5mbywgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtCYWNrZW5kV2FzbX0gZnJvbSAnLi4vYmFja2VuZF93YXNtJztcblxuaW1wb3J0IHtwZXJtdXRlQXhlc0FuZFRyYW5zcG9zZX0gZnJvbSAnLi9rZXJuZWxfdXRpbHMnO1xuaW1wb3J0IHtDcHBEVHlwZX0gZnJvbSAnLi90eXBlcyc7XG5cbmxldCB3YXNtTWF4OiAoeElkOiBudW1iZXIsIGR0eXBlOiBudW1iZXIsIHJlZHVjZVNpemU6IG51bWJlciwgb3V0SWQ6IG51bWJlcikgPT5cbiAgICB2b2lkO1xuXG5mdW5jdGlvbiBzZXR1cChiYWNrZW5kOiBCYWNrZW5kV2FzbSk6IHZvaWQge1xuICB3YXNtTWF4ID0gYmFja2VuZC53YXNtLmN3cmFwKE1heCwgbnVsbCAvKnZvaWQqLywgW1xuICAgICdudW1iZXInLCAgLy8geF9pZFxuICAgICdudW1iZXInLCAgLy8gZHR5cGVcbiAgICAnbnVtYmVyJywgIC8vIHJlZHVjZV9zaXplXG4gICAgJ251bWJlcicsICAvLyBvdXRfaWRcbiAgXSk7XG59XG5cbmZ1bmN0aW9uIG1heChhcmdzOiB7YmFja2VuZDogQmFja2VuZFdhc20sIGlucHV0czogTWF4SW5wdXRzLCBhdHRyczogTWF4QXR0cnN9KTpcbiAgICBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2JhY2tlbmQsIGlucHV0cywgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge3JlZHVjdGlvbkluZGljZXM6IGF4aXMsIGtlZXBEaW1zfSA9IGF0dHJzO1xuICBjb25zdCB7eH0gPSBpbnB1dHM7XG4gIGNvbnN0IHhJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldCh4LmRhdGFJZCkuaWQ7XG4gIGxldCBpbnB1dElkID0geElkO1xuICBsZXQgaW5wdXQgPSB4O1xuXG4gIGNvbnN0IHt0cmFuc3Bvc2VkLCBheGVzLCBvcmlnaW5hbEF4ZXMsIGlucHV0V2FzVHJhbnNwb3NlZH0gPVxuICAgICAgcGVybXV0ZUF4ZXNBbmRUcmFuc3Bvc2UoeCwgYXhpcywgYmFja2VuZCk7XG5cbiAgaWYgKGlucHV0V2FzVHJhbnNwb3NlZCkge1xuICAgIGNvbnN0IHRyYW5zcG9zZWRJZCA9IGJhY2tlbmQuZGF0YUlkTWFwLmdldCh0cmFuc3Bvc2VkLmRhdGFJZCkuaWQ7XG4gICAgaW5wdXQgPSB0cmFuc3Bvc2VkO1xuICAgIGlucHV0SWQgPSB0cmFuc3Bvc2VkSWQ7XG4gIH1cblxuICBjb25zdCBpbnB1dFJhbmsgPSBpbnB1dC5zaGFwZS5sZW5ndGg7XG4gIGJhY2tlbmRfdXRpbC5hc3NlcnRBeGVzQXJlSW5uZXJNb3N0RGltcygnbWF4JywgYXhlcywgaW5wdXRSYW5rKTtcbiAgY29uc3QgW291dFNoYXBlLCByZWR1Y2VTaGFwZV0gPVxuICAgICAgYmFja2VuZF91dGlsLmNvbXB1dGVPdXRBbmRSZWR1Y2VTaGFwZXMoaW5wdXQuc2hhcGUsIGF4ZXMpO1xuICBjb25zdCByZWR1Y2VTaXplID0gdXRpbC5zaXplRnJvbVNoYXBlKHJlZHVjZVNoYXBlKTtcblxuICBjb25zdCBvdXQgPSBiYWNrZW5kLm1ha2VPdXRwdXQob3V0U2hhcGUsIHguZHR5cGUpO1xuICBpZiAodXRpbC5zaXplRnJvbVNoYXBlKGlucHV0LnNoYXBlKSAhPT0gMCkge1xuICAgIGNvbnN0IG91dElkID0gYmFja2VuZC5kYXRhSWRNYXAuZ2V0KG91dC5kYXRhSWQpLmlkO1xuICAgIHdhc21NYXgoaW5wdXRJZCwgQ3BwRFR5cGVbeC5kdHlwZV0sIHJlZHVjZVNpemUsIG91dElkKTtcbiAgfVxuXG4gIGlmIChpbnB1dFdhc1RyYW5zcG9zZWQpIHtcbiAgICAvLyBkaXNwb3NlIG9mIHRoZSB0cmFuc3Bvc2VkIHRlbnNvci5cbiAgICBiYWNrZW5kLmRpc3Bvc2VEYXRhKHRyYW5zcG9zZWQuZGF0YUlkKTtcbiAgfVxuXG4gIGlmIChrZWVwRGltcykge1xuICAgIC8vIHJlc2hhcGVcbiAgICBjb25zdCBuZXdTaGFwZSA9IGJhY2tlbmRfdXRpbC5leHBhbmRTaGFwZVRvS2VlcERpbShvdXQuc2hhcGUsIG9yaWdpbmFsQXhlcyk7XG4gICAgb3V0LnNoYXBlID0gbmV3U2hhcGU7XG4gIH1cblxuICByZXR1cm4gb3V0O1xufVxuXG5leHBvcnQgY29uc3QgbWF4Q29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IE1heCxcbiAgYmFja2VuZE5hbWU6ICd3YXNtJyxcbiAgc2V0dXBGdW5jOiBzZXR1cCxcbiAga2VybmVsRnVuYzogbWF4IGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==