UNPKG

@tensorflow/tfjs-converter

Version:

Tensorflow model converter for javascript

106 lines 18.9 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as tfc from '@tensorflow/tfjs-core'; import { NodeValueImpl } from './custom_op/node_value_impl'; import { getRegisteredOp } from './custom_op/register'; import * as arithmetic from './executors/arithmetic_executor'; import * as basicMath from './executors/basic_math_executor'; import * as control from './executors/control_executor'; import * as convolution from './executors/convolution_executor'; import * as creation from './executors/creation_executor'; import * as dynamic from './executors/dynamic_executor'; import * as evaluation from './executors/evaluation_executor'; import * as graph from './executors/graph_executor'; import * as hashTable from './executors/hash_table_executor'; import * as image from './executors/image_executor'; import * as logical from './executors/logical_executor'; import * as matrices from './executors/matrices_executor'; import * as normalization from './executors/normalization_executor'; import * as reduction from './executors/reduction_executor'; import * as sliceJoin from './executors/slice_join_executor'; import * as sparse from './executors/sparse_executor'; import * as spectral from './executors/spectral_executor'; import * as string from './executors/string_executor'; import * as transformation from './executors/transformation_executor'; /** * Executes the op defined by the node object. * @param node * @param tensorMap contains tensors for executed nodes and weights * @param context contains tensors and information for running the current node. * @param resourceManager Optional. Contains global resources of the model. */ export function executeOp(node, tensorMap, context, resourceManager) { const value = ((node, tensorMap, context) => { switch (node.category) { case 'arithmetic': return tfc.tidy(() => arithmetic.executeOp(node, tensorMap, context)); case 'basic_math': return tfc.tidy(() => basicMath.executeOp(node, tensorMap, context)); case 'control': return control.executeOp(node, tensorMap, context); case 'convolution': return tfc.tidy(() => convolution.executeOp(node, tensorMap, context)); case 'creation': return tfc.tidy(() => creation.executeOp(node, tensorMap, context)); case 'dynamic': return dynamic.executeOp(node, tensorMap, context); case 'evaluation': return tfc.tidy(() => evaluation.executeOp(node, tensorMap, context)); case 'image': return tfc.tidy(() => image.executeOp(node, tensorMap, context)); case 'graph': return tfc.tidy(() => graph.executeOp(node, tensorMap, context)); case 'logical': return tfc.tidy(() => logical.executeOp(node, tensorMap, context)); case 'matrices': return tfc.tidy(() => matrices.executeOp(node, tensorMap, context)); case 'normalization': return tfc.tidy(() => normalization.executeOp(node, tensorMap, context)); case 'reduction': return tfc.tidy(() => reduction.executeOp(node, tensorMap, context)); case 'slice_join': return tfc.tidy(() => sliceJoin.executeOp(node, tensorMap, context)); case 'sparse': return tfc.tidy(() => sparse.executeOp(node, tensorMap, context)); case 'spectral': return tfc.tidy(() => spectral.executeOp(node, tensorMap, context)); case 'string': return tfc.tidy(() => string.executeOp(node, tensorMap, context)); case 'transformation': return tfc.tidy(() => transformation.executeOp(node, tensorMap, context)); case 'hash_table': return hashTable.executeOp(node, tensorMap, context, resourceManager); case 'custom': const opMapper = getRegisteredOp(node.op); if (opMapper && opMapper.customExecutor) { return opMapper.customExecutor(new NodeValueImpl(node, tensorMap, context)); } else { throw TypeError(`Custom op ${node.op} is not registered.`); } default: throw TypeError(`Unknown op '${node.op}'. File an issue at ` + `https://github.com/tensorflow/tfjs/issues so we can add it` + `, or register a custom execution with tf.registerOp()`); } })(node, tensorMap, context); if (tfc.util.isPromise(value)) { return value.then((data) => [].concat(data)); } return [].concat(value); } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"operation_executor.js","sourceRoot":"","sources":["../../../../../../tfjs-converter/src/operations/operation_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,GAAG,MAAM,uBAAuB,CAAC;AAM7C,OAAO,EAAC,aAAa,EAAC,MAAM,6BAA6B,CAAC;AAC1D,OAAO,EAAC,eAAe,EAAC,MAAM,sBAAsB,CAAC;AACrD,OAAO,KAAK,UAAU,MAAM,iCAAiC,CAAC;AAC9D,OAAO,KAAK,SAAS,MAAM,iCAAiC,CAAC;AAC7D,OAAO,KAAK,OAAO,MAAM,8BAA8B,CAAC;AACxD,OAAO,KAAK,WAAW,MAAM,kCAAkC,CAAC;AAChE,OAAO,KAAK,QAAQ,MAAM,+BAA+B,CAAC;AAC1D,OAAO,KAAK,OAAO,MAAM,8BAA8B,CAAC;AACxD,OAAO,KAAK,UAAU,MAAM,iCAAiC,CAAC;AAC9D,OAAO,KAAK,KAAK,MAAM,4BAA4B,CAAC;AACpD,OAAO,KAAK,SAAS,MAAM,iCAAiC,CAAC;AAC7D,OAAO,KAAK,KAAK,MAAM,4BAA4B,CAAC;AACpD,OAAO,KAAK,OAAO,MAAM,8BAA8B,CAAC;AACxD,OAAO,KAAK,QAAQ,MAAM,+BAA+B,CAAC;AAC1D,OAAO,KAAK,aAAa,MAAM,oCAAoC,CAAC;AACpE,OAAO,KAAK,SAAS,MAAM,gCAAgC,CAAC;AAC5D,OAAO,KAAK,SAAS,MAAM,iCAAiC,CAAC;AAC7D,OAAO,KAAK,MAAM,MAAM,6BAA6B,CAAC;AACtD,OAAO,KAAK,QAAQ,MAAM,+BAA+B,CAAC;AAC1D,OAAO,KAAK,MAAM,MAAM,6BAA6B,CAAC;AACtD,OAAO,KAAK,cAAc,MAAM,qCAAqC,CAAC;AAGtE;;;;;;GAMG;AACH,MAAM,UAAU,SAAS,CACrB,IAAU,EAAE,SAA0B,EAAE,OAAyB,EACjE,eAAiC;IACnC,MAAM,KAAK,GACP,CAAC,CAAC,IAAU,EAAE,SAA0B,EAAE,OAAyB,EAAE,EAAE;QACrE,QAAQ,IAAI,CAAC,QAAQ,EAAE;YACrB,KAAK,YAAY;gBACf,OAAO,GAAG,CAAC,IAAI,CACX,GAAG,EAAE,CAAC,UAAU,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YAC5D,KAAK,YAAY;gBACf,OAAO,GAAG,CAAC,IAAI,CACX,GAAG,EAAE,CAAC,SAAS,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YAC3D,KAAK,SAAS;gBACZ,OAAO,OAAO,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YACrD,KAAK,aAAa;gBAChB,OAAO,GAAG,CAAC,IAAI,CACX,GAAG,EAAE,CAAC,WAAW,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YAC7D,KAAK,UAAU;gBACb,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YACtE,KAAK,SAAS;gBACZ,OAAO,OAAO,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;YACrD,KAAK,YAAY;gBACf,OAAO,GAAG,CAAC,IAAI,CACX,GAAG,EAAE,CAAC,UAAU,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YAC5D,KAAK,OAAO;gBACV,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YACnE,KAAK,OAAO;gBACV,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YACnE,KAAK,SAAS;gBACZ,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YACrE,KAAK,UAAU;gBACb,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YACtE,KAAK,eAAe;gBAClB,OAAO,GAAG,CAAC,IAAI,CACX,GAAG,EAAE,CAAC,aAAa,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YAC/D,KAAK,WAAW;gBACd,OAAO,GAAG,CAAC,IAAI,CACX,GAAG,EAAE,CAAC,SAAS,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YAC3D,KAAK,YAAY;gBACf,OAAO,GAAG,CAAC,IAAI,CACX,GAAG,EAAE,CAAC,SAAS,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YAC3D,KAAK,QAAQ;gBACX,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YACpE,KAAK,UAAU;gBACb,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YACtE,KAAK,QAAQ;gBACX,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YACpE,KAAK,gBAAgB;gBACnB,OAAO,GAAG,CAAC,IAAI,CACX,GAAG,EAAE,CAAC,cAAc,CAAC,SAAS,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;YAChE,KAAK,YAAY;gBACf,OAAO,SAAS,CAAC,SAAS,CACtB,IAAI,EAAE,SAAS,EAAE,OAAO,EAAE,eAAe,CAAC,CAAC;YACjD,KAAK,QAAQ;gBACX,MAAM,QAAQ,GAAG,eAAe,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;gBAC1C,IAAI,QAAQ,IAAI,QAAQ,CAAC,cAAc,EAAE;oBACvC,OAAO,QAAQ,CAAC,cAAc,CAC1B,IAAI,aAAa,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;iBAClD;qBAAM;oBACL,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;iBAC5D;YACH;gBACE,MAAM,SAAS,CACX,eAAe,IAAI,CAAC,EAAE,sBAAsB;oBAC5C,4DAA4D;oBAC5D,uDAAuD,CAAC,CAAC;SAChE;IACH,CAAC,CAAC,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;IACjC,IAAI,GAAG,CAAC,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE;QAC7B,OAAQ,KAA6B,CAAC,IAAI,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC;KACvE;IACD,OAAO,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;AAC1B,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tfc from '@tensorflow/tfjs-core';\n\nimport {NamedTensorsMap} from '../data/types';\nimport {ExecutionContext} from '../executor/execution_context';\nimport {ResourceManager} from '../executor/resource_manager';\n\nimport {NodeValueImpl} from './custom_op/node_value_impl';\nimport {getRegisteredOp} from './custom_op/register';\nimport * as arithmetic from './executors/arithmetic_executor';\nimport * as basicMath from './executors/basic_math_executor';\nimport * as control from './executors/control_executor';\nimport * as convolution from './executors/convolution_executor';\nimport * as creation from './executors/creation_executor';\nimport * as dynamic from './executors/dynamic_executor';\nimport * as evaluation from './executors/evaluation_executor';\nimport * as graph from './executors/graph_executor';\nimport * as hashTable from './executors/hash_table_executor';\nimport * as image from './executors/image_executor';\nimport * as logical from './executors/logical_executor';\nimport * as matrices from './executors/matrices_executor';\nimport * as normalization from './executors/normalization_executor';\nimport * as reduction from './executors/reduction_executor';\nimport * as sliceJoin from './executors/slice_join_executor';\nimport * as sparse from './executors/sparse_executor';\nimport * as spectral from './executors/spectral_executor';\nimport * as string from './executors/string_executor';\nimport * as transformation from './executors/transformation_executor';\nimport {Node} from './types';\n\n/**\n * Executes the op defined by the node object.\n * @param node\n * @param tensorMap contains tensors for executed nodes and weights\n * @param context contains tensors and information for running the current node.\n * @param resourceManager Optional. Contains global resources of the model.\n */\nexport function executeOp(\n    node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext,\n    resourceManager?: ResourceManager): tfc.Tensor[]|Promise<tfc.Tensor[]> {\n  const value =\n      ((node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext) => {\n        switch (node.category) {\n          case 'arithmetic':\n            return tfc.tidy(\n                () => arithmetic.executeOp(node, tensorMap, context));\n          case 'basic_math':\n            return tfc.tidy(\n                () => basicMath.executeOp(node, tensorMap, context));\n          case 'control':\n            return control.executeOp(node, tensorMap, context);\n          case 'convolution':\n            return tfc.tidy(\n                () => convolution.executeOp(node, tensorMap, context));\n          case 'creation':\n            return tfc.tidy(() => creation.executeOp(node, tensorMap, context));\n          case 'dynamic':\n            return dynamic.executeOp(node, tensorMap, context);\n          case 'evaluation':\n            return tfc.tidy(\n                () => evaluation.executeOp(node, tensorMap, context));\n          case 'image':\n            return tfc.tidy(() => image.executeOp(node, tensorMap, context));\n          case 'graph':\n            return tfc.tidy(() => graph.executeOp(node, tensorMap, context));\n          case 'logical':\n            return tfc.tidy(() => logical.executeOp(node, tensorMap, context));\n          case 'matrices':\n            return tfc.tidy(() => matrices.executeOp(node, tensorMap, context));\n          case 'normalization':\n            return tfc.tidy(\n                () => normalization.executeOp(node, tensorMap, context));\n          case 'reduction':\n            return tfc.tidy(\n                () => reduction.executeOp(node, tensorMap, context));\n          case 'slice_join':\n            return tfc.tidy(\n                () => sliceJoin.executeOp(node, tensorMap, context));\n          case 'sparse':\n            return tfc.tidy(() => sparse.executeOp(node, tensorMap, context));\n          case 'spectral':\n            return tfc.tidy(() => spectral.executeOp(node, tensorMap, context));\n          case 'string':\n            return tfc.tidy(() => string.executeOp(node, tensorMap, context));\n          case 'transformation':\n            return tfc.tidy(\n                () => transformation.executeOp(node, tensorMap, context));\n          case 'hash_table':\n            return hashTable.executeOp(\n                node, tensorMap, context, resourceManager);\n          case 'custom':\n            const opMapper = getRegisteredOp(node.op);\n            if (opMapper && opMapper.customExecutor) {\n              return opMapper.customExecutor(\n                  new NodeValueImpl(node, tensorMap, context));\n            } else {\n              throw TypeError(`Custom op ${node.op} is not registered.`);\n            }\n          default:\n            throw TypeError(\n                `Unknown op '${node.op}'. File an issue at ` +\n                `https://github.com/tensorflow/tfjs/issues so we can add it` +\n                `, or register a custom execution with tf.registerOp()`);\n        }\n      })(node, tensorMap, context);\n  if (tfc.util.isPromise(value)) {\n    return (value as Promise<tfc.Tensor>).then((data) => [].concat(data));\n  }\n  return [].concat(value);\n}\n"]}