UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

79 lines 12.5 kB
/** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { _FusedMatMul, broadcast_util } from '@tensorflow/tfjs-core'; import { FusableActivation } from './types'; let wasmFusedMatMul; function setup(backend) { wasmFusedMatMul = backend.wasm.cwrap(_FusedMatMul, null /* void */, [ 'number', 'array', 'number', 'number', 'array', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number' // out_id ]); } function fusedBatchMatMul(args) { const { inputs, backend, attrs } = args; const { a, b, bias, preluActivationWeights } = inputs; if (a.dtype !== 'float32' || b.dtype !== 'float32') { throw new Error(`_FusedMatMul for non non-float32 tensors not yet supported.`); } const { transposeA, transposeB, activation, leakyreluAlpha } = attrs; const aId = backend.dataIdMap.get(a.dataId).id; const bId = backend.dataIdMap.get(b.dataId).id; let biasId = 0; if (bias != null) { const biasData = backend.dataIdMap.get(bias.dataId); if (biasData.shape.length !== 1) { throw new Error(`_FusedMatMul only supports rank-1 bias but got ` + `rank ${biasData.shape.length}.`); } biasId = biasData.id; } const preluActivationWeightsId = preluActivationWeights == null ? 0 : backend.dataIdMap.get(preluActivationWeights.dataId).id; const fusedActivation = FusableActivation[activation]; if (fusedActivation == null) { throw new Error(`${activation} activation not yet supported for FusedConv2D ` + `in the wasm backend.`); } const leftDim = transposeA ? a.shape[2] : a.shape[1]; const rightDim = transposeB ? b.shape[1] : b.shape[2]; const batchDims = broadcast_util.assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2)); const out = backend.makeOutput([...batchDims, leftDim, rightDim], a.dtype); const outId = backend.dataIdMap.get(out.dataId).id; const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer); const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer); wasmFusedMatMul(aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, transposeA, transposeB, fusedActivation, biasId, preluActivationWeightsId, leakyreluAlpha || 0, outId); return out; } export const _fusedMatMulConfig = { kernelName: _FusedMatMul, backendName: 'wasm', setupFunc: setup, kernelFunc: fusedBatchMatMul }; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"_FusedMatMul.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-wasm/src/kernels/_FusedMatMul.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAyC,cAAc,EAA2B,MAAM,uBAAuB,CAAC;AAIpI,OAAO,EAAC,iBAAiB,EAAC,MAAM,SAAS,CAAC;AAE1C,IAAI,eAKQ,CAAC;AAEb,SAAS,KAAK,CAAC,OAAoB;IACjC,eAAe,GAAG,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,YAAY,EAAE,IAAI,CAAC,UAAU,EAAE;QAClE,QAAQ;QACR,OAAO;QACP,QAAQ;QACR,QAAQ;QACR,OAAO;QACP,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ,CAAG,SAAS;KACrB,CAAC,CAAC;AACL,CAAC;AAED,SAAS,gBAAgB,CAAC,IAIzB;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAE,CAAC,EAAE,IAAI,EAAE,sBAAsB,EAAC,GAAG,MAAM,CAAC;IAEpD,IAAI,CAAC,CAAC,KAAK,KAAK,SAAS,IAAI,CAAC,CAAC,KAAK,KAAK,SAAS,EAAE;QAClD,MAAM,IAAI,KAAK,CACX,6DAA6D,CAAC,CAAC;KACpE;IAED,MAAM,EAAC,UAAU,EAAE,UAAU,EAAE,UAAU,EAAE,cAAc,EAAC,GAAG,KAAK,CAAC;IACnE,MAAM,GAAG,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAC/C,MAAM,GAAG,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAE/C,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,IAAI,IAAI,IAAI,IAAI,EAAE;QAChB,MAAM,QAAQ,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;QACpD,IAAI,QAAQ,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE;YAC/B,MAAM,IAAI,KAAK,CACX,iDAAiD;gBACjD,QAAQ,QAAQ,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;SACvC;QACD,MAAM,GAAG,QAAQ,CAAC,EAAE,CAAC;KACtB;IACD,MAAM,wBAAwB,GAAG,sBAAsB,IAAI,IAAI,CAAC,CAAC;QAC7D,CAAC,CAAC,CAAC;QACH,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,sBAAsB,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAC5D,MAAM,eAAe,GACjB,iBAAiB,CAAC,UAC8B,CAAC,CAAC;IACtD,IAAI,eAAe,IAAI,IAAI,EAAE;QAC3B,MAAM,IAAI,KAAK,CACX,GAAG,UAAU,gDAAgD;YAC7D,sBAAsB,CAAC,CAAC;KAC7B;IAED,MAAM,OAAO,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IACrD,MAAM,QAAQ,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IACtD,MAAM,SAAS,GAAG,cAAc,CAAC,0BAA0B,CACvD,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;IAEhD,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,CAAC,CAAC,GAAG,SAAS,EAAE,OAAO,EAAE,QAAQ,CAAC,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;IAC3E,MAAM,KAAK,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAEnD,MAAM,WAAW,GAAG,IAAI,UAAU,CAAC,IAAI,UAAU,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,CAAC;IACnE,MAAM,WAAW,GAAG,IAAI,UAAU,CAAC,IAAI,UAAU,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,CAAC;IAEnE,eAAe,CACX,GAAG,EAAE,WAAW,EAAE,CAAC,CAAC,KAAK,CAAC,MAAM,EAAE,GAAG,EAAE,WAAW,EAAE,CAAC,CAAC,KAAK,CAAC,MAAM,EAClE,UAAU,EAAE,UAAU,EAAE,eAAe,EAAE,MAAM,EAAE,wBAAwB,EACzE,cAAc,IAAI,CAAC,EAAE,KAAK,CAAC,CAAC;IAEhC,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,CAAC,MAAM,kBAAkB,GAAiB;IAC9C,UAAU,EAAE,YAAY;IACxB,WAAW,EAAE,MAAM;IACnB,SAAS,EAAE,KAAK;IAChB,UAAU,EAAE,gBAAyC;CACtD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {_FusedMatMul, _FusedMatMulAttrs, _FusedMatMulInputs, broadcast_util, KernelConfig, KernelFunc} from '@tensorflow/tfjs-core';\n\nimport {BackendWasm} from '../backend_wasm';\n\nimport {FusableActivation} from './types';\n\nlet wasmFusedMatMul:\n    (aId: number, aShape: Uint8Array, aShapeSize: number, bId: number,\n     bShape: Uint8Array, bShapeSize: number, transposeA: boolean,\n     transposeB: boolean, activation: number, biasId: number,\n     preluActivationWeightsId: number, leakyreluAlpha: number, outId: number) =>\n        void;\n\nfunction setup(backend: BackendWasm) {\n  wasmFusedMatMul = backend.wasm.cwrap(_FusedMatMul, null /* void */, [\n    'number',  // a_id\n    'array',   // a_shape\n    'number',  // a_shape.length\n    'number',  // b_id\n    'array',   // b_shape\n    'number',  // b_shape.length\n    'number',  // transpose_a\n    'number',  // transpose_b\n    'number',  // activation\n    'number',  // biasId\n    'number',  // preluActivationWeightsId\n    'number',  // leakyreluAlpha\n    'number'   // out_id\n  ]);\n}\n\nfunction fusedBatchMatMul(args: {\n  inputs: _FusedMatMulInputs,\n  backend: BackendWasm,\n  attrs: _FusedMatMulAttrs\n}) {\n  const {inputs, backend, attrs} = args;\n  const {a, b, bias, preluActivationWeights} = inputs;\n\n  if (a.dtype !== 'float32' || b.dtype !== 'float32') {\n    throw new Error(\n        `_FusedMatMul for non non-float32 tensors not yet supported.`);\n  }\n\n  const {transposeA, transposeB, activation, leakyreluAlpha} = attrs;\n  const aId = backend.dataIdMap.get(a.dataId).id;\n  const bId = backend.dataIdMap.get(b.dataId).id;\n\n  let biasId = 0;\n  if (bias != null) {\n    const biasData = backend.dataIdMap.get(bias.dataId);\n    if (biasData.shape.length !== 1) {\n      throw new Error(\n          `_FusedMatMul only supports rank-1 bias but got ` +\n          `rank ${biasData.shape.length}.`);\n    }\n    biasId = biasData.id;\n  }\n  const preluActivationWeightsId = preluActivationWeights == null ?\n      0 :\n      backend.dataIdMap.get(preluActivationWeights.dataId).id;\n  const fusedActivation =\n      FusableActivation[activation as unknown as\n                        keyof typeof FusableActivation];\n  if (fusedActivation == null) {\n    throw new Error(\n        `${activation} activation not yet supported for FusedConv2D ` +\n        `in the wasm backend.`);\n  }\n\n  const leftDim = transposeA ? a.shape[2] : a.shape[1];\n  const rightDim = transposeB ? b.shape[1] : b.shape[2];\n  const batchDims = broadcast_util.assertAndGetBroadcastShape(\n      a.shape.slice(0, -2), b.shape.slice(0, -2));\n\n  const out = backend.makeOutput([...batchDims, leftDim, rightDim], a.dtype);\n  const outId = backend.dataIdMap.get(out.dataId).id;\n\n  const aShapeBytes = new Uint8Array(new Int32Array(a.shape).buffer);\n  const bShapeBytes = new Uint8Array(new Int32Array(b.shape).buffer);\n\n  wasmFusedMatMul(\n      aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length,\n      transposeA, transposeB, fusedActivation, biasId, preluActivationWeightsId,\n      leakyreluAlpha || 0, outId);\n\n  return out;\n}\n\nexport const _fusedMatMulConfig: KernelConfig = {\n  kernelName: _FusedMatMul,\n  backendName: 'wasm',\n  setupFunc: setup,\n  kernelFunc: fusedBatchMatMul as unknown as KernelFunc\n};\n"]}