@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
63 lines • 11.2 kB
JavaScript
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { backend_util, MaxPoolWithArgmax, util } from '@tensorflow/tfjs-core';
import { CppDType } from './types';
let wasmMaxPoolWithArgmax;
function setup(backend) {
wasmMaxPoolWithArgmax = backend.wasm.cwrap('MaxPoolWithArgmax', null, [
'number',
'number',
'number',
'number',
'boolean',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // padLeft
]);
}
export function maxPoolWithArgmax(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { filterSize, strides, pad, includeBatchInIndex } = attrs;
util.assert(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`);
const dilations = [1, 1];
util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
`Got strides ${strides} and dilations '${dilations}'`);
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
const pooled = backend.makeOutput(convInfo.outShape, x.dtype);
const indexes = backend.makeOutput(convInfo.outShape, 'int32');
wasmMaxPoolWithArgmax(backend.dataIdMap.get(x.dataId).id, backend.dataIdMap.get(pooled.dataId).id, backend.dataIdMap.get(indexes.dataId).id, CppDType[x.dtype], includeBatchInIndex, convInfo.batchSize, convInfo.inChannels, convInfo.inHeight, convInfo.inWidth, convInfo.outHeight, convInfo.outWidth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.effectiveFilterHeight, convInfo.effectiveFilterWidth, convInfo.padInfo.top, convInfo.padInfo.left);
return [pooled, indexes];
}
export const maxPoolWithArgmaxConfig = {
kernelName: MaxPoolWithArgmax,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: maxPoolWithArgmax
};
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"MaxPoolWithArgmax.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-wasm/src/kernels/MaxPoolWithArgmax.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAA4B,iBAAiB,EAA+D,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAInK,OAAO,EAAC,QAAQ,EAAC,MAAM,SAAS,CAAC;AAEjC,IAAI,qBAMsE,CAAC;AAE3E,SAAS,KAAK,CAAC,OAAoB;IACjC,qBAAqB,GAAG,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,mBAAmB,EAAE,IAAI,EAAE;QACpE,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,SAAS;QACT,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ,EAAI,UAAU;KACvB,CAAC,CAAC;AACL,CAAC;AAED,MAAM,UAAU,iBAAiB,CAAC,IAIjC;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAC,GAAG,MAAM,CAAC;IACnB,MAAM,EAAC,UAAU,EAAE,OAAO,EAAE,GAAG,EAAE,mBAAmB,EAAC,GAAG,KAAK,CAAC;IAE9D,IAAI,CAAC,MAAM,CACP,CAAC,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,EACpB,GAAG,EAAE,CAAC,uDACF,CAAC,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;IAC3B,MAAM,SAAS,GAAqB,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;IAC3C,IAAI,CAAC,MAAM,CACP,YAAY,CAAC,8BAA8B,CAAC,OAAO,EAAE,SAAS,CAAC,EAC/D,GAAG,EAAE,CAAC,2DAA2D;QAC7D,eAAe,OAAO,mBAAmB,SAAS,GAAG,CAAC,CAAC;IAE/D,MAAM,QAAQ,GAAG,YAAY,CAAC,iBAAiB,CAC3C,CAAC,CAAC,KAAyC,EAAE,UAAU,EAAE,OAAO,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EACxE,GAAG,CAAC,CAAC;IAET,MAAM,MAAM,GAAG,OAAO,CAAC,UAAU,CAAC,QAAQ,CAAC,QAAQ,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;IAC9D,MAAM,OAAO,GAAG,OAAO,CAAC,UAAU,CAAC,QAAQ,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC;IAE/D,qBAAqB,CACjB,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,EAAE,EAClC,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,EAAE,EACvC,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,EAAE,EACxC,QAAQ,CAAC,CAAC,CAAC,KAAK,CAAC,EACjB,mBAAmB,EACnB,QAAQ,CAAC,SAAS,EAClB,QAAQ,CAAC,UAAU,EACnB,QAAQ,CAAC,QAAQ,EACjB,QAAQ,CAAC,OAAO,EAChB,QAAQ,CAAC,SAAS,EAClB,QAAQ,CAAC,QAAQ,EACjB,QAAQ,CAAC,YAAY,EACrB,QAAQ,CAAC,WAAW,EACpB,QAAQ,CAAC,cAAc,EACvB,QAAQ,CAAC,aAAa,EACtB,QAAQ,CAAC,qBAAqB,EAC9B,QAAQ,CAAC,oBAAoB,EAC7B,QAAQ,CAAC,OAAO,CAAC,GAAG,EACpB,QAAQ,CAAC,OAAO,CAAC,IAAI,CACxB,CAAC;IACF,OAAO,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;AAC3B,CAAC;AAED,MAAM,CAAC,MAAM,uBAAuB,GAAiB;IACnD,UAAU,EAAE,iBAAiB;IAC7B,WAAW,EAAE,MAAM;IACnB,SAAS,EAAE,KAAK;IAChB,UAAU,EAAE,iBAA0C;CACvD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, KernelConfig, KernelFunc, MaxPoolWithArgmax, MaxPoolWithArgmaxAttrs, MaxPoolWithArgmaxInputs, TensorInfo, util} from '@tensorflow/tfjs-core';\n\nimport {BackendWasm} from '../backend_wasm';\n\nimport {CppDType} from './types';\n\nlet wasmMaxPoolWithArgmax: (\n    xId: number, pooledId: number, indexesId: number, dtype: number,\n    includeBatchIndex: boolean, batchSize: number, channelSize: number,\n    inHeight: number, inWidth: number, outHeight: number, outWidth: number,\n    strideHeight: number, strideWidth: number, dilationHeight: number,\n    dilationWidth: number, effectiveFilterHeight: number,\n    effectiveFilterWidth: number, padTop: number, padLeft: number) => void;\n\nfunction setup(backend: BackendWasm) {\n  wasmMaxPoolWithArgmax = backend.wasm.cwrap('MaxPoolWithArgmax', null, [\n    'number',   // xId\n    'number',   // pooledId\n    'number',   // indexesId\n    'number',   // dtype\n    'boolean',  // includeBatchIndex\n    'number',   // batchSize\n    'number',   // channelSize\n    'number',   // inHeight\n    'number',   // inWidth\n    'number',   // outHeight\n    'number',   // outWidth\n    'number',   // strideHeight\n    'number',   // strideWidth\n    'number',   // dilationHeight\n    'number',   // dilationWidth\n    'number',   // effectiveFilterHeight\n    'number',   // effectiveFilterWidth\n    'number',   // padTop\n    'number',   // padLeft\n  ]);\n}\n\nexport function maxPoolWithArgmax(args: {\n  inputs: MaxPoolWithArgmaxInputs,\n  attrs: MaxPoolWithArgmaxAttrs,\n  backend: BackendWasm,\n}): TensorInfo[] {\n  const {inputs, backend, attrs} = args;\n  const {x} = inputs;\n  const {filterSize, strides, pad, includeBatchInIndex} = attrs;\n\n  util.assert(\n      x.shape.length === 4,\n      () => `Error in maxPool: input must be rank 4 but got rank ${\n          x.shape.length}.`);\n  const dilations: [number, number] = [1, 1];\n  util.assert(\n      backend_util.eitherStridesOrDilationsAreOne(strides, dilations),\n      () => 'Error in maxPool: Either strides or dilations must be 1. ' +\n          `Got strides ${strides} and dilations '${dilations}'`);\n\n  const convInfo = backend_util.computePool2DInfo(\n      x.shape as [number, number, number, number], filterSize, strides, [1, 1],\n      pad);\n\n  const pooled = backend.makeOutput(convInfo.outShape, x.dtype);\n  const indexes = backend.makeOutput(convInfo.outShape, 'int32');\n\n  wasmMaxPoolWithArgmax(\n      backend.dataIdMap.get(x.dataId).id,\n      backend.dataIdMap.get(pooled.dataId).id,\n      backend.dataIdMap.get(indexes.dataId).id,\n      CppDType[x.dtype],\n      includeBatchInIndex,\n      convInfo.batchSize,\n      convInfo.inChannels,\n      convInfo.inHeight,\n      convInfo.inWidth,\n      convInfo.outHeight,\n      convInfo.outWidth,\n      convInfo.strideHeight,\n      convInfo.strideWidth,\n      convInfo.dilationHeight,\n      convInfo.dilationWidth,\n      convInfo.effectiveFilterHeight,\n      convInfo.effectiveFilterWidth,\n      convInfo.padInfo.top,\n      convInfo.padInfo.left,\n  );\n  return [pooled, indexes];\n}\n\nexport const maxPoolWithArgmaxConfig: KernelConfig = {\n  kernelName: MaxPoolWithArgmax,\n  backendName: 'wasm',\n  setupFunc: setup,\n  kernelFunc: maxPoolWithArgmax as unknown as KernelFunc\n};\n"]}