@tensorflow/tfjs-converter
Version:
Tensorflow model converter for javascript
71 lines • 12.2 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line: no-imports-from-dist
import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
import { cloneTensor, getParamValue, getTensor } from './utils';
export const executeOp = (node, tensorMap, context) => {
switch (node.op) {
case 'Const': {
return tensorMap[node.name];
}
case 'PlaceholderWithDefault':
const def = getParamValue('default', node, tensorMap, context);
return [getTensor(node.name, tensorMap, context) || def];
case 'Placeholder':
return [getTensor(node.name, tensorMap, context)];
case 'Identity':
case 'StopGradient':
case 'FakeQuantWithMinMaxVars': { // This op is currently ignored.
const data = getParamValue('x', node, tensorMap, context);
return [cloneTensor(data)];
}
case 'IdentityN':
return getParamValue('x', node, tensorMap, context)
.map((t) => cloneTensor(t));
case 'Snapshot':
const snapshot = getParamValue('x', node, tensorMap, context);
return [cloneTensor(snapshot)];
case 'Shape':
return [tfOps.tensor1d(getParamValue('x', node, tensorMap, context).shape, 'int32')];
case 'ShapeN':
return getParamValue('x', node, tensorMap, context)
.map((t) => tfOps.tensor1d(t.shape));
case 'Size':
return [tfOps.scalar(getParamValue('x', node, tensorMap, context).size, 'int32')];
case 'Rank':
return [tfOps.scalar(getParamValue('x', node, tensorMap, context).rank, 'int32')];
case 'NoOp':
return [tfOps.scalar(1)];
case 'Print':
const input = getParamValue('x', node, tensorMap, context);
const data = getParamValue('data', node, tensorMap, context);
const message = getParamValue('message', node, tensorMap, context);
const summarize = getParamValue('summarize', node, tensorMap, context);
console.warn('The graph has a tf.print() operation,' +
'usually used for debugging, which slows down performance.');
console.log(message);
for (let i = 0; i < data.length; i++) {
console.log(Array.prototype.slice.call(data[i].dataSync())
.slice(0, summarize));
}
return [input];
default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
};
export const CATEGORY = 'graph';
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"graph_executor.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/graph_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAM1E,OAAO,EAAC,WAAW,EAAE,aAAa,EAAE,SAAS,EAAC,MAAM,SAAS,CAAC;AAE9D,MAAM,CAAC,MAAM,SAAS,GAClB,CAAC,IAAU,EAAE,SAA0B,EACtC,OAAyB,EAAY,EAAE;IACtC,QAAQ,IAAI,CAAC,EAAE,EAAE;QACf,KAAK,OAAO,CAAC,CAAC;YACZ,OAAO,SAAS,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;SAC7B;QACD,KAAK,wBAAwB;YAC3B,MAAM,GAAG,GACL,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACjE,OAAO,CAAC,SAAS,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;QAC3D,KAAK,aAAa;YAChB,OAAO,CAAC,SAAS,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;QACpD,KAAK,UAAU,CAAC;QAChB,KAAK,cAAc,CAAC;QACpB,KAAK,yBAAyB,CAAC,CAAC,EAAG,gCAAgC;YACjE,MAAM,IAAI,GAAG,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACpE,OAAO,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC,CAAC;SAC5B;QACD,KAAK,WAAW;YACd,OAAQ,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAc;iBAC5D,GAAG,CAAC,CAAC,CAAS,EAAE,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC1C,KAAK,UAAU;YACb,MAAM,QAAQ,GACT,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YAC7D,OAAO,CAAC,WAAW,CAAC,QAAQ,CAAC,CAAC,CAAC;QACjC,KAAK,OAAO;YACV,OAAO,CAAC,KAAK,CAAC,QAAQ,CACjB,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC,KAAK,EAC9D,OAAO,CAAC,CAAC,CAAC;QAChB,KAAK,QAAQ;YACX,OAAQ,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAc;iBAC5D,GAAG,CAAC,CAAC,CAAS,EAAE,EAAE,CAAC,KAAK,CAAC,QAAQ,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC;QACnD,KAAK,MAAM;YACT,OAAO,CAAC,KAAK,CAAC,MAAM,CACf,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC,IAAI,EAC7D,OAAO,CAAC,CAAC,CAAC;QAChB,KAAK,MAAM;YACT,OAAO,CAAC,KAAK,CAAC,MAAM,CACf,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC,IAAI,EAC7D,OAAO,CAAC,CAAC,CAAC;QAChB,KAAK,MAAM;YACT,OAAO,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3B,KAAK,OAAO;YACV,MAAM,KAAK,GAAG,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACrE,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,MAAM,OAAO,GACT,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACjE,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACnE,OAAO,CAAC,IAAI,CACR,uCAAuC;gBACvC,2DAA2D,CAAC,CAAC;YACjE,OAAO,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC;YACrB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;gBACpC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,SAAS,CAAC,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,QAAQ,EAAE,CAAC;qBACzC,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;aACvC;YACD,OAAO,CAAC,KAAK,CAAC,CAAC;QAEjB;YACE,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;KAC9D;AACH,CAAC,CAAC;AAEN,MAAM,CAAC,MAAM,QAAQ,GAAG,OAAO,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {InternalOpExecutor, Node} from '../types';\n\nimport {cloneTensor, getParamValue, getTensor} from './utils';\n\nexport const executeOp: InternalOpExecutor =\n    (node: Node, tensorMap: NamedTensorsMap,\n     context: ExecutionContext): Tensor[] => {\n      switch (node.op) {\n        case 'Const': {\n          return tensorMap[node.name];\n        }\n        case 'PlaceholderWithDefault':\n          const def =\n              getParamValue('default', node, tensorMap, context) as Tensor;\n          return [getTensor(node.name, tensorMap, context) || def];\n        case 'Placeholder':\n          return [getTensor(node.name, tensorMap, context)];\n        case 'Identity':\n        case 'StopGradient':\n        case 'FakeQuantWithMinMaxVars': {  // This op is currently ignored.\n          const data = getParamValue('x', node, tensorMap, context) as Tensor;\n          return [cloneTensor(data)];\n        }\n        case 'IdentityN':\n          return (getParamValue('x', node, tensorMap, context) as Tensor[])\n              .map((t: Tensor) => cloneTensor(t));\n        case 'Snapshot':\n          const snapshot =\n              (getParamValue('x', node, tensorMap, context) as Tensor);\n          return [cloneTensor(snapshot)];\n        case 'Shape':\n          return [tfOps.tensor1d(\n              (getParamValue('x', node, tensorMap, context) as Tensor).shape,\n              'int32')];\n        case 'ShapeN':\n          return (getParamValue('x', node, tensorMap, context) as Tensor[])\n              .map((t: Tensor) => tfOps.tensor1d(t.shape));\n        case 'Size':\n          return [tfOps.scalar(\n              (getParamValue('x', node, tensorMap, context) as Tensor).size,\n              'int32')];\n        case 'Rank':\n          return [tfOps.scalar(\n              (getParamValue('x', node, tensorMap, context) as Tensor).rank,\n              'int32')];\n        case 'NoOp':\n          return [tfOps.scalar(1)];\n        case 'Print':\n          const input = getParamValue('x', node, tensorMap, context) as Tensor;\n          const data =\n              getParamValue('data', node, tensorMap, context) as Tensor[];\n          const message =\n              getParamValue('message', node, tensorMap, context) as string;\n          const summarize =\n              getParamValue('summarize', node, tensorMap, context) as number;\n          console.warn(\n              'The graph has a tf.print() operation,' +\n              'usually used for debugging, which slows down performance.');\n          console.log(message);\n          for (let i = 0; i < data.length; i++) {\n            console.log(Array.prototype.slice.call(data[i].dataSync())\n                            .slice(0, summarize));\n          }\n          return [input];\n\n        default:\n          throw TypeError(`Node type ${node.op} is not implemented`);\n      }\n    };\n\nexport const CATEGORY = 'graph';\n"]}