UNPKG

@tensorflow/tfjs-converter

Version:

Tensorflow model converter for javascript

89 lines 12.7 kB
/** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { getTensor } from '../executors/utils'; import { getBoolArrayParam, getBoolParam, getDtypeArrayParam, getDtypeParam, getNumberParam, getNumericArrayParam, getStringArrayParam, getStringParam, getTensorShapeArrayParam, getTensorShapeParam } from '../operation_mapper'; /** * Helper class for lookup inputs and params for nodes in the model graph. */ export class NodeValueImpl { constructor(node, tensorMap, context) { this.node = node; this.tensorMap = tensorMap; this.context = context; this.inputs = []; this.attrs = {}; this.inputs = node.inputNames.map(name => this.getInput(name)); if (node.rawAttrs != null) { this.attrs = Object.keys(node.rawAttrs) .reduce((attrs, key) => { attrs[key] = this.getAttr(key); return attrs; }, {}); } } /** * Return the value of the attribute or input param. * @param name String: name of attribute or input param. */ getInput(name) { return getTensor(name, this.tensorMap, this.context); } /** * Return the value of the attribute or input param. * @param name String: name of attribute or input param. */ getAttr(name, defaultValue) { const value = this.node.rawAttrs[name]; if (value.tensor != null) { return getTensor(name, this.tensorMap, this.context); } if (value.i != null || value.f != null) { return getNumberParam(this.node.rawAttrs, name, defaultValue); } if (value.s != null) { return getStringParam(this.node.rawAttrs, name, defaultValue); } if (value.b != null) { return getBoolParam(this.node.rawAttrs, name, defaultValue); } if (value.shape != null) { return getTensorShapeParam(this.node.rawAttrs, name, defaultValue); } if (value.type != null) { return getDtypeParam(this.node.rawAttrs, name, defaultValue); } if (value.list != null) { if (value.list.i != null || value.list.f != null) { return getNumericArrayParam(this.node.rawAttrs, name, defaultValue); } if (value.list.s != null) { return getStringArrayParam(this.node.rawAttrs, name, defaultValue); } if (value.list.shape != null) { return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue); } if (value.list.b != null) { return getBoolArrayParam(this.node.rawAttrs, name, defaultValue); } if (value.list.type != null) { return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue); } } return defaultValue; } } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"node_value_impl.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/custom_op/node_value_impl.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAMH,OAAO,EAAC,SAAS,EAAC,MAAM,oBAAoB,CAAC;AAC7C,OAAO,EAAC,iBAAiB,EAAE,YAAY,EAAE,kBAAkB,EAAE,aAAa,EAAE,cAAc,EAAE,oBAAoB,EAAE,mBAAmB,EAAE,cAAc,EAAE,wBAAwB,EAAE,mBAAmB,EAAC,MAAM,qBAAqB,CAAC;AAGjO;;GAEG;AACH,MAAM,OAAO,aAAa;IAGxB,YACY,IAAU,EAAU,SAA0B,EAC9C,OAAyB;QADzB,SAAI,GAAJ,IAAI,CAAM;QAAU,cAAS,GAAT,SAAS,CAAiB;QAC9C,YAAO,GAAP,OAAO,CAAkB;QAJrB,WAAM,GAAa,EAAE,CAAC;QACtB,UAAK,GAA+B,EAAE,CAAC;QAIrD,IAAI,CAAC,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,CAAC,CAAC;QAC/D,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,IAAI,CAAC,KAAK,GAAG,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC;iBACrB,MAAM,CAAC,CAAC,KAAiC,EAAE,GAAG,EAAE,EAAE;gBACjD,KAAK,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;gBAC/B,OAAO,KAAK,CAAC;YACf,CAAC,EAAE,EAAE,CAAC,CAAC;SACzB;IACH,CAAC;IAED;;;OAGG;IACK,QAAQ,CAAC,IAAY;QAC3B,OAAO,SAAS,CAAC,IAAI,EAAE,IAAI,CAAC,SAAS,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;IACvD,CAAC;IAED;;;OAGG;IACK,OAAO,CAAC,IAAY,EAAE,YAAwB;QACpD,MAAM,KAAK,GAAG,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,CAAC;QACvC,IAAI,KAAK,CAAC,MAAM,IAAI,IAAI,EAAE;YACxB,OAAO,SAAS,CAAC,IAAI,EAAE,IAAI,CAAC,SAAS,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;SACtD;QACD,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,EAAE;YACtC,OAAO,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAsB,CAAC,CAAC;SACzE;QACD,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,EAAE;YACnB,OAAO,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAsB,CAAC,CAAC;SACzE;QACD,IAAI,KAAK,CAAC,CAAC,IAAI,IAAI,EAAE;YACnB,OAAO,YAAY,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAuB,CAAC,CAAC;SACxE;QACD,IAAI,KAAK,CAAC,KAAK,IAAI,IAAI,EAAE;YACvB,OAAO,mBAAmB,CACtB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;SACzD;QACD,IAAI,KAAK,CAAC,IAAI,IAAI,IAAI,EAAE;YACtB,OAAO,aAAa,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;SAC1E;QACD,IAAI,KAAK,CAAC,IAAI,IAAI,IAAI,EAAE;YACtB,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,EAAE;gBAChD,OAAO,oBAAoB,CACvB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;aACzD;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,EAAE;gBACxB,OAAO,mBAAmB,CACtB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAwB,CAAC,CAAC;aACzD;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,KAAK,IAAI,IAAI,EAAE;gBAC5B,OAAO,wBAAwB,CAC3B,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAA0B,CAAC,CAAC;aAC3D;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,IAAI,EAAE;gBACxB,OAAO,iBAAiB,CACpB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAAyB,CAAC,CAAC;aAC1D;YACD,IAAI,KAAK,CAAC,IAAI,CAAC,IAAI,IAAI,IAAI,EAAE;gBAC3B,OAAO,kBAAkB,CACrB,IAAI,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,EAAE,YAA0B,CAAC,CAAC;aAC3D;SACF;QAED,OAAO,YAAY,CAAC;IACtB,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {DataType, Tensor} from '@tensorflow/tfjs-core';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {getTensor} from '../executors/utils';\nimport {getBoolArrayParam, getBoolParam, getDtypeArrayParam, getDtypeParam, getNumberParam, getNumericArrayParam, getStringArrayParam, getStringParam, getTensorShapeArrayParam, getTensorShapeParam} from '../operation_mapper';\nimport {GraphNode, Node, ValueType} from '../types';\n\n/**\n * Helper class for lookup inputs and params for nodes in the model graph.\n */\nexport class NodeValueImpl implements GraphNode {\n  public readonly inputs: Tensor[] = [];\n  public readonly attrs: {[key: string]: ValueType} = {};\n  constructor(\n      private node: Node, private tensorMap: NamedTensorsMap,\n      private context: ExecutionContext) {\n    this.inputs = node.inputNames.map(name => this.getInput(name));\n    if (node.rawAttrs != null) {\n      this.attrs = Object.keys(node.rawAttrs)\n                       .reduce((attrs: {[key: string]: ValueType}, key) => {\n                         attrs[key] = this.getAttr(key);\n                         return attrs;\n                       }, {});\n    }\n  }\n\n  /**\n   * Return the value of the attribute or input param.\n   * @param name String: name of attribute or input param.\n   */\n  private getInput(name: string): Tensor {\n    return getTensor(name, this.tensorMap, this.context);\n  }\n\n  /**\n   * Return the value of the attribute or input param.\n   * @param name String: name of attribute or input param.\n   */\n  private getAttr(name: string, defaultValue?: ValueType): ValueType {\n    const value = this.node.rawAttrs[name];\n    if (value.tensor != null) {\n      return getTensor(name, this.tensorMap, this.context);\n    }\n    if (value.i != null || value.f != null) {\n      return getNumberParam(this.node.rawAttrs, name, defaultValue as number);\n    }\n    if (value.s != null) {\n      return getStringParam(this.node.rawAttrs, name, defaultValue as string);\n    }\n    if (value.b != null) {\n      return getBoolParam(this.node.rawAttrs, name, defaultValue as boolean);\n    }\n    if (value.shape != null) {\n      return getTensorShapeParam(\n          this.node.rawAttrs, name, defaultValue as number[]);\n    }\n    if (value.type != null) {\n      return getDtypeParam(this.node.rawAttrs, name, defaultValue as DataType);\n    }\n    if (value.list != null) {\n      if (value.list.i != null || value.list.f != null) {\n        return getNumericArrayParam(\n            this.node.rawAttrs, name, defaultValue as number[]);\n      }\n      if (value.list.s != null) {\n        return getStringArrayParam(\n            this.node.rawAttrs, name, defaultValue as string[]);\n      }\n      if (value.list.shape != null) {\n        return getTensorShapeArrayParam(\n            this.node.rawAttrs, name, defaultValue as number[][]);\n      }\n      if (value.list.b != null) {\n        return getBoolArrayParam(\n            this.node.rawAttrs, name, defaultValue as boolean[]);\n      }\n      if (value.list.type != null) {\n        return getDtypeArrayParam(\n            this.node.rawAttrs, name, defaultValue as DataType[]);\n      }\n    }\n\n    return defaultValue;\n  }\n}\n"]}