@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
84 lines • 11.7 kB
JavaScript
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { Transform, util } from '@tensorflow/tfjs-core';
let wasmTransform;
function setup(backend) {
wasmTransform = backend.wasm.cwrap(Transform, null /*void*/, [
'number',
'number',
'bool',
'number',
'number',
'number',
'number',
'number',
'number',
'array',
'number',
'array',
'number',
'number',
'number',
'number',
'number' // outId
]);
}
function transform(args) {
const { backend, inputs, attrs } = args;
const { image, transforms } = inputs;
const { interpolation, fillMode, fillValue, outputShape } = attrs;
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth];
const outShape = [batch, outHeight, outWidth,
numChannels];
const inputStrides = new Uint8Array(new Int32Array(util.computeStrides(image.shape)).buffer);
const outputStrides = new Uint8Array(new Int32Array(util.computeStrides(outShape)).buffer);
const out = backend.makeOutput(outShape, image.dtype);
const outId = backend.dataIdMap.get(out.dataId).id;
const imageData = backend.dataIdMap.get(image.dataId);
const imageId = imageData.id;
const transformsData = backend.dataIdMap.get(transforms.dataId);
const transformsId = transformsData.id;
const interpolationModeId = interpolation === 'nearest' ? 1 : 2;
let fillModeId;
switch (fillMode) {
case 'constant':
fillModeId = 1;
break;
case 'reflect':
fillModeId = 2;
break;
case 'wrap':
fillModeId = 3;
break;
case 'nearest':
fillModeId = 4;
break;
default:
fillModeId = 1;
break;
}
wasmTransform(imageId, transformsId, (transforms.shape[0] > 1), batch, outHeight, outWidth, numChannels, imageWidth, imageHeight, inputStrides, image.shape.length - 1, outputStrides, outShape.length - 1, interpolationModeId, fillModeId, fillValue, outId);
return out;
}
export const transformConfig = {
kernelName: Transform,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: transform
};
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Transform.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-wasm/src/kernels/Transform.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAuC,SAAS,EAAmC,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAI7H,IAAI,aAM6D,CAAC;AAElE,SAAS,KAAK,CAAC,OAAoB;IACjC,aAAa,GAAG,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,SAAS,EAAE,IAAI,CAAC,QAAQ,EAAE;QAC3D,QAAQ;QACR,QAAQ;QACR,MAAM;QACN,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,OAAO;QACP,QAAQ;QACR,OAAO;QACP,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,QAAQ,CAAG,QAAQ;KACpB,CAAC,CAAC;AACL,CAAC;AAED,SAAS,SAAS,CACd,IAC0E;IAE5E,MAAM,EAAC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,KAAK,EAAE,UAAU,EAAC,GAAG,MAAM,CAAC;IACnC,MAAM,EAAC,aAAa,EAAE,QAAQ,EAAE,SAAS,EAAE,WAAW,EAAC,GAAG,KAAK,CAAC;IAEhE,MAAM,CAAC,KAAK,EAAE,WAAW,EAAE,UAAU,EAAE,WAAW,CAAC,GAAG,KAAK,CAAC,KAAK,CAAC;IAClE,MAAM,CAAC,SAAS,EAAE,QAAQ,CAAC,GACvB,WAAW,IAAI,IAAI,CAAC,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,WAAW,EAAE,UAAU,CAAC,CAAC;IAClE,MAAM,QAAQ,GACV,CAAC,KAAK,EAAE,SAAS,EAAE,QAAQ;QAC1B,WAAW,CAAqC,CAAC;IACtD,MAAM,YAAY,GACd,IAAI,UAAU,CAAC,IAAI,UAAU,CAAC,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;IAE5E,MAAM,aAAa,GACf,IAAI,UAAU,CAAC,IAAI,UAAU,CAAC,IAAI,CAAC,cAAc,CAAC,QAAQ,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;IAEzE,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,CAAC,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC;IACtD,MAAM,KAAK,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAEnD,MAAM,SAAS,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACtD,MAAM,OAAO,GAAG,SAAS,CAAC,EAAE,CAAC;IAE7B,MAAM,cAAc,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,UAAU,CAAC,MAAM,CAAC,CAAC;IAChE,MAAM,YAAY,GAAG,cAAc,CAAC,EAAE,CAAC;IAEvC,MAAM,mBAAmB,GAAG,aAAa,KAAK,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAChE,IAAI,UAAU,CAAC;IACf,QAAQ,QAAQ,EAAE;QAChB,KAAK,UAAU;YACb,UAAU,GAAG,CAAC,CAAC;YACf,MAAM;QACR,KAAK,SAAS;YACZ,UAAU,GAAG,CAAC,CAAC;YACf,MAAM;QACR,KAAK,MAAM;YACT,UAAU,GAAG,CAAC,CAAC;YACf,MAAM;QACR,KAAK,SAAS;YACZ,UAAU,GAAG,CAAC,CAAC;YACf,MAAM;QACR;YACE,UAAU,GAAG,CAAC,CAAC;YACf,MAAM;KACT;IAED,aAAa,CACT,OAAO,EAAE,YAAY,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,EAAE,SAAS,EAClE,QAAQ,EAAE,WAAW,EAAE,UAAU,EAAE,WAAW,EAAE,YAAY,EAC5D,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,aAAa,EAAE,QAAQ,CAAC,MAAM,GAAG,CAAC,EAC1D,mBAAmB,EAAE,UAAU,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;IAEvD,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,CAAC,MAAM,eAAe,GAAiB;IAC3C,UAAU,EAAE,SAAS;IACrB,WAAW,EAAE,MAAM;IACnB,SAAS,EAAE,KAAK;IAChB,UAAU,EAAE,SAAkC;CAC/C,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2021 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {KernelConfig, KernelFunc, TensorInfo, Transform, TransformAttrs, TransformInputs, util} from '@tensorflow/tfjs-core';\n\nimport {BackendWasm} from '../backend_wasm';\n\nlet wasmTransform: (\n    imageId: number, transformsId: number, isBatchTransform: boolean,\n    batch: number, outHeight: number, outWidth: number, numChannels: number,\n    imageWidth: number, imageHeight: number, inputStrides: Uint8Array,\n    inputStridesLength: number, outputStrides: Uint8Array,\n    outputStridesLength: number, interpolationModeId: number,\n    fillModeId: number, fillValue: number, outId: number) => void;\n\nfunction setup(backend: BackendWasm): void {\n  wasmTransform = backend.wasm.cwrap(Transform, null /*void*/, [\n    'number',  // imageId\n    'number',  // transformsId\n    'bool',    // isBatchTransform\n    'number',  // batch\n    'number',  // outHeight\n    'number',  // outWidth\n    'number',  // numChannels\n    'number',  // imageWidth\n    'number',  // imageHeight\n    'array',   // inputStrides\n    'number',  // inputStridesLength\n    'array',   // outputStrides\n    'number',  // outputStridesLength\n    'number',  // interpolationModeId\n    'number',  // fillModeId\n    'number',  // fillValue\n    'number'   // outId\n  ]);\n}\n\nfunction transform(\n    args:\n        {backend: BackendWasm, inputs: TransformInputs, attrs: TransformAttrs}):\n    TensorInfo {\n  const {backend, inputs, attrs} = args;\n  const {image, transforms} = inputs;\n  const {interpolation, fillMode, fillValue, outputShape} = attrs;\n\n  const [batch, imageHeight, imageWidth, numChannels] = image.shape;\n  const [outHeight, outWidth] =\n      outputShape != null ? outputShape : [imageHeight, imageWidth];\n  const outShape =\n      [batch, outHeight, outWidth,\n       numChannels] as [number, number, number, number];\n  const inputStrides =\n      new Uint8Array(new Int32Array(util.computeStrides(image.shape)).buffer);\n\n  const outputStrides =\n      new Uint8Array(new Int32Array(util.computeStrides(outShape)).buffer);\n\n  const out = backend.makeOutput(outShape, image.dtype);\n  const outId = backend.dataIdMap.get(out.dataId).id;\n\n  const imageData = backend.dataIdMap.get(image.dataId);\n  const imageId = imageData.id;\n\n  const transformsData = backend.dataIdMap.get(transforms.dataId);\n  const transformsId = transformsData.id;\n\n  const interpolationModeId = interpolation === 'nearest' ? 1 : 2;\n  let fillModeId;\n  switch (fillMode) {\n    case 'constant':\n      fillModeId = 1;\n      break;\n    case 'reflect':\n      fillModeId = 2;\n      break;\n    case 'wrap':\n      fillModeId = 3;\n      break;\n    case 'nearest':\n      fillModeId = 4;\n      break;\n    default:\n      fillModeId = 1;\n      break;\n  }\n\n  wasmTransform(\n      imageId, transformsId, (transforms.shape[0] > 1), batch, outHeight,\n      outWidth, numChannels, imageWidth, imageHeight, inputStrides,\n      image.shape.length - 1, outputStrides, outShape.length - 1,\n      interpolationModeId, fillModeId, fillValue, outId);\n\n  return out;\n}\n\nexport const transformConfig: KernelConfig = {\n  kernelName: Transform,\n  backendName: 'wasm',\n  setupFunc: setup,\n  kernelFunc: transform as unknown as KernelFunc\n};\n"]}