UNPKG

@tensorflow/tfjs-backend-wasm

Version:

This package adds a WebAssembly backend to TensorFlow.js. It currently supports the following models from our [models](https://github.com/tensorflow/tfjs-models) repo: - BlazeFace - BodyPix - CocoSSD - Face landmarks detection - HandPose - KNN classifier

114 lines 18.1 kB
/** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, SparseFillEmptyRows } from '@tensorflow/tfjs-core'; import { slice } from './Slice'; import { CppDType } from './types'; let wasmSparseFillEmptyRows; export function setup(backend) { wasmSparseFillEmptyRows = backend.wasm.cwrap('SparseFillEmptyRows', 'number', [ 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', 'number', // exceptionValuesId ]); } export function sparseFillEmptyRows(args) { const { backend, inputs } = args; const { indices, values, denseShape, defaultValue } = inputs; const indicesCount = indices.shape[0]; const rank = indices.shape[1]; const denseRows = backend.readSync(denseShape.dataId)[0]; // Set output size to maximum possible and resize later (actual result // might be smaller). const maxOutputIndicesShape = [indicesCount + denseRows, rank]; const indicesId = backend.dataIdMap.get(indices.dataId).id; const valuesId = backend.dataIdMap.get(values.dataId).id; const defaultValueId = backend.dataIdMap.get(defaultValue.dataId).id; const outputIndices = backend.makeOutput(maxOutputIndicesShape, indices.dtype); const outputIndicesId = backend.dataIdMap.get(outputIndices.dataId).id; const outputValues = backend.makeOutput(maxOutputIndicesShape.slice(0, 1), values.dtype); const outputValuesId = backend.dataIdMap.get(outputValues.dataId).id; const emptyRowIndicator = backend.makeOutput([denseRows], 'bool'); const emptyRowIndicatorId = backend.dataIdMap.get(emptyRowIndicator.dataId).id; const reverseIndexMap = backend.makeOutput([indicesCount], indices.dtype); const reverseIndexMapId = backend.dataIdMap.get(reverseIndexMap.dataId).id; const exceptionValues = backend.makeOutput([4], 'int32'); const exceptionValuesId = backend.dataIdMap.get(exceptionValues.dataId).id; const outputRows = wasmSparseFillEmptyRows(indicesId, valuesId, CppDType[values.dtype], indicesCount, denseRows, rank, defaultValueId, outputIndicesId, outputValuesId, emptyRowIndicatorId, reverseIndexMapId, exceptionValuesId); const exceptionValuesArray = backend.readSync(exceptionValues.dataId); let exceptionMessage; switch (exceptionValuesArray[0]) { case 1: { exceptionMessage = backend_util.getSparseFillEmptyRowsIndicesDenseShapeMismatch(exceptionValuesArray[1]); break; } case 2: { exceptionMessage = backend_util.getSparseFillEmptyRowsNegativeIndexErrorMessage(exceptionValuesArray[1], exceptionValuesArray[2]); break; } case 3: exceptionMessage = backend_util.getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(exceptionValuesArray[1], exceptionValuesArray[2], exceptionValuesArray[3]); break; default: exceptionMessage = ''; } backend.disposeData(exceptionValues.dataId); if (exceptionMessage) { backend.disposeData(outputIndices.dataId); backend.disposeData(outputValues.dataId); backend.disposeData(emptyRowIndicator.dataId); backend.disposeData(reverseIndexMap.dataId); throw new Error(exceptionMessage); } let resizedIndices = outputIndices; let resizedValues = outputValues; // Overestimated output size. if (outputRows !== maxOutputIndicesShape[0]) { resizedIndices = slice({ inputs: { x: outputIndices }, attrs: { begin: 0, size: [outputRows, rank] }, backend }); resizedValues = slice({ inputs: { x: outputValues }, attrs: { begin: 0, size: outputRows }, backend }); backend.disposeData(outputIndices.dataId); backend.disposeData(outputValues.dataId); } return [resizedIndices, resizedValues, emptyRowIndicator, reverseIndexMap]; } export const sparseFillEmptyRowsConfig = { kernelName: SparseFillEmptyRows, backendName: 'wasm', setupFunc: setup, kernelFunc: sparseFillEmptyRows }; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"SparseFillEmptyRows.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-wasm/src/kernels/SparseFillEmptyRows.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAA4B,mBAAmB,EAAwC,MAAM,uBAAuB,CAAC;AAGzI,OAAO,EAAC,KAAK,EAAC,MAAM,SAAS,CAAC;AAE9B,OAAO,EAAC,QAAQ,EAAC,MAAM,SAAS,CAAC;AAEjC,IAAI,uBAKoC,CAAC;AAEzC,MAAM,UAAU,KAAK,CAAC,OAAoB;IACxC,uBAAuB;QACnB,OAAO,CAAC,IAAI,CAAC,KAAK,CAAC,qBAAqB,EAAE,QAAQ,EAAE;YAClD,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ;YACR,QAAQ,EAAG,oBAAoB;SAChC,CAAC,CAAC;AACT,CAAC;AAED,MAAM,UAAU,mBAAmB,CAAC,IAGnC;IACC,MAAM,EAAC,OAAO,EAAE,MAAM,EAAC,GAAG,IAAI,CAAC;IAC/B,MAAM,EAAC,OAAO,EAAE,MAAM,EAAE,UAAU,EAAE,YAAY,EAAC,GAAG,MAAM,CAAC;IAE3D,MAAM,YAAY,GAAG,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IACtC,MAAM,IAAI,GAAG,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAC9B,MAAM,SAAS,GAAG,OAAO,CAAC,QAAQ,CAAC,UAAU,CAAC,MAAM,CAAC,CAAC,CAAC,CAAW,CAAC;IAEnE,sEAAsE;IACtE,qBAAqB;IACrB,MAAM,qBAAqB,GAAG,CAAC,YAAY,GAAG,SAAS,EAAE,IAAI,CAAC,CAAC;IAE/D,MAAM,SAAS,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAC3D,MAAM,QAAQ,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IACzD,MAAM,cAAc,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,YAAY,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAErE,MAAM,aAAa,GACf,OAAO,CAAC,UAAU,CAAC,qBAAqB,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC;IAC7D,MAAM,eAAe,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAEvE,MAAM,YAAY,GACd,OAAO,CAAC,UAAU,CAAC,qBAAqB,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,CAAC,KAAK,CAAC,CAAC;IACxE,MAAM,cAAc,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,YAAY,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAErE,MAAM,iBAAiB,GAAG,OAAO,CAAC,UAAU,CAAC,CAAC,SAAS,CAAC,EAAE,MAAM,CAAC,CAAC;IAClE,MAAM,mBAAmB,GACrB,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,iBAAiB,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAEvD,MAAM,eAAe,GAAG,OAAO,CAAC,UAAU,CAAC,CAAC,YAAY,CAAC,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC;IAC1E,MAAM,iBAAiB,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,eAAe,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAE3E,MAAM,eAAe,GAAG,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;IACzD,MAAM,iBAAiB,GAAG,OAAO,CAAC,SAAS,CAAC,GAAG,CAAC,eAAe,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC;IAE3E,MAAM,UAAU,GAAG,uBAAuB,CACtC,SAAS,EAAE,QAAQ,EAAE,QAAQ,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,YAAY,EAAE,SAAS,EACpE,IAAI,EAAE,cAAc,EAAE,eAAe,EAAE,cAAc,EACrD,mBAAmB,EAAE,iBAAiB,EAAE,iBAAiB,CAAC,CAAC;IAE/D,MAAM,oBAAoB,GACtB,OAAO,CAAC,QAAQ,CAAC,eAAe,CAAC,MAAM,CAAe,CAAC;IAE3D,IAAI,gBAAwB,CAAC;IAC7B,QAAQ,oBAAoB,CAAC,CAAC,CAAC,EAAE;QAC/B,KAAK,CAAC,CAAC,CAAC;YACN,gBAAgB;gBACZ,YAAY,CAAC,+CAA+C,CACxD,oBAAoB,CAAC,CAAC,CAAC,CAAC,CAAC;YACjC,MAAM;SACP;QACD,KAAK,CAAC,CAAC,CAAC;YACN,gBAAgB;gBACZ,YAAY,CAAC,+CAA+C,CACxD,oBAAoB,CAAC,CAAC,CAAC,EAAE,oBAAoB,CAAC,CAAC,CAAC,CAAC,CAAC;YAC1D,MAAM;SACP;QACD,KAAK,CAAC;YACJ,gBAAgB;gBACZ,YAAY,CAAC,iDAAiD,CAC1D,oBAAoB,CAAC,CAAC,CAAC,EAAE,oBAAoB,CAAC,CAAC,CAAC,EAChD,oBAAoB,CAAC,CAAC,CAAC,CAAC,CAAC;YACjC,MAAM;QACR;YACE,gBAAgB,GAAG,EAAE,CAAC;KACzB;IAED,OAAO,CAAC,WAAW,CAAC,eAAe,CAAC,MAAM,CAAC,CAAC;IAC5C,IAAI,gBAAgB,EAAE;QACpB,OAAO,CAAC,WAAW,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;QAC1C,OAAO,CAAC,WAAW,CAAC,YAAY,CAAC,MAAM,CAAC,CAAC;QACzC,OAAO,CAAC,WAAW,CAAC,iBAAiB,CAAC,MAAM,CAAC,CAAC;QAC9C,OAAO,CAAC,WAAW,CAAC,eAAe,CAAC,MAAM,CAAC,CAAC;QAC5C,MAAM,IAAI,KAAK,CAAC,gBAAgB,CAAC,CAAC;KACnC;IAED,IAAI,cAAc,GAAG,aAAa,CAAC;IACnC,IAAI,aAAa,GAAG,YAAY,CAAC;IACjC,6BAA6B;IAC7B,IAAI,UAAU,KAAK,qBAAqB,CAAC,CAAC,CAAC,EAAE;QAC3C,cAAc,GAAG,KAAK,CAAC;YACrB,MAAM,EAAE,EAAC,CAAC,EAAE,aAAa,EAAC;YAC1B,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,EAAE,IAAI,EAAE,CAAC,UAAU,EAAE,IAAI,CAAC,EAAC;YAC3C,OAAO;SACR,CAAC,CAAC;QACH,aAAa,GAAG,KAAK,CAAC;YACpB,MAAM,EAAE,EAAC,CAAC,EAAE,YAAY,EAAC;YACzB,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,EAAE,IAAI,EAAE,UAAU,EAAC;YACnC,OAAO;SACR,CAAC,CAAC;QACH,OAAO,CAAC,WAAW,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;QAC1C,OAAO,CAAC,WAAW,CAAC,YAAY,CAAC,MAAM,CAAC,CAAC;KAC1C;IAED,OAAO,CAAC,cAAc,EAAE,aAAa,EAAE,iBAAiB,EAAE,eAAe,CAAC,CAAC;AAC7E,CAAC;AAED,MAAM,CAAC,MAAM,yBAAyB,GAAiB;IACrD,UAAU,EAAE,mBAAmB;IAC/B,WAAW,EAAE,MAAM;IACnB,SAAS,EAAE,KAAK;IAChB,UAAU,EAAE,mBAA4C;CACzD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2021 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, KernelConfig, KernelFunc, SparseFillEmptyRows, SparseFillEmptyRowsInputs, TensorInfo} from '@tensorflow/tfjs-core';\n\nimport {BackendWasm} from '../backend_wasm';\nimport {slice} from './Slice';\n\nimport {CppDType} from './types';\n\nlet wasmSparseFillEmptyRows: (\n    indicesId: number, valuesId: number, valuesDType: number,\n    indicesCount: number, denseRows: number, rank: number,\n    defaultValueId: number, outputIndicesId: number, outputValuesId: number,\n    emptyRowIndicatorId: number, reverseIndexMapId: number,\n    exceptionValuesId: number) => number;\n\nexport function setup(backend: BackendWasm): void {\n  wasmSparseFillEmptyRows =\n      backend.wasm.cwrap('SparseFillEmptyRows', 'number', [\n        'number',  // indicesId\n        'number',  // valuesId\n        'number',  // valuesDType\n        'number',  // indicesCount\n        'number',  // denseRows\n        'number',  // rank\n        'number',  // defaultValueId\n        'number',  // outputIndicesId\n        'number',  // outputValuesId\n        'number',  // emptyRowIndicatorId\n        'number',  // reverseIndexMapId\n        'number',  // exceptionValuesId\n      ]);\n}\n\nexport function sparseFillEmptyRows(args: {\n  backend: BackendWasm,\n  inputs: SparseFillEmptyRowsInputs,\n}): [TensorInfo, TensorInfo, TensorInfo, TensorInfo] {\n  const {backend, inputs} = args;\n  const {indices, values, denseShape, defaultValue} = inputs;\n\n  const indicesCount = indices.shape[0];\n  const rank = indices.shape[1];\n  const denseRows = backend.readSync(denseShape.dataId)[0] as number;\n\n  // Set output size to maximum possible and resize later (actual result\n  // might be smaller).\n  const maxOutputIndicesShape = [indicesCount + denseRows, rank];\n\n  const indicesId = backend.dataIdMap.get(indices.dataId).id;\n  const valuesId = backend.dataIdMap.get(values.dataId).id;\n  const defaultValueId = backend.dataIdMap.get(defaultValue.dataId).id;\n\n  const outputIndices =\n      backend.makeOutput(maxOutputIndicesShape, indices.dtype);\n  const outputIndicesId = backend.dataIdMap.get(outputIndices.dataId).id;\n\n  const outputValues =\n      backend.makeOutput(maxOutputIndicesShape.slice(0, 1), values.dtype);\n  const outputValuesId = backend.dataIdMap.get(outputValues.dataId).id;\n\n  const emptyRowIndicator = backend.makeOutput([denseRows], 'bool');\n  const emptyRowIndicatorId =\n      backend.dataIdMap.get(emptyRowIndicator.dataId).id;\n\n  const reverseIndexMap = backend.makeOutput([indicesCount], indices.dtype);\n  const reverseIndexMapId = backend.dataIdMap.get(reverseIndexMap.dataId).id;\n\n  const exceptionValues = backend.makeOutput([4], 'int32');\n  const exceptionValuesId = backend.dataIdMap.get(exceptionValues.dataId).id;\n\n  const outputRows = wasmSparseFillEmptyRows(\n      indicesId, valuesId, CppDType[values.dtype], indicesCount, denseRows,\n      rank, defaultValueId, outputIndicesId, outputValuesId,\n      emptyRowIndicatorId, reverseIndexMapId, exceptionValuesId);\n\n  const exceptionValuesArray =\n      backend.readSync(exceptionValues.dataId) as Int32Array;\n\n  let exceptionMessage: string;\n  switch (exceptionValuesArray[0]) {\n    case 1: {\n      exceptionMessage =\n          backend_util.getSparseFillEmptyRowsIndicesDenseShapeMismatch(\n              exceptionValuesArray[1]);\n      break;\n    }\n    case 2: {\n      exceptionMessage =\n          backend_util.getSparseFillEmptyRowsNegativeIndexErrorMessage(\n              exceptionValuesArray[1], exceptionValuesArray[2]);\n      break;\n    }\n    case 3:\n      exceptionMessage =\n          backend_util.getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(\n              exceptionValuesArray[1], exceptionValuesArray[2],\n              exceptionValuesArray[3]);\n      break;\n    default:\n      exceptionMessage = '';\n  }\n\n  backend.disposeData(exceptionValues.dataId);\n  if (exceptionMessage) {\n    backend.disposeData(outputIndices.dataId);\n    backend.disposeData(outputValues.dataId);\n    backend.disposeData(emptyRowIndicator.dataId);\n    backend.disposeData(reverseIndexMap.dataId);\n    throw new Error(exceptionMessage);\n  }\n\n  let resizedIndices = outputIndices;\n  let resizedValues = outputValues;\n  // Overestimated output size.\n  if (outputRows !== maxOutputIndicesShape[0]) {\n    resizedIndices = slice({\n      inputs: {x: outputIndices},\n      attrs: {begin: 0, size: [outputRows, rank]},\n      backend\n    });\n    resizedValues = slice({\n      inputs: {x: outputValues},\n      attrs: {begin: 0, size: outputRows},\n      backend\n    });\n    backend.disposeData(outputIndices.dataId);\n    backend.disposeData(outputValues.dataId);\n  }\n\n  return [resizedIndices, resizedValues, emptyRowIndicator, reverseIndexMap];\n}\n\nexport const sparseFillEmptyRowsConfig: KernelConfig = {\n  kernelName: SparseFillEmptyRows,\n  backendName: 'wasm',\n  setupFunc: setup,\n  kernelFunc: sparseFillEmptyRows as unknown as KernelFunc\n};\n"]}