@tensorflow/tfjs-converter
Version:
Tensorflow model converter for javascript
142 lines • 18.6 kB
JavaScript
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import { parseNodeName } from '../operations/executors/utils';
/**
* Given graph inputs and desired outputs, find the minimal set of nodes
* to execute in order to compute the outputs. In addition return other useful
* info such:
* - Missing inputs needed to compute the output.
* - Whether the subgraph contains dynamic ops (control flow, dynamic shape).
* - Alternative inputs in order to avoid async (dynamic op) execution.
*/
export function getExecutionSubgraph(inputs, outputs, weightMap, initNodes) {
const usedNodes = new Set();
const missingInputs = [];
let dynamicNode = null;
let syncInputs = null;
// Start with the outputs, going backwards and find all the nodes that are
// needed to compute those outputs.
const seen = new Set();
const inputNodeNames = Object.keys(inputs).map(name => parseNodeName(name)[0]);
let initNodeNames = [];
if (initNodes != null) {
initNodeNames = initNodes.map(node => parseNodeName(node.name)[0]);
}
const frontier = [...outputs];
while (frontier.length > 0) {
const node = frontier.pop();
if (isControlFlow(node) || isDynamicShape(node) || isHashTable(node)) {
if (dynamicNode == null) {
dynamicNode = node;
syncInputs = dynamicNode.children.map(child => child.name)
.filter(name => usedNodes.has(name));
}
}
usedNodes.add(node.name);
// Weights are dead end since we already have their values.
if (weightMap[node.name] != null) {
continue;
}
// This node is a dead end since it's one of the user-provided inputs.
if (inputNodeNames.indexOf(node.name) !== -1) {
continue;
}
// This node is a dead end since it doesn't have any inputs.
if (initNodeNames.indexOf(node.name) !== -1) {
continue;
}
if (node.inputs.length === 0) {
missingInputs.push(node.name);
continue;
}
node.inputs.forEach(input => {
// Don't add to the frontier if it is already there.
if (seen.has(input.name)) {
return;
}
seen.add(input.name);
frontier.push(input);
});
}
return { inputs, outputs, usedNodes, missingInputs, dynamicNode, syncInputs };
}
/**
* Given the execution info, return a list of nodes in topological order that
* need to be executed to compute the output.
*/
export function getNodesInTopologicalOrder(graph, weightMap, executionInfo) {
const { usedNodes, inputs } = executionInfo;
const frontier = [];
const inputNodes = Object.keys(inputs)
.map(name => parseNodeName(name)[0])
.map(name => graph.nodes[name]);
const initNodes = graph.initNodes;
inputNodes.forEach(input => {
if (usedNodes.has(input.name)) {
frontier.push(input);
}
});
graph.weights.forEach(weight => {
if (usedNodes.has(weight.name)) {
frontier.push(weight);
}
});
if (initNodes != null) {
initNodes.forEach(node => {
if (usedNodes.has(node.name)) {
frontier.push(node);
}
});
}
const seen = new Set();
const orderedNodes = [];
while (frontier.length > 0) {
const node = frontier.pop();
seen.add(node.name);
if (!weightMap[node.name]) {
orderedNodes.push(node);
}
node.children.forEach(child => {
if (!seen.has(child.name) && usedNodes.has(child.name) &&
child.inputs.every(input => seen.has(input.name))) {
frontier.push(child);
}
});
}
return orderedNodes;
}
const CONTROL_FLOW_OPS = [
'Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf',
'StatelessWhile', 'if', 'While'
];
const DYNAMIC_SHAPE_OPS = [
'NonMaxSuppressionV2', 'NonMaxSuppressionV3', 'NonMaxSuppressionV5', 'Where'
];
const HASH_TABLE_OPS = [
'HashTable', 'HashTableV2', 'LookupTableImport', 'LookupTableImportV2',
'LookupTableFind', 'LookupTableFindV2', 'LookupTableSize', 'LookupTableSizeV2'
];
export function isControlFlow(node) {
return CONTROL_FLOW_OPS.indexOf(node.op) >= 0;
}
export function isDynamicShape(node) {
return DYNAMIC_SHAPE_OPS.indexOf(node.op) >= 0;
}
export function isHashTable(node) {
return HASH_TABLE_OPS.indexOf(node.op) >= 0;
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"model_analysis.js","sourceRoot":"","sources":["../../../../../../tfjs-converter/src/executor/model_analysis.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAKH,OAAO,EAAC,aAAa,EAAC,MAAM,+BAA+B,CAAC;AAY5D;;;;;;;GAOG;AACH,MAAM,UAAU,oBAAoB,CAChC,MAAsB,EAAE,OAAe,EAAE,SAA0B,EACnE,SAAkB;IACpB,MAAM,SAAS,GAAG,IAAI,GAAG,EAAU,CAAC;IACpC,MAAM,aAAa,GAAa,EAAE,CAAC;IACnC,IAAI,WAAW,GAAS,IAAI,CAAC;IAC7B,IAAI,UAAU,GAAa,IAAI,CAAC;IAEhC,0EAA0E;IAC1E,mCAAmC;IACnC,MAAM,IAAI,GAAG,IAAI,GAAG,EAAU,CAAC;IAC/B,MAAM,cAAc,GAChB,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAE5D,IAAI,aAAa,GAAa,EAAE,CAAC;IACjC,IAAI,SAAS,IAAI,IAAI,EAAE;QACrB,aAAa,GAAG,SAAS,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;KACpE;IAED,MAAM,QAAQ,GAAG,CAAC,GAAG,OAAO,CAAC,CAAC;IAC9B,OAAO,QAAQ,CAAC,MAAM,GAAG,CAAC,EAAE;QAC1B,MAAM,IAAI,GAAG,QAAQ,CAAC,GAAG,EAAE,CAAC;QAC5B,IAAI,aAAa,CAAC,IAAI,CAAC,IAAI,cAAc,CAAC,IAAI,CAAC,IAAI,WAAW,CAAC,IAAI,CAAC,EAAE;YACpE,IAAI,WAAW,IAAI,IAAI,EAAE;gBACvB,WAAW,GAAG,IAAI,CAAC;gBACnB,UAAU,GAAG,WAAW,CAAC,QAAQ,CAAC,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,IAAI,CAAC;qBACxC,MAAM,CAAC,IAAI,CAAC,EAAE,CAAC,SAAS,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC;aACvD;SACF;QACD,SAAS,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QAEzB,2DAA2D;QAC3D,IAAI,SAAS,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,IAAI,EAAE;YAChC,SAAS;SACV;QACD,sEAAsE;QACtE,IAAI,cAAc,CAAC,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE;YAC5C,SAAS;SACV;QACD,4DAA4D;QAC5D,IAAI,aAAa,CAAC,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE;YAC3C,SAAS;SACV;QACD,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE;YAC5B,aAAa,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;YAC9B,SAAS;SACV;QACD,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YAC1B,oDAAoD;YACpD,IAAI,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,IAAI,CAAC,EAAE;gBACxB,OAAO;aACR;YACD,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC;YACrB,QAAQ,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACvB,CAAC,CAAC,CAAC;KACJ;IACD,OAAO,EAAC,MAAM,EAAE,OAAO,EAAE,SAAS,EAAE,aAAa,EAAE,WAAW,EAAE,UAAU,EAAC,CAAC;AAC9E,CAAC;AAED;;;GAGG;AACH,MAAM,UAAU,0BAA0B,CACtC,KAAY,EAAE,SAA0B,EACxC,aAA4B;IAC9B,MAAM,EAAC,SAAS,EAAE,MAAM,EAAC,GAAG,aAAa,CAAC;IAC1C,MAAM,QAAQ,GAAW,EAAE,CAAC;IAC5B,MAAM,UAAU,GAAG,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC;SACd,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;SACnC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC;IACvD,MAAM,SAAS,GAAG,KAAK,CAAC,SAAS,CAAC;IAElC,UAAU,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;QACzB,IAAI,SAAS,CAAC,GAAG,CAAC,KAAK,CAAC,IAAI,CAAC,EAAE;YAC7B,QAAQ,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;SACtB;IACH,CAAC,CAAC,CAAC;IACH,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;QAC7B,IAAI,SAAS,CAAC,GAAG,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE;YAC9B,QAAQ,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;SACvB;IACH,CAAC,CAAC,CAAC;IACH,IAAI,SAAS,IAAI,IAAI,EAAE;QACrB,SAAS,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE;YACvB,IAAI,SAAS,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,EAAE;gBAC5B,QAAQ,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;aACrB;QACH,CAAC,CAAC,CAAC;KACJ;IACD,MAAM,IAAI,GAAG,IAAI,GAAG,EAAU,CAAC;IAC/B,MAAM,YAAY,GAAW,EAAE,CAAC;IAChC,OAAO,QAAQ,CAAC,MAAM,GAAG,CAAC,EAAE;QAC1B,MAAM,IAAI,GAAG,QAAQ,CAAC,GAAG,EAAE,CAAC;QAC5B,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QACpB,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,IAAI,CAAC,EAAE;YACzB,YAAY,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;SACzB;QACD,IAAI,CAAC,QAAQ,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;YAC5B,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,IAAI,CAAC,IAAI,SAAS,CAAC,GAAG,CAAC,KAAK,CAAC,IAAI,CAAC;gBAClD,KAAK,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,EAAE;gBACrD,QAAQ,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;aACtB;QACH,CAAC,CAAC,CAAC;KACJ;IACD,OAAO,YAAY,CAAC;AACtB,CAAC;AAED,MAAM,gBAAgB,GAAG;IACvB,QAAQ,EAAE,OAAO,EAAE,OAAO,EAAE,MAAM,EAAE,eAAe,EAAE,aAAa;IAClE,gBAAgB,EAAE,IAAI,EAAE,OAAO;CAChC,CAAC;AACF,MAAM,iBAAiB,GAAG;IACxB,qBAAqB,EAAE,qBAAqB,EAAE,qBAAqB,EAAE,OAAO;CAC7E,CAAC;AACF,MAAM,cAAc,GAAG;IACrB,WAAW,EAAE,aAAa,EAAE,mBAAmB,EAAE,qBAAqB;IACtE,iBAAiB,EAAE,mBAAmB,EAAE,iBAAiB,EAAE,mBAAmB;CAC/E,CAAC;AAEF,MAAM,UAAU,aAAa,CAAC,IAAU;IACtC,OAAO,gBAAgB,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC;AAChD,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,IAAU;IACvC,OAAO,iBAAiB,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC;AACjD,CAAC;AAED,MAAM,UAAU,WAAW,CAAC,IAAU;IACpC,OAAO,cAAc,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC;AAC9C,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {NamedTensorMap} from '@tensorflow/tfjs-core';\n\nimport {NamedTensorsMap} from '../data/types';\nimport {parseNodeName} from '../operations/executors/utils';\nimport {Graph, Node} from '../operations/types';\n\nexport interface ExecutionInfo {\n  inputs: NamedTensorMap;\n  outputs: Node[];\n  usedNodes: Set<string>;\n  missingInputs: string[];\n  dynamicNode: Node;\n  syncInputs: string[];\n}\n\n/**\n * Given graph inputs and desired outputs, find the minimal set of nodes\n * to execute in order to compute the outputs. In addition return other useful\n * info such:\n * - Missing inputs needed to compute the output.\n * - Whether the subgraph contains dynamic ops (control flow, dynamic shape).\n * - Alternative inputs in order to avoid async (dynamic op) execution.\n */\nexport function getExecutionSubgraph(\n    inputs: NamedTensorMap, outputs: Node[], weightMap: NamedTensorsMap,\n    initNodes?: Node[]): ExecutionInfo {\n  const usedNodes = new Set<string>();\n  const missingInputs: string[] = [];\n  let dynamicNode: Node = null;\n  let syncInputs: string[] = null;\n\n  // Start with the outputs, going backwards and find all the nodes that are\n  // needed to compute those outputs.\n  const seen = new Set<string>();\n  const inputNodeNames =\n      Object.keys(inputs).map(name => parseNodeName(name)[0]);\n\n  let initNodeNames: string[] = [];\n  if (initNodes != null) {\n    initNodeNames = initNodes.map(node => parseNodeName(node.name)[0]);\n  }\n\n  const frontier = [...outputs];\n  while (frontier.length > 0) {\n    const node = frontier.pop();\n    if (isControlFlow(node) || isDynamicShape(node) || isHashTable(node)) {\n      if (dynamicNode == null) {\n        dynamicNode = node;\n        syncInputs = dynamicNode.children.map(child => child.name)\n                         .filter(name => usedNodes.has(name));\n      }\n    }\n    usedNodes.add(node.name);\n\n    // Weights are dead end since we already have their values.\n    if (weightMap[node.name] != null) {\n      continue;\n    }\n    // This node is a dead end since it's one of the user-provided inputs.\n    if (inputNodeNames.indexOf(node.name) !== -1) {\n      continue;\n    }\n    // This node is a dead end since it doesn't have any inputs.\n    if (initNodeNames.indexOf(node.name) !== -1) {\n      continue;\n    }\n    if (node.inputs.length === 0) {\n      missingInputs.push(node.name);\n      continue;\n    }\n    node.inputs.forEach(input => {\n      // Don't add to the frontier if it is already there.\n      if (seen.has(input.name)) {\n        return;\n      }\n      seen.add(input.name);\n      frontier.push(input);\n    });\n  }\n  return {inputs, outputs, usedNodes, missingInputs, dynamicNode, syncInputs};\n}\n\n/**\n * Given the execution info, return a list of nodes in topological order that\n * need to be executed to compute the output.\n */\nexport function getNodesInTopologicalOrder(\n    graph: Graph, weightMap: NamedTensorsMap,\n    executionInfo: ExecutionInfo): Node[] {\n  const {usedNodes, inputs} = executionInfo;\n  const frontier: Node[] = [];\n  const inputNodes = Object.keys(inputs)\n                         .map(name => parseNodeName(name)[0])\n                         .map(name => graph.nodes[name]);\n  const initNodes = graph.initNodes;\n\n  inputNodes.forEach(input => {\n    if (usedNodes.has(input.name)) {\n      frontier.push(input);\n    }\n  });\n  graph.weights.forEach(weight => {\n    if (usedNodes.has(weight.name)) {\n      frontier.push(weight);\n    }\n  });\n  if (initNodes != null) {\n    initNodes.forEach(node => {\n      if (usedNodes.has(node.name)) {\n        frontier.push(node);\n      }\n    });\n  }\n  const seen = new Set<string>();\n  const orderedNodes: Node[] = [];\n  while (frontier.length > 0) {\n    const node = frontier.pop();\n    seen.add(node.name);\n    if (!weightMap[node.name]) {\n      orderedNodes.push(node);\n    }\n    node.children.forEach(child => {\n      if (!seen.has(child.name) && usedNodes.has(child.name) &&\n          child.inputs.every(input => seen.has(input.name))) {\n        frontier.push(child);\n      }\n    });\n  }\n  return orderedNodes;\n}\n\nconst CONTROL_FLOW_OPS = [\n  'Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf',\n  'StatelessWhile', 'if', 'While'\n];\nconst DYNAMIC_SHAPE_OPS = [\n  'NonMaxSuppressionV2', 'NonMaxSuppressionV3', 'NonMaxSuppressionV5', 'Where'\n];\nconst HASH_TABLE_OPS = [\n  'HashTable', 'HashTableV2', 'LookupTableImport', 'LookupTableImportV2',\n  'LookupTableFind', 'LookupTableFindV2', 'LookupTableSize', 'LookupTableSizeV2'\n];\n\nexport function isControlFlow(node: Node) {\n  return CONTROL_FLOW_OPS.indexOf(node.op) >= 0;\n}\n\nexport function isDynamicShape(node: Node) {\n  return DYNAMIC_SHAPE_OPS.indexOf(node.op) >= 0;\n}\n\nexport function isHashTable(node: Node) {\n  return HASH_TABLE_OPS.indexOf(node.op) >= 0;\n}\n"]}