UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

70 lines 11.3 kB
/** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, Conv3DBackpropFilterV2 } from '@tensorflow/tfjs-core'; let wasmConv3DBackpropFilterV2; function setup(backend) { wasmConv3DBackpropFilterV2 = backend.wasm.cwrap(Conv3DBackpropFilterV2, null, [ 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', // padLeft ]); } export function conv3DBackpropFilterV2(args) { const { inputs, backend, attrs } = args; const { x, dy } = inputs; const { strides, pad, filterShape } = attrs; if (x.dtype !== 'float32') { throw new Error(`Tensor dy must have dtype float32, got ${x.dtype}`); } if (dy.dtype !== 'float32') { throw new Error(`Tensor filter must have dtype float32, got ${dy.dtype}`); } const convInfo = backend_util.computeConv3DInfo(x.shape, filterShape, strides, /*dilations=*/ 1, pad); const dw = backend.makeOutput(convInfo.filterShape, dy.dtype); wasmConv3DBackpropFilterV2(backend.dataIdMap.get(x.dataId).id, backend.dataIdMap.get(dy.dataId).id, backend.dataIdMap.get(dw.dataId).id, convInfo.batchSize, convInfo.inDepth, convInfo.inHeight, convInfo.inWidth, convInfo.inChannels, convInfo.outDepth, convInfo.outHeight, convInfo.outWidth, convInfo.outChannels, convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationDepth, convInfo.dilationHeight, convInfo.dilationWidth, convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth, convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left); return dw; } export const conv3DBackpropFilterV2Config = { kernelName: Conv3DBackpropFilterV2, backendName: 'wasm', setupFunc: setup, kernelFunc: conv3DBackpropFilterV2 }; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Conv3DBackpropFilterV2.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-wasm/src/kernels/Conv3DBackpropFilterV2.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,sBAAsB,EAAkG,MAAM,uBAAuB,CAAC;AAI5K,IAAI,0BAO0D,CAAC;AAE/D,SAAS,KAAK,CAAC,OAAoB;IACjC,0BAA0B;QACtB,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,sBAAsB,EAAE,IAAI,EAAE;YAC/C,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ,EAAG,UAAU;SACtB,CAAC,CAAC;AACT,CAAC;AAED,MAAM,UAAU,sBAAsB,CAAC,IAItC;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAE,EAAE,EAAC,GAAG,MAAM,CAAC;IACvB,MAAM,EAAC,OAAO,EAAE,GAAG,EAAE,WAAW,EAAC,GAAG,KAAK,CAAC;IAE1C,IAAI,CAAC,CAAC,KAAK,KAAK,SAAS,EAAE;QACzB,MAAM,IAAI,KAAK,CAAC,0CAA0C,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC;KACtE;IACD,IAAI,EAAE,CAAC,KAAK,KAAK,SAAS,EAAE;QAC1B,MAAM,IAAI,KAAK,CAAC,8CAA8C,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC;KAC3E;IAED,MAAM,QAAQ,GAAG,YAAY,CAAC,iBAAiB,CAC3C,CAAC,CAAC,KAAiD,EAAE,WAAW,EAAE,OAAO;IACzE,cAAc,CAAA,CAAC,EAAE,GAAG,CAAC,CAAC;IAE1B,MAAM,EAAE,GAAG,OAAO,CAAC,UAAU,CAAC,QAAQ,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC;IAE9D,0BAA0B,CACtB,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,EAAE,EAClC,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,EACnC,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,EACnC,QAAQ,CAAC,SAAS,EAClB,QAAQ,CAAC,OAAO,EAChB,QAAQ,CAAC,QAAQ,EACjB,QAAQ,CAAC,OAAO,EAChB,QAAQ,CAAC,UAAU,EACnB,QAAQ,CAAC,QAAQ,EACjB,QAAQ,CAAC,SAAS,EAClB,QAAQ,CAAC,QAAQ,EACjB,QAAQ,CAAC,WAAW,EACpB,QAAQ,CAAC,WAAW,EACpB,QAAQ,CAAC,YAAY,EACrB,QAAQ,CAAC,WAAW,EACpB,QAAQ,CAAC,aAAa,EACtB,QAAQ,CAAC,cAAc,EACvB,QAAQ,CAAC,aAAa,EACtB,QAAQ,CAAC,WAAW,EACpB,QAAQ,CAAC,YAAY,EACrB,QAAQ,CAAC,WAAW,EACpB,QAAQ,CAAC,OAAO,CAAC,KAAK,EACtB,QAAQ,CAAC,OAAO,CAAC,GAAG,EACpB,QAAQ,CAAC,OAAO,CAAC,IAAI,CACxB,CAAC;IACF,OAAO,EAAE,CAAC;AACZ,CAAC;AAED,MAAM,CAAC,MAAM,4BAA4B,GAAiB;IACxD,UAAU,EAAE,sBAAsB;IAClC,WAAW,EAAE,MAAM;IACnB,SAAS,EAAE,KAAK;IAChB,UAAU,EAAE,sBAA+C;CAC5D,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, Conv3DBackpropFilterV2, Conv3DBackpropFilterV2Attrs, Conv3DBackpropFilterV2Inputs, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core';\n\nimport {BackendWasm} from '../backend_wasm';\n\nlet wasmConv3DBackpropFilterV2: (\n    xId: number, dyId: number, dwId: number, batchSize: number, inDepth: number,\n    inHeight: number, inWidth: number, inChannels: number, outDepth: number,\n    outHeight: number, outWidth: number, outChannels: number,\n    strideDepth: number, strideHeight: number, strideWidth: number,\n    dilationDepth: number, dilationHeight: number, dilationWidth: number,\n    filterDepth: number, filterHeight: number, filterWidth: number,\n    padFront: number, padTop: number, padLeft: number) => void;\n\nfunction setup(backend: BackendWasm) {\n  wasmConv3DBackpropFilterV2 =\n      backend.wasm.cwrap(Conv3DBackpropFilterV2, null, [\n        'number',  // xId\n        'number',  // dyId\n        'number',  // dwId\n        'number',  // batchSize\n        'number',  // inDepth\n        'number',  // inHeight\n        'number',  // inWidth\n        'number',  // inChannels\n        'number',  // outDepth\n        'number',  // outHeight\n        'number',  // outWidth\n        'number',  // outChannels\n        'number',  // strideDepth\n        'number',  // strideHeight\n        'number',  // strideWidth\n        'number',  // dilationDepth\n        'number',  // dilationHeight\n        'number',  // dilationWidth\n        'number',  // filterDepth\n        'number',  // filterHeight\n        'number',  // filterWidth\n        'number',  // padFront\n        'number',  // padTop\n        'number',  // padLeft\n      ]);\n}\n\nexport function conv3DBackpropFilterV2(args: {\n  inputs: Conv3DBackpropFilterV2Inputs,\n  attrs: Conv3DBackpropFilterV2Attrs,\n  backend: BackendWasm,\n}): TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {x, dy} = inputs;\n  const {strides, pad, filterShape} = attrs;\n\n  if (x.dtype !== 'float32') {\n    throw new Error(`Tensor dy must have dtype float32, got ${x.dtype}`);\n  }\n  if (dy.dtype !== 'float32') {\n    throw new Error(`Tensor filter must have dtype float32, got ${dy.dtype}`);\n  }\n\n  const convInfo = backend_util.computeConv3DInfo(\n      x.shape as [number, number, number, number, number], filterShape, strides,\n      /*dilations=*/1, pad);\n\n  const dw = backend.makeOutput(convInfo.filterShape, dy.dtype);\n\n  wasmConv3DBackpropFilterV2(\n      backend.dataIdMap.get(x.dataId).id,\n      backend.dataIdMap.get(dy.dataId).id,\n      backend.dataIdMap.get(dw.dataId).id,\n      convInfo.batchSize,\n      convInfo.inDepth,\n      convInfo.inHeight,\n      convInfo.inWidth,\n      convInfo.inChannels,\n      convInfo.outDepth,\n      convInfo.outHeight,\n      convInfo.outWidth,\n      convInfo.outChannels,\n      convInfo.strideDepth,\n      convInfo.strideHeight,\n      convInfo.strideWidth,\n      convInfo.dilationDepth,\n      convInfo.dilationHeight,\n      convInfo.dilationWidth,\n      convInfo.filterDepth,\n      convInfo.filterHeight,\n      convInfo.filterWidth,\n      convInfo.padInfo.front,\n      convInfo.padInfo.top,\n      convInfo.padInfo.left,\n  );\n  return dw;\n}\n\nexport const conv3DBackpropFilterV2Config: KernelConfig = {\n  kernelName: Conv3DBackpropFilterV2,\n  backendName: 'wasm',\n  setupFunc: setup,\n  kernelFunc: conv3DBackpropFilterV2 as unknown as KernelFunc\n};\n"]}