@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
85 lines • 14.4 kB
JavaScript
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { backend_util, Conv2DBackpropInput, util } from '@tensorflow/tfjs-core';
let wasmConv2DBackpropInput;
function setup(backend) {
wasmConv2DBackpropInput = backend.wasm.cwrap(Conv2DBackpropInput, null, [
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number',
'number', // outId
]);
}
function conv2DBackpropInput(args) {
const { backend, inputs, attrs } = args;
const { dy, filter } = inputs;
const { strides, pad, dataFormat, dimRoundingMode, inputShape } = attrs;
const dilations = 1;
const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);
const convInfo = backend_util.computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
const topPad = filterHeight - 1 - convInfo.padInfo.top;
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
const dxStrides = util.computeStrides(convInfo.inShape);
const dyStrides = util.computeStrides(dy.shape);
const [fltS0, fltS1, fltS2] = util.computeStrides(filter.shape);
const xBatchStride = dxStrides[0];
const xRowStride = isChannelsLast ? dxStrides[1] : dxStrides[2];
const xColStride = isChannelsLast ? dxStrides[2] : 1;
const xChannelStride = isChannelsLast ? 1 : dxStrides[1];
const yBatchStride = dyStrides[0];
const yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
const yColStride = isChannelsLast ? dyStrides[2] : 1;
const yChannelStride = isChannelsLast ? 1 : dyStrides[1];
const out = backend.makeOutput(convInfo.inShape, 'float32');
const outId = backend.dataIdMap.get(out.dataId).id;
const dyId = backend.dataIdMap.get(dy.dataId).id;
const filterId = backend.dataIdMap.get(filter.dataId).id;
wasmConv2DBackpropInput(dyId, filterId, batchSize, filterHeight, filterWidth, inHeight, inWidth, inChannels, outHeight, outWidth, outChannels, strideHeight, strideWidth, topPad, leftPad, fltS0, fltS1, fltS2, xBatchStride, xRowStride, xColStride, xChannelStride, yBatchStride, yRowStride, yColStride, yChannelStride, outId);
return out;
}
export const conv2DBackpropInputConfig = {
kernelName: Conv2DBackpropInput,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: conv2DBackpropInput
};
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Conv2DBackpropInput.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-wasm/src/kernels/Conv2DBackpropInput.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,mBAAmB,EAA6F,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAIzK,IAAI,uBAQ8C,CAAC;AAEnD,SAAS,KAAK,CAAC,OAAoB;IACjC,uBAAuB,GAAG,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,mBAAmB,EAAE,IAAI,EAAE;QACtE,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ,EAAG,QAAQ;KACpB,CAAC,CAAC;AACL,CAAC;AAED,SAAS,mBAAmB,CAAC,IAI5B;IACC,MAAM,EAAC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,EAAE,EAAE,MAAM,EAAC,GAAG,MAAM,CAAC;IAC5B,MAAM,EAAC,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,eAAe,EAAE,UAAU,EAAC,GAAG,KAAK,CAAC;IAEtE,MAAM,SAAS,GAAG,CAAC,CAAC;IAEpB,MAAM,WAAW,GAAG,YAAY,CAAC,uBAAuB,CAAC,UAAU,CAAC,CAAC;IACrE,MAAM,QAAQ,GAAG,YAAY,CAAC,iBAAiB,CAC3C,UAAU,EAAE,MAAM,CAAC,KAAyC,EAAE,OAAO,EACrE,SAAS,EAAE,GAAG,EAAE,eAAe,EAAE,KAAK,CAAC,eAAe,EAAE,WAAW,CAAC,CAAC;IACzE,MAAM,EACJ,SAAS,EACT,YAAY,EACZ,WAAW,EACX,UAAU,EACV,QAAQ,EACR,OAAO,EACP,WAAW,EACX,SAAS,EACT,QAAQ,EACR,YAAY,EACZ,WAAW,EACZ,GAAG,QAAQ,CAAC;IAEb,MAAM,MAAM,GAAG,YAAY,GAAG,CAAC,GAAG,QAAQ,CAAC,OAAO,CAAC,GAAG,CAAC;IACvD,MAAM,OAAO,GAAG,WAAW,GAAG,CAAC,GAAG,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC;IAExD,MAAM,cAAc,GAAG,QAAQ,CAAC,UAAU,KAAK,cAAc,CAAC;IAC9D,MAAM,SAAS,GAAG,IAAI,CAAC,cAAc,CAAC,QAAQ,CAAC,OAAO,CAAC,CAAC;IACxD,MAAM,SAAS,GAAG,IAAI,CAAC,cAAc,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC;IAChD,MAAM,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,GAAG,IAAI,CAAC,cAAc,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;IAChE,MAAM,YAAY,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;IAClC,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC;IAChE,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACrD,MAAM,cAAc,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC;IACzD,MAAM,YAAY,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;IAClC,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC;IAChE,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACrD,MAAM,cAAc,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC;IAEzD,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,CAAC,QAAQ,CAAC,OAAO,EAAE,SAAS,CAAC,CAAC;IAC5D,MAAM,KAAK,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IACnD,MAAM,IAAI,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IACjD,MAAM,QAAQ,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAEzD,uBAAuB,CACnB,IAAI,EAAE,QAAQ,EAAE,SAAS,EAAE,YAAY,EAAE,WAAW,EAAE,QAAQ,EAAE,OAAO,EACvE,UAAU,EAAE,SAAS,EAAE,QAAQ,EAAE,WAAW,EAAE,YAAY,EAAE,WAAW,EACvE,MAAM,EAAE,OAAO,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,YAAY,EAAE,UAAU,EAC9D,UAAU,EAAE,cAAc,EAAE,YAAY,EAAE,UAAU,EAAE,UAAU,EAChE,cAAc,EAAE,KAAK,CAAC,CAAC;IAC3B,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,CAAC,MAAM,yBAAyB,GAAiB;IACrD,UAAU,EAAE,mBAAmB;IAC/B,WAAW,EAAE,MAAM;IACnB,SAAS,EAAE,KAAK;IAChB,UAAU,EAAE,mBAA4C;CACzD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, Conv2DBackpropInput, Conv2DBackpropInputAttrs, Conv2DBackpropInputInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core';\n\nimport {BackendWasm} from '../backend_wasm';\n\nlet wasmConv2DBackpropInput: (\n    dyId: number, filterId: number, batchSize: number, filterHeight: number,\n    filterWidth: number, inHeight: number, inWidth: number, inChannels: number,\n    outHeight: number, outWidth: number, outChannels: number,\n    strideHeight: number, strideWidth: number, topPad: number, leftPad: number,\n    fltS0: number, fltS1: number, fltS2: number, xBatchStride: number,\n    xRowStride: number, xColStride: number, xChannelStride: number,\n    yBatchStride: number, yRowStride: number, yColStride: number,\n    yChannelStride: number, outId: number) => void;\n\nfunction setup(backend: BackendWasm): void {\n  wasmConv2DBackpropInput = backend.wasm.cwrap(Conv2DBackpropInput, null, [\n    'number',  // dyId\n    'number',  // filterId\n    'number',  // batchSize\n    'number',  // filterHeight\n    'number',  // filterWidth\n    'number',  // inHeight\n    'number',  // inWidth\n    'number',  // inChannels\n    'number',  // outHeight\n    'number',  // outWidth\n    'number',  // outChannels\n    'number',  // strideHeight\n    'number',  // strideWidth\n    'number',  // topPad\n    'number',  // leftPad\n    'number',  // fltS0\n    'number',  // fltS1\n    'number',  // fltS2\n    'number',  // xBatchStride\n    'number',  // xRowStride\n    'number',  // xColStride\n    'number',  // xChannelStride\n    'number',  // yBatchStride\n    'number',  // yRowStride\n    'number',  // yColStride\n    'number',  // yChannelStride\n    'number',  // outId\n  ]);\n}\n\nfunction conv2DBackpropInput(args: {\n  backend: BackendWasm,\n  inputs: Conv2DBackpropInputInputs,\n  attrs: Conv2DBackpropInputAttrs\n}): TensorInfo {\n  const {backend, inputs, attrs} = args;\n  const {dy, filter} = inputs;\n  const {strides, pad, dataFormat, dimRoundingMode, inputShape} = attrs;\n\n  const dilations = 1;\n\n  const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);\n  const convInfo = backend_util.computeConv2DInfo(\n      inputShape, filter.shape as [number, number, number, number], strides,\n      dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);\n  const {\n    batchSize,\n    filterHeight,\n    filterWidth,\n    inChannels,\n    inHeight,\n    inWidth,\n    outChannels,\n    outHeight,\n    outWidth,\n    strideHeight,\n    strideWidth\n  } = convInfo;\n\n  const topPad = filterHeight - 1 - convInfo.padInfo.top;\n  const leftPad = filterWidth - 1 - convInfo.padInfo.left;\n\n  const isChannelsLast = convInfo.dataFormat === 'channelsLast';\n  const dxStrides = util.computeStrides(convInfo.inShape);\n  const dyStrides = util.computeStrides(dy.shape);\n  const [fltS0, fltS1, fltS2] = util.computeStrides(filter.shape);\n  const xBatchStride = dxStrides[0];\n  const xRowStride = isChannelsLast ? dxStrides[1] : dxStrides[2];\n  const xColStride = isChannelsLast ? dxStrides[2] : 1;\n  const xChannelStride = isChannelsLast ? 1 : dxStrides[1];\n  const yBatchStride = dyStrides[0];\n  const yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];\n  const yColStride = isChannelsLast ? dyStrides[2] : 1;\n  const yChannelStride = isChannelsLast ? 1 : dyStrides[1];\n\n  const out = backend.makeOutput(convInfo.inShape, 'float32');\n  const outId = backend.dataIdMap.get(out.dataId).id;\n  const dyId = backend.dataIdMap.get(dy.dataId).id;\n  const filterId = backend.dataIdMap.get(filter.dataId).id;\n\n  wasmConv2DBackpropInput(\n      dyId, filterId, batchSize, filterHeight, filterWidth, inHeight, inWidth,\n      inChannels, outHeight, outWidth, outChannels, strideHeight, strideWidth,\n      topPad, leftPad, fltS0, fltS1, fltS2, xBatchStride, xRowStride,\n      xColStride, xChannelStride, yBatchStride, yRowStride, yColStride,\n      yChannelStride, outId);\n  return out;\n}\n\nexport const conv2DBackpropInputConfig: KernelConfig = {\n  kernelName: Conv2DBackpropInput,\n  backendName: 'wasm',\n  setupFunc: setup,\n  kernelFunc: conv2DBackpropInput as unknown as KernelFunc\n};\n"]}