@tensorflow/tfjs-backend-wasm
Version:
This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier
91 lines • 13.3 kB
JavaScript
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { backend_util, GatherV2, util } from '@tensorflow/tfjs-core';
import { reshape } from './Reshape';
import { CppDType } from './types';
let wasmGather;
function setup(backend) {
wasmGather = backend.wasm.cwrap('Gather', null /*void*/, [
'number',
'number',
'array',
'number',
'number',
'number',
'array',
'number' // outId
]);
}
function gatherV2(args) {
const { backend, inputs, attrs } = args;
const { x, indices } = inputs;
const { axis, batchDims } = attrs;
// Throw error when any index is out of bound.
const parsedAxis = util.parseAxisParam(axis, x.shape)[0];
const indicesVals = backend.readSync(indices.dataId);
const axisDim = x.shape[parsedAxis];
for (let i = 0; i < indicesVals.length; ++i) {
const index = indicesVals[i];
util.assert(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
}
const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims);
const flattenX = reshape({
inputs: { x },
attrs: {
shape: [
shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
shapeInfo.sliceSize
]
},
backend
});
const indicesSize = util.sizeFromShape(indices.shape);
const flattenIndex = reshape({
inputs: { x: indices },
attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] },
backend
});
const flattenOutputShape = [
shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
shapeInfo.sliceSize
];
const out = backend.makeOutput(flattenOutputShape, x.dtype);
if (util.sizeFromShape(x.shape) === 0) {
return out;
}
const stridesSize = flattenX.shape.length - 1;
const xData = backend.dataIdMap.get(flattenX.dataId);
const xId = xData.id;
const indicesData = backend.dataIdMap.get(flattenIndex.dataId);
const indicesId = indicesData.id;
const outId = backend.dataIdMap.get(out.dataId).id;
const xStridesBytes = new Uint8Array(new Int32Array(util.computeStrides(flattenX.shape)).buffer);
const outStridesBytes = new Uint8Array(new Int32Array(util.computeStrides(flattenOutputShape)).buffer);
wasmGather(xId, CppDType[x.dtype], xStridesBytes, stridesSize, indicesId, shapeInfo.batchSize, outStridesBytes, outId);
backend.disposeData(flattenX.dataId);
backend.disposeData(flattenIndex.dataId);
// reshape
out.shape = shapeInfo.outputShape;
return out;
}
export const gatherV2Config = {
kernelName: GatherV2,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: gatherV2
};
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"GatherV2.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-wasm/src/kernels/GatherV2.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,QAAQ,EAA2F,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAI5J,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAClC,OAAO,EAAC,QAAQ,EAAC,MAAM,SAAS,CAAC;AAEjC,IAAI,UAGsB,CAAC;AAE3B,SAAS,KAAK,CAAC,OAAoB;IACjC,UAAU,GAAG,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE;QACvD,QAAQ;QACR,QAAQ;QACR,OAAO;QACP,QAAQ;QACR,QAAQ;QACR,QAAQ;QACR,OAAO;QACP,QAAQ,CAAG,QAAQ;KACpB,CAAC,CAAC;AACL,CAAC;AAED,SAAS,QAAQ,CACb,IAA0E;IAE5E,MAAM,EAAC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAE,OAAO,EAAC,GAAG,MAAM,CAAC;IAC5B,MAAM,EAAC,IAAI,EAAE,SAAS,EAAC,GAAG,KAAK,CAAC;IAEhC,8CAA8C;IAC9C,MAAM,UAAU,GAAG,IAAI,CAAC,cAAc,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;IACzD,MAAM,WAAW,GAAG,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAC,MAAM,CAAe,CAAC;IACnE,MAAM,OAAO,GAAG,CAAC,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IACpC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,WAAW,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QAC3C,MAAM,KAAK,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC;QAC7B,IAAI,CAAC,MAAM,CACP,KAAK,IAAI,OAAO,GAAG,CAAC,IAAI,KAAK,IAAI,CAAC,EAClC,GAAG,EAAE,CACD,6BAA6B,KAAK,kBAAkB,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC;KAC7E;IAED,MAAM,SAAS,GAAG,YAAY,CAAC,YAAY,CAAC,wBAAwB,CAChE,CAAW,EAAE,OAAiB,EAAE,UAAU,EAAE,SAAS,CAAC,CAAC;IAE3D,MAAM,QAAQ,GAAG,OAAO,CAAC;QACvB,MAAM,EAAE,EAAC,CAAC,EAAC;QACX,KAAK,EAAE;YACL,KAAK,EAAE;gBACL,SAAS,CAAC,SAAS,EAAE,SAAS,CAAC,SAAS,EAAE,SAAS,CAAC,OAAO;gBAC3D,SAAS,CAAC,SAAS;aACpB;SACF;QACD,OAAO;KACR,CAAC,CAAC;IACH,MAAM,WAAW,GAAG,IAAI,CAAC,aAAa,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;IACtD,MAAM,YAAY,GAAG,OAAO,CAAC;QAC3B,MAAM,EAAE,EAAC,CAAC,EAAE,OAAO,EAAC;QACpB,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,SAAS,CAAC,SAAS,EAAE,WAAW,GAAG,SAAS,CAAC,SAAS,CAAC,EAAC;QACxE,OAAO;KACR,CAAC,CAAC;IACH,MAAM,kBAAkB,GAAG;QACzB,SAAS,CAAC,SAAS,EAAE,SAAS,CAAC,SAAS,EAAE,WAAW,GAAG,SAAS,CAAC,SAAS;QAC3E,SAAS,CAAC,SAAS;KACpB,CAAC;IAEF,MAAM,GAAG,GAAG,OAAO,CAAC,UAAU,CAAC,kBAAkB,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;IAC5D,IAAI,IAAI,CAAC,aAAa,CAAC,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE;QACrC,OAAO,GAAG,CAAC;KACZ;IACD,MAAM,WAAW,GAAG,QAAQ,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;IAE9C,MAAM,KAAK,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC;IACrD,MAAM,GAAG,GAAG,KAAK,CAAC,EAAE,CAAC;IAErB,MAAM,WAAW,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,YAAY,CAAC,MAAM,CAAC,CAAC;IAC/D,MAAM,SAAS,GAAG,WAAW,CAAC,EAAE,CAAC;IAEjC,MAAM,KAAK,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAEnD,MAAM,aAAa,GAAG,IAAI,UAAU,CAChC,IAAI,UAAU,CAAC,IAAI,CAAC,cAAc,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;IAChE,MAAM,eAAe,GAAG,IAAI,UAAU,CAClC,IAAI,UAAU,CAAC,IAAI,CAAC,cAAc,CAAC,kBAAkB,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;IAEpE,UAAU,CACN,GAAG,EAAE,QAAQ,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE,aAAa,EAAE,WAAW,EAAE,SAAS,EAC7D,SAAS,CAAC,SAAS,EAAE,eAAe,EAAE,KAAK,CAAC,CAAC;IAEjD,OAAO,CAAC,WAAW,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC;IACrC,OAAO,CAAC,WAAW,CAAC,YAAY,CAAC,MAAM,CAAC,CAAC;IAEzC,UAAU;IACV,GAAG,CAAC,KAAK,GAAG,SAAS,CAAC,WAAW,CAAC;IAClC,OAAO,GAAG,CAAC;AACb,CAAC;AAED,MAAM,CAAC,MAAM,cAAc,GAAiB;IAC1C,UAAU,EAAE,QAAQ;IACpB,WAAW,EAAE,MAAM;IACnB,SAAS,EAAE,KAAK;IAChB,UAAU,EAAE,QAAiC;CAC9C,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, KernelFunc, Tensor, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {BackendWasm} from '../backend_wasm';\n\nimport {reshape} from './Reshape';\nimport {CppDType} from './types';\n\nlet wasmGather: (\n    xId: number, dtype: CppDType, xStrides: Uint8Array, stridesSize: number,\n    indicesId: number, batchSize: number, outStrides: Uint8Array,\n    outId: number) => void;\n\nfunction setup(backend: BackendWasm): void {\n  wasmGather = backend.wasm.cwrap('Gather', null /*void*/, [\n    'number',  // xId\n    'number',  // dtype\n    'array',   // xStrides\n    'number',  // stridesSize\n    'number',  // indicesId\n    'number',  // batchSize\n    'array',   // outStrides\n    'number'   // outId\n  ]);\n}\n\nfunction gatherV2(\n    args: {backend: BackendWasm, inputs: GatherV2Inputs, attrs: GatherV2Attrs}):\n    TensorInfo {\n  const {backend, inputs, attrs} = args;\n  const {x, indices} = inputs;\n  const {axis, batchDims} = attrs;\n\n  // Throw error when any index is out of bound.\n  const parsedAxis = util.parseAxisParam(axis, x.shape)[0];\n  const indicesVals = backend.readSync(indices.dataId) as TypedArray;\n  const axisDim = x.shape[parsedAxis];\n  for (let i = 0; i < indicesVals.length; ++i) {\n    const index = indicesVals[i];\n    util.assert(\n        index <= axisDim - 1 && index >= 0,\n        () =>\n            `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);\n  }\n\n  const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(\n      x as Tensor, indices as Tensor, parsedAxis, batchDims);\n\n  const flattenX = reshape({\n    inputs: {x},\n    attrs: {\n      shape: [\n        shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,\n        shapeInfo.sliceSize\n      ]\n    },\n    backend\n  });\n  const indicesSize = util.sizeFromShape(indices.shape);\n  const flattenIndex = reshape({\n    inputs: {x: indices},\n    attrs: {shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize]},\n    backend\n  });\n  const flattenOutputShape = [\n    shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,\n    shapeInfo.sliceSize\n  ];\n\n  const out = backend.makeOutput(flattenOutputShape, x.dtype);\n  if (util.sizeFromShape(x.shape) === 0) {\n    return out;\n  }\n  const stridesSize = flattenX.shape.length - 1;\n\n  const xData = backend.dataIdMap.get(flattenX.dataId);\n  const xId = xData.id;\n\n  const indicesData = backend.dataIdMap.get(flattenIndex.dataId);\n  const indicesId = indicesData.id;\n\n  const outId = backend.dataIdMap.get(out.dataId).id;\n\n  const xStridesBytes = new Uint8Array(\n      new Int32Array(util.computeStrides(flattenX.shape)).buffer);\n  const outStridesBytes = new Uint8Array(\n      new Int32Array(util.computeStrides(flattenOutputShape)).buffer);\n\n  wasmGather(\n      xId, CppDType[x.dtype], xStridesBytes, stridesSize, indicesId,\n      shapeInfo.batchSize, outStridesBytes, outId);\n\n  backend.disposeData(flattenX.dataId);\n  backend.disposeData(flattenIndex.dataId);\n\n  // reshape\n  out.shape = shapeInfo.outputShape;\n  return out;\n}\n\nexport const gatherV2Config: KernelConfig = {\n  kernelName: GatherV2,\n  backendName: 'wasm',\n  setupFunc: setup,\n  kernelFunc: gatherV2 as unknown as KernelFunc\n};\n"]}