@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
82 lines • 12 kB
JavaScript
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { backend_util, MaxPool, util } from '@tensorflow/tfjs-core';
let wasmMaxPool;
function setup(backend) {
wasmMaxPool = backend.wasm.cwrap(MaxPool, null /* void */, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // outId
]);
}
function maxPool(args) {
const { inputs, attrs, backend } = args;
const x = inputs.x;
const xId = backend.dataIdMap.get(x.dataId).id;
// TF API supports int32 input. CPU and WebGL backend also support int32
// input. WASM backend doesn't support it because it uses xnnpack which only
// supports float32.
//
// Add the following assert only for the WASM backend instead of at core op
// level.
//
// TODO: add support for int32 input.
util.assert(x.dtype === 'float32', () => `Error in MaxPool: only float32 input is supported. Got ${x.dtype}.`);
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const padTop = convInfo.padInfo.top;
const padRight = convInfo.padInfo.right;
const padBottom = convInfo.padInfo.bottom;
const padLeft = convInfo.padInfo.left;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const inputChannels = convInfo.inChannels;
const outputChannels = convInfo.outChannels;
if (convInfo.dataFormat !== 'channelsLast') {
throw new Error(`wasm backend does not support dataFormat:'` +
`${convInfo.dataFormat}'. Please use 'channelsLast'.`);
}
const out = backend.makeOutput(convInfo.outShape, 'float32');
const outId = backend.dataIdMap.get(out.dataId).id;
wasmMaxPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId);
return out;
}
export const maxPoolConfig = {
kernelName: MaxPool,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: maxPool
};
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"MaxPool.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-wasm/src/kernels/MaxPool.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAA4B,OAAO,EAAyC,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAInI,IAAI,WAKqE,CAAC;AAE1E,SAAS,KAAK,CAAC,OAAoB;IACjC,WAAW,GAAG,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,OAAO,EAAE,IAAI,CAAC,UAAU,EAAE;QACzD,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ,EAAG,QAAQ;KACpB,CAAC,CAAC;AACL,CAAC;AAED,SAAS,OAAO,CACZ,IAAwE;IAC1E,MAAM,EAAC,MAAM,EAAE,KAAK,EAAE,OAAO,EAAC,GAAG,IAAI,CAAC;IAEtC,MAAM,CAAC,GAAG,MAAM,CAAC,CAAa,CAAC;IAC/B,MAAM,GAAG,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAE/C,wEAAwE;IACxE,4EAA4E;IAC5E,oBAAoB;IACpB,EAAE;IACF,2EAA2E;IAC3E,SAAS;IACT,EAAE;IACF,qCAAqC;IACrC,IAAI,CAAC,MAAM,CACP,CAAC,CAAC,KAAK,KAAK,SAAS,EACrB,GAAG,EAAE,CACD,0DAA0D,CAAC,CAAC,KAAK,GAAG,CAAC,CAAC;IAE9E,MAAM,EAAC,UAAU,EAAE,OAAO,EAAE,GAAG,EAAE,eAAe,EAAC,GAAG,KAAK,CAAC;IAC1D,MAAM,QAAQ,GAAG,YAAY,CAAC,iBAAiB,CAC3C,CAAC,CAAC,KAAK,EAAE,UAAU,EAAE,OAAO,EAAE,CAAC,CAAC,eAAe,EAAE,GAAG,EAAE,eAAe,CAAC,CAAC;IAE3E,MAAM,YAAY,GAAG,QAAQ,CAAC,YAAY,CAAC;IAC3C,MAAM,WAAW,GAAG,QAAQ,CAAC,WAAW,CAAC;IACzC,MAAM,MAAM,GAAG,QAAQ,CAAC,OAAO,CAAC,GAAG,CAAC;IACpC,MAAM,QAAQ,GAAG,QAAQ,CAAC,OAAO,CAAC,KAAK,CAAC;IACxC,MAAM,SAAS,GAAG,QAAQ,CAAC,OAAO,CAAC,MAAM,CAAC;IAC1C,MAAM,OAAO,GAAG,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC;IACtC,MAAM,cAAc,GAAG,QAAQ,CAAC,cAAc,CAAC;IAC/C,MAAM,aAAa,GAAG,QAAQ,CAAC,aAAa,CAAC;IAC7C,MAAM,YAAY,GAAG,QAAQ,CAAC,YAAY,CAAC;IAC3C,MAAM,WAAW,GAAG,QAAQ,CAAC,WAAW,CAAC;IACzC,MAAM,aAAa,GAAG,QAAQ,CAAC,UAAU,CAAC;IAC1C,MAAM,cAAc,GAAG,QAAQ,CAAC,WAAW,CAAC;IAE5C,IAAI,QAAQ,CAAC,UAAU,KAAK,cAAc,EAAE;QAC1C,MAAM,IAAI,KAAK,CACX,4CAA4C;YAC5C,GAAG,QAAQ,CAAC,UAAU,+BAA+B,CAAC,CAAC;KAC5D;IAED,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,CAAC,QAAQ,CAAC,QAAQ,EAAE,SAAS,CAAC,CAAC;IAC7D,MAAM,KAAK,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAEnD,WAAW,CACP,GAAG,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,YAAY,EAAE,WAAW,EAClE,MAAM,EAAE,QAAQ,EAAE,SAAS,EAAE,OAAO,EAAE,cAAc,EAAE,aAAa,EACnE,YAAY,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,EAAE,KAAK,CAAC,CAAC;IACrE,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,CAAC,MAAM,aAAa,GAAiB;IACzC,UAAU,EAAE,OAAO;IACnB,WAAW,EAAE,MAAM;IACnB,SAAS,EAAE,KAAK;IAChB,UAAU,EAAE,OAAgC;CAC7C,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, KernelConfig, KernelFunc, MaxPool, MaxPoolAttrs, MaxPoolInputs, Tensor4D, util} from '@tensorflow/tfjs-core';\n\nimport {BackendWasm} from '../backend_wasm';\n\nlet wasmMaxPool: (\n    xId: number, batchSize: number, inputHeight: number, inputWidth: number,\n    filterHeight: number, filterWidth: number, padTop: number, padRight: number,\n    padBottom: number, padLeft: number, dilationHeight: number,\n    dilationWidth: number, strideHeight: number, strideWidth: number,\n    inputChannels: number, outputChannels: number, outId: number) => void;\n\nfunction setup(backend: BackendWasm) {\n  wasmMaxPool = backend.wasm.cwrap(MaxPool, null /* void */, [\n    'number',  // xId\n    'number',  // batchSize\n    'number',  // inputHeight\n    'number',  // inputWidth\n    'number',  // filterHeight\n    'number',  // filterWidth\n    'number',  // padTop\n    'number',  // padRight\n    'number',  // padBottom\n    'number',  // padLeft\n    'number',  // dilationHeight\n    'number',  // dilationWidth\n    'number',  // strideHeight\n    'number',  // strideWidth\n    'number',  // inputChannels\n    'number',  // outputChannels\n    'number',  // outId\n  ]);\n}\n\nfunction maxPool(\n    args: {inputs: MaxPoolInputs, backend: BackendWasm, attrs: MaxPoolAttrs}) {\n  const {inputs, attrs, backend} = args;\n\n  const x = inputs.x as Tensor4D;\n  const xId = backend.dataIdMap.get(x.dataId).id;\n\n  // TF API supports int32 input. CPU and WebGL backend also support int32\n  // input. WASM backend doesn't support it because it uses xnnpack which only\n  // supports float32.\n  //\n  // Add the following assert only for the WASM backend instead of at core op\n  // level.\n  //\n  // TODO: add support for int32 input.\n  util.assert(\n      x.dtype === 'float32',\n      () =>\n          `Error in MaxPool: only float32 input is supported. Got ${x.dtype}.`);\n\n  const {filterSize, strides, pad, dimRoundingMode} = attrs;\n  const convInfo = backend_util.computePool2DInfo(\n      x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);\n\n  const filterHeight = convInfo.filterHeight;\n  const filterWidth = convInfo.filterWidth;\n  const padTop = convInfo.padInfo.top;\n  const padRight = convInfo.padInfo.right;\n  const padBottom = convInfo.padInfo.bottom;\n  const padLeft = convInfo.padInfo.left;\n  const dilationHeight = convInfo.dilationHeight;\n  const dilationWidth = convInfo.dilationWidth;\n  const strideHeight = convInfo.strideHeight;\n  const strideWidth = convInfo.strideWidth;\n  const inputChannels = convInfo.inChannels;\n  const outputChannels = convInfo.outChannels;\n\n  if (convInfo.dataFormat !== 'channelsLast') {\n    throw new Error(\n        `wasm backend does not support dataFormat:'` +\n        `${convInfo.dataFormat}'. Please use 'channelsLast'.`);\n  }\n\n  const out = backend.makeOutput(convInfo.outShape, 'float32');\n  const outId = backend.dataIdMap.get(out.dataId).id;\n\n  wasmMaxPool(\n      xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth,\n      padTop, padRight, padBottom, padLeft, dilationHeight, dilationWidth,\n      strideHeight, strideWidth, inputChannels, outputChannels, outId);\n  return out;\n}\n\nexport const maxPoolConfig: KernelConfig = {\n  kernelName: MaxPool,\n  backendName: 'wasm',\n  setupFunc: setup,\n  kernelFunc: maxPool as unknown as KernelFunc\n};\n"]}