UNPKG

@tensorflow/tfjs-converter

Version:

Tensorflow model converter for javascript

65 lines 12.2 kB
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // tslint:disable-next-line: no-imports-from-dist import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter'; import { getParamValue } from './utils'; export const executeOp = (node, tensorMap, context) => { switch (node.op) { case 'BiasAdd': case 'AddV2': case 'Add': { return [tfOps.add(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; } case 'AddN': { return [tfOps.addN(getParamValue('tensors', node, tensorMap, context))]; } case 'FloorMod': case 'Mod': return [tfOps.mod(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; case 'Mul': return [tfOps.mul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; case 'RealDiv': case 'Div': { return [tfOps.div(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; } case 'DivNoNan': { return [tfOps.divNoNan(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; } case 'FloorDiv': { return [tfOps.floorDiv(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; } case 'Sub': { return [tfOps.sub(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; } case 'Minimum': { return [tfOps.minimum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; } case 'Maximum': { return [tfOps.maximum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; } case 'Pow': { return [tfOps.pow(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; } case 'SquaredDifference': { return [tfOps.squaredDifference(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))]; } default: throw TypeError(`Node type ${node.op} is not implemented`); } }; export const CATEGORY = 'arithmetic'; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"arithmetic_executor.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/arithmetic_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAM1E,OAAO,EAAC,aAAa,EAAC,MAAM,SAAS,CAAC;AAEtC,MAAM,CAAC,MAAM,SAAS,GAClB,CAAC,IAAU,EAAE,SAA0B,EACtC,OAAyB,EAAY,EAAE;IACtC,QAAQ,IAAI,CAAC,EAAE,EAAE;QACf,KAAK,SAAS,CAAC;QACf,KAAK,OAAO,CAAC;QACb,KAAK,KAAK,CAAC,CAAC;YACV,OAAO,CAAC,KAAK,CAAC,GAAG,CACZ,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,EACxD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,MAAM,CAAC,CAAC;YACX,OAAO,CAAC,KAAK,CAAC,IAAI,CACd,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAc,CAAC,CAAC,CAAC;SACvE;QACD,KAAK,UAAU,CAAC;QAChB,KAAK,KAAK;YACR,OAAO,CAAC,KAAK,CAAC,GAAG,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;QAC/D,KAAK,KAAK;YACR,OAAO,CAAC,KAAK,CAAC,GAAG,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;QAC/D,KAAK,SAAS,CAAC;QACf,KAAK,KAAK,CAAC,CAAC;YACV,OAAO,CAAC,KAAK,CAAC,GAAG,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,UAAU,CAAC,CAAC;YACf,OAAO,CAAC,KAAK,CAAC,QAAQ,CAClB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,UAAU,CAAC,CAAC;YACf,OAAO,CAAC,KAAK,CAAC,QAAQ,CAClB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,KAAK,CAAC,CAAC;YACV,OAAO,CAAC,KAAK,CAAC,GAAG,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,SAAS,CAAC,CAAC;YACd,OAAO,CAAC,KAAK,CAAC,OAAO,CACjB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,SAAS,CAAC,CAAC;YACd,OAAO,CAAC,KAAK,CAAC,OAAO,CACjB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,KAAK,CAAC,CAAC;YACV,OAAO,CAAC,KAAK,CAAC,GAAG,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD,KAAK,mBAAmB,CAAC,CAAC;YACxB,OAAO,CAAC,KAAK,CAAC,iBAAiB,CAC3B,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,EACtD,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC,CAAC,CAAC;SAC9D;QACD;YACE,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;KAC9D;AACH,CAAC,CAAC;AAEN,MAAM,CAAC,MAAM,QAAQ,GAAG,YAAY,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {InternalOpExecutor, Node} from '../types';\n\nimport {getParamValue} from './utils';\n\nexport const executeOp: InternalOpExecutor =\n    (node: Node, tensorMap: NamedTensorsMap,\n     context: ExecutionContext): Tensor[] => {\n      switch (node.op) {\n        case 'BiasAdd':\n        case 'AddV2':\n        case 'Add': {\n          return [tfOps.add(\n              (getParamValue('a', node, tensorMap, context) as Tensor),\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        }\n        case 'AddN': {\n          return [tfOps.addN((\n              getParamValue('tensors', node, tensorMap, context) as Tensor[]))];\n        }\n        case 'FloorMod':\n        case 'Mod':\n          return [tfOps.mod(\n              getParamValue('a', node, tensorMap, context) as Tensor,\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        case 'Mul':\n          return [tfOps.mul(\n              getParamValue('a', node, tensorMap, context) as Tensor,\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        case 'RealDiv':\n        case 'Div': {\n          return [tfOps.div(\n              getParamValue('a', node, tensorMap, context) as Tensor,\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        }\n        case 'DivNoNan': {\n          return [tfOps.divNoNan(\n              getParamValue('a', node, tensorMap, context) as Tensor,\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        }\n        case 'FloorDiv': {\n          return [tfOps.floorDiv(\n              getParamValue('a', node, tensorMap, context) as Tensor,\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        }\n        case 'Sub': {\n          return [tfOps.sub(\n              getParamValue('a', node, tensorMap, context) as Tensor,\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        }\n        case 'Minimum': {\n          return [tfOps.minimum(\n              getParamValue('a', node, tensorMap, context) as Tensor,\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        }\n        case 'Maximum': {\n          return [tfOps.maximum(\n              getParamValue('a', node, tensorMap, context) as Tensor,\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        }\n        case 'Pow': {\n          return [tfOps.pow(\n              getParamValue('a', node, tensorMap, context) as Tensor,\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        }\n        case 'SquaredDifference': {\n          return [tfOps.squaredDifference(\n              getParamValue('a', node, tensorMap, context) as Tensor,\n              getParamValue('b', node, tensorMap, context) as Tensor)];\n        }\n        default:\n          throw TypeError(`Node type ${node.op} is not implemented`);\n      }\n    };\n\nexport const CATEGORY = 'arithmetic';\n"]}