UNPKG

@tensorflow/tfjs-converter

Version:

Tensorflow model converter for javascript

89 lines 12.7 kB
/** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { getTensor } from '../executors/utils'; import { getBoolArrayParam, getBoolParam, getDtypeArrayParam, getDtypeParam, getNumberParam, getNumericArrayParam, getStringArrayParam, getStringParam, getTensorShapeArrayParam, getTensorShapeParam } from '../operation_mapper'; /** * Helper class for lookup inputs and params for nodes in the model graph. */ export class NodeValueImpl { constructor(node, tensorMap, context) { this.node = node; this.tensorMap = tensorMap; this.context = context; this.inputs = []; this.attrs = {}; this.inputs = node.inputNames.map(name => this.getInput(name)); if (node.rawAttrs != null) { this.attrs = Object.keys(node.rawAttrs) .reduce((attrs, key) => { attrs[key] = this.getAttr(key); return attrs; }, {}); } } /** * Return the value of the attribute or input param. * @param name String: name of attribute or input param. */ getInput(name) { return getTensor(name, this.tensorMap, this.context); } /** * Return the value of the attribute or input param. * @param name String: name of attribute or input param. */ getAttr(name, defaultValue) { const value = this.node.rawAttrs[name]; if (value.tensor != null) { return getTensor(name, this.tensorMap, this.context); } if (value.i != null || value.f != null) { return getNumberParam(this.node.rawAttrs, name, defaultValue); } if (value.s != null) { return getStringParam(this.node.rawAttrs, name, defaultValue); } if (value.b != null) { return getBoolParam(this.node.rawAttrs, name, defaultValue); } if (value.shape != null) { return getTensorShapeParam(this.node.rawAttrs, name, defaultValue); } if (value.type != null) { return getDtypeParam(this.node.rawAttrs, name, defaultValue); } if (value.list != null) { if (value.list.i != null || value.list.f != null) { return getNumericArrayParam(this.node.rawAttrs, name, defaultValue); } if (value.list.s != null) { return getStringArrayParam(this.node.rawAttrs, name, defaultValue); } if (value.list.shape != null) { return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue); } if (value.list.b != null) { return getBoolArrayParam(this.node.rawAttrs, name, defaultValue); } if (value.list.type != null) { return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue); } } return defaultValue; } } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibm9kZV92YWx1ZV9pbXBsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb252ZXJ0ZXIvc3JjL29wZXJhdGlvbnMvY3VzdG9tX29wL25vZGVfdmFsdWVfaW1wbC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFNSCxPQUFPLEVBQUMsU0FBUyxFQUFDLE1BQU0sb0JBQW9CLENBQUM7QUFDN0MsT0FBTyxFQUFDLGlCQUFpQixFQUFFLFlBQVksRUFBRSxrQkFBa0IsRUFBRSxhQUFhLEVBQUUsY0FBYyxFQUFFLG9CQUFvQixFQUFFLG1CQUFtQixFQUFFLGNBQWMsRUFBRSx3QkFBd0IsRUFBRSxtQkFBbUIsRUFBQyxNQUFNLHFCQUFxQixDQUFDO0FBR2pPOztHQUVHO0FBQ0gsTUFBTSxPQUFPLGFBQWE7SUFHeEIsWUFDWSxJQUFVLEVBQVUsU0FBMEIsRUFDOUMsT0FBeUI7UUFEekIsU0FBSSxHQUFKLElBQUksQ0FBTTtRQUFVLGNBQVMsR0FBVCxTQUFTLENBQWlCO1FBQzlDLFlBQU8sR0FBUCxPQUFPLENBQWtCO1FBSnJCLFdBQU0sR0FBYSxFQUFFLENBQUM7UUFDdEIsVUFBSyxHQUErQixFQUFFLENBQUM7UUFJckQsSUFBSSxDQUFDLE1BQU0sR0FBRyxJQUFJLENBQUMsVUFBVSxDQUFDLEdBQUcsQ0FBQyxJQUFJLENBQUMsRUFBRSxDQUFDLElBQUksQ0FBQyxRQUFRLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQztRQUMvRCxJQUFJLElBQUksQ0FBQyxRQUFRLElBQUksSUFBSSxFQUFFO1lBQ3pCLElBQUksQ0FBQyxLQUFLLEdBQUcsTUFBTSxDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsUUFBUSxDQUFDO2lCQUNyQixNQUFNLENBQUMsQ0FBQyxLQUFpQyxFQUFFLEdBQUcsRUFBRSxFQUFFO2dCQUNqRCxLQUFLLENBQUMsR0FBRyxDQUFDLEdBQUcsSUFBSSxDQUFDLE9BQU8sQ0FBQyxHQUFHLENBQUMsQ0FBQztnQkFDL0IsT0FBTyxLQUFLLENBQUM7WUFDZixDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUM7U0FDekI7SUFDSCxDQUFDO0lBRUQ7OztPQUdHO0lBQ0ssUUFBUSxDQUFDLElBQVk7UUFDM0IsT0FBTyxTQUFTLENBQUMsSUFBSSxFQUFFLElBQUksQ0FBQyxTQUFTLEVBQUUsSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDO0lBQ3ZELENBQUM7SUFFRDs7O09BR0c7SUFDSyxPQUFPLENBQUMsSUFBWSxFQUFFLFlBQXdCO1FBQ3BELE1BQU0sS0FBSyxHQUFHLElBQUksQ0FBQyxJQUFJLENBQUMsUUFBUSxDQUFDLElBQUksQ0FBQyxDQUFDO1FBQ3ZDLElBQUksS0FBSyxDQUFDLE1BQU0sSUFBSSxJQUFJLEVBQUU7WUFDeEIsT0FBTyxTQUFTLENBQUMsSUFBSSxFQUFFLElBQUksQ0FBQyxTQUFTLEVBQUUsSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDO1NBQ3REO1FBQ0QsSUFBSSxLQUFLLENBQUMsQ0FBQyxJQUFJLElBQUksSUFBSSxLQUFLLENBQUMsQ0FBQyxJQUFJLElBQUksRUFBRTtZQUN0QyxPQUFPLGNBQWMsQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxJQUFJLEVBQUUsWUFBc0IsQ0FBQyxDQUFDO1NBQ3pFO1FBQ0QsSUFBSSxLQUFLLENBQUMsQ0FBQyxJQUFJLElBQUksRUFBRTtZQUNuQixPQUFPLGNBQWMsQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxJQUFJLEVBQUUsWUFBc0IsQ0FBQyxDQUFDO1NBQ3pFO1FBQ0QsSUFBSSxLQUFLLENBQUMsQ0FBQyxJQUFJLElBQUksRUFBRTtZQUNuQixPQUFPLFlBQVksQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxJQUFJLEVBQUUsWUFBdUIsQ0FBQyxDQUFDO1NBQ3hFO1FBQ0QsSUFBSSxLQUFLLENBQUMsS0FBSyxJQUFJLElBQUksRUFBRTtZQUN2QixPQUFPLG1CQUFtQixDQUN0QixJQUFJLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxJQUFJLEVBQUUsWUFBd0IsQ0FBQyxDQUFDO1NBQ3pEO1FBQ0QsSUFBSSxLQUFLLENBQUMsSUFBSSxJQUFJLElBQUksRUFBRTtZQUN0QixPQUFPLGFBQWEsQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxJQUFJLEVBQUUsWUFBd0IsQ0FBQyxDQUFDO1NBQzFFO1FBQ0QsSUFBSSxLQUFLLENBQUMsSUFBSSxJQUFJLElBQUksRUFBRTtZQUN0QixJQUFJLEtBQUssQ0FBQyxJQUFJLENBQUMsQ0FBQyxJQUFJLElBQUksSUFBSSxLQUFLLENBQUMsSUFBSSxDQUFDLENBQUMsSUFBSSxJQUFJLEVBQUU7Z0JBQ2hELE9BQU8sb0JBQW9CLENBQ3ZCLElBQUksQ0FBQyxJQUFJLENBQUMsUUFBUSxFQUFFLElBQUksRUFBRSxZQUF3QixDQUFDLENBQUM7YUFDekQ7WUFDRCxJQUFJLEtBQUssQ0FBQyxJQUFJLENBQUMsQ0FBQyxJQUFJLElBQUksRUFBRTtnQkFDeEIsT0FBTyxtQkFBbUIsQ0FDdEIsSUFBSSxDQUFDLElBQUksQ0FBQyxRQUFRLEVBQUUsSUFBSSxFQUFFLFlBQXdCLENBQUMsQ0FBQzthQUN6RDtZQUNELElBQUksS0FBSyxDQUFDLElBQUksQ0FBQyxLQUFLLElBQUksSUFBSSxFQUFFO2dCQUM1QixPQUFPLHdCQUF3QixDQUMzQixJQUFJLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxJQUFJLEVBQUUsWUFBMEIsQ0FBQyxDQUFDO2FBQzNEO1lBQ0QsSUFBSSxLQUFLLENBQUMsSUFBSSxDQUFDLENBQUMsSUFBSSxJQUFJLEVBQUU7Z0JBQ3hCLE9BQU8saUJBQWlCLENBQ3BCLElBQUksQ0FBQyxJQUFJLENBQUMsUUFBUSxFQUFFLElBQUksRUFBRSxZQUF5QixDQUFDLENBQUM7YUFDMUQ7WUFDRCxJQUFJLEtBQUssQ0FBQyxJQUFJLENBQUMsSUFBSSxJQUFJLElBQUksRUFBRTtnQkFDM0IsT0FBTyxrQkFBa0IsQ0FDckIsSUFBSSxDQUFDLElBQUksQ0FBQyxRQUFRLEVBQUUsSUFBSSxFQUFFLFlBQTBCLENBQUMsQ0FBQzthQUMzRDtTQUNGO1FBRUQsT0FBTyxZQUFZLENBQUM7SUFDdEIsQ0FBQztDQUNGIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTkgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge0RhdGFUeXBlLCBUZW5zb3J9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TmFtZWRUZW5zb3JzTWFwfSBmcm9tICcuLi8uLi9kYXRhL3R5cGVzJztcbmltcG9ydCB7RXhlY3V0aW9uQ29udGV4dH0gZnJvbSAnLi4vLi4vZXhlY3V0b3IvZXhlY3V0aW9uX2NvbnRleHQnO1xuaW1wb3J0IHtnZXRUZW5zb3J9IGZyb20gJy4uL2V4ZWN1dG9ycy91dGlscyc7XG5pbXBvcnQge2dldEJvb2xBcnJheVBhcmFtLCBnZXRCb29sUGFyYW0sIGdldER0eXBlQXJyYXlQYXJhbSwgZ2V0RHR5cGVQYXJhbSwgZ2V0TnVtYmVyUGFyYW0sIGdldE51bWVyaWNBcnJheVBhcmFtLCBnZXRTdHJpbmdBcnJheVBhcmFtLCBnZXRTdHJpbmdQYXJhbSwgZ2V0VGVuc29yU2hhcGVBcnJheVBhcmFtLCBnZXRUZW5zb3JTaGFwZVBhcmFtfSBmcm9tICcuLi9vcGVyYXRpb25fbWFwcGVyJztcbmltcG9ydCB7R3JhcGhOb2RlLCBOb2RlLCBWYWx1ZVR5cGV9IGZyb20gJy4uL3R5cGVzJztcblxuLyoqXG4gKiBIZWxwZXIgY2xhc3MgZm9yIGxvb2t1cCBpbnB1dHMgYW5kIHBhcmFtcyBmb3Igbm9kZXMgaW4gdGhlIG1vZGVsIGdyYXBoLlxuICovXG5leHBvcnQgY2xhc3MgTm9kZVZhbHVlSW1wbCBpbXBsZW1lbnRzIEdyYXBoTm9kZSB7XG4gIHB1YmxpYyByZWFkb25seSBpbnB1dHM6IFRlbnNvcltdID0gW107XG4gIHB1YmxpYyByZWFkb25seSBhdHRyczoge1trZXk6IHN0cmluZ106IFZhbHVlVHlwZX0gPSB7fTtcbiAgY29uc3RydWN0b3IoXG4gICAgICBwcml2YXRlIG5vZGU6IE5vZGUsIHByaXZhdGUgdGVuc29yTWFwOiBOYW1lZFRlbnNvcnNNYXAsXG4gICAgICBwcml2YXRlIGNvbnRleHQ6IEV4ZWN1dGlvbkNvbnRleHQpIHtcbiAgICB0aGlzLmlucHV0cyA9IG5vZGUuaW5wdXROYW1lcy5tYXAobmFtZSA9PiB0aGlzLmdldElucHV0KG5hbWUpKTtcbiAgICBpZiAobm9kZS5yYXdBdHRycyAhPSBudWxsKSB7XG4gICAgICB0aGlzLmF0dHJzID0gT2JqZWN0LmtleXMobm9kZS5yYXdBdHRycylcbiAgICAgICAgICAgICAgICAgICAgICAgLnJlZHVjZSgoYXR0cnM6IHtba2V5OiBzdHJpbmddOiBWYWx1ZVR5cGV9LCBrZXkpID0+IHtcbiAgICAgICAgICAgICAgICAgICAgICAgICBhdHRyc1trZXldID0gdGhpcy5nZXRBdHRyKGtleSk7XG4gICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGF0dHJzO1xuICAgICAgICAgICAgICAgICAgICAgICB9LCB7fSk7XG4gICAgfVxuICB9XG5cbiAgLyoqXG4gICAqIFJldHVybiB0aGUgdmFsdWUgb2YgdGhlIGF0dHJpYnV0ZSBvciBpbnB1dCBwYXJhbS5cbiAgICogQHBhcmFtIG5hbWUgU3RyaW5nOiBuYW1lIG9mIGF0dHJpYnV0ZSBvciBpbnB1dCBwYXJhbS5cbiAgICovXG4gIHByaXZhdGUgZ2V0SW5wdXQobmFtZTogc3RyaW5nKTogVGVuc29yIHtcbiAgICByZXR1cm4gZ2V0VGVuc29yKG5hbWUsIHRoaXMudGVuc29yTWFwLCB0aGlzLmNvbnRleHQpO1xuICB9XG5cbiAgLyoqXG4gICAqIFJldHVybiB0aGUgdmFsdWUgb2YgdGhlIGF0dHJpYnV0ZSBvciBpbnB1dCBwYXJhbS5cbiAgICogQHBhcmFtIG5hbWUgU3RyaW5nOiBuYW1lIG9mIGF0dHJpYnV0ZSBvciBpbnB1dCBwYXJhbS5cbiAgICovXG4gIHByaXZhdGUgZ2V0QXR0cihuYW1lOiBzdHJpbmcsIGRlZmF1bHRWYWx1ZT86IFZhbHVlVHlwZSk6IFZhbHVlVHlwZSB7XG4gICAgY29uc3QgdmFsdWUgPSB0aGlzLm5vZGUucmF3QXR0cnNbbmFtZV07XG4gICAgaWYgKHZhbHVlLnRlbnNvciAhPSBudWxsKSB7XG4gICAgICByZXR1cm4gZ2V0VGVuc29yKG5hbWUsIHRoaXMudGVuc29yTWFwLCB0aGlzLmNvbnRleHQpO1xuICAgIH1cbiAgICBpZiAodmFsdWUuaSAhPSBudWxsIHx8IHZhbHVlLmYgIT0gbnVsbCkge1xuICAgICAgcmV0dXJuIGdldE51bWJlclBhcmFtKHRoaXMubm9kZS5yYXdBdHRycywgbmFtZSwgZGVmYXVsdFZhbHVlIGFzIG51bWJlcik7XG4gICAgfVxuICAgIGlmICh2YWx1ZS5zICE9IG51bGwpIHtcbiAgICAgIHJldHVybiBnZXRTdHJpbmdQYXJhbSh0aGlzLm5vZGUucmF3QXR0cnMsIG5hbWUsIGRlZmF1bHRWYWx1ZSBhcyBzdHJpbmcpO1xuICAgIH1cbiAgICBpZiAodmFsdWUuYiAhPSBudWxsKSB7XG4gICAgICByZXR1cm4gZ2V0Qm9vbFBhcmFtKHRoaXMubm9kZS5yYXdBdHRycywgbmFtZSwgZGVmYXVsdFZhbHVlIGFzIGJvb2xlYW4pO1xuICAgIH1cbiAgICBpZiAodmFsdWUuc2hhcGUgIT0gbnVsbCkge1xuICAgICAgcmV0dXJuIGdldFRlbnNvclNoYXBlUGFyYW0oXG4gICAgICAgICAgdGhpcy5ub2RlLnJhd0F0dHJzLCBuYW1lLCBkZWZhdWx0VmFsdWUgYXMgbnVtYmVyW10pO1xuICAgIH1cbiAgICBpZiAodmFsdWUudHlwZSAhPSBudWxsKSB7XG4gICAgICByZXR1cm4gZ2V0RHR5cGVQYXJhbSh0aGlzLm5vZGUucmF3QXR0cnMsIG5hbWUsIGRlZmF1bHRWYWx1ZSBhcyBEYXRhVHlwZSk7XG4gICAgfVxuICAgIGlmICh2YWx1ZS5saXN0ICE9IG51bGwpIHtcbiAgICAgIGlmICh2YWx1ZS5saXN0LmkgIT0gbnVsbCB8fCB2YWx1ZS5saXN0LmYgIT0gbnVsbCkge1xuICAgICAgICByZXR1cm4gZ2V0TnVtZXJpY0FycmF5UGFyYW0oXG4gICAgICAgICAgICB0aGlzLm5vZGUucmF3QXR0cnMsIG5hbWUsIGRlZmF1bHRWYWx1ZSBhcyBudW1iZXJbXSk7XG4gICAgICB9XG4gICAgICBpZiAodmFsdWUubGlzdC5zICE9IG51bGwpIHtcbiAgICAgICAgcmV0dXJuIGdldFN0cmluZ0FycmF5UGFyYW0oXG4gICAgICAgICAgICB0aGlzLm5vZGUucmF3QXR0cnMsIG5hbWUsIGRlZmF1bHRWYWx1ZSBhcyBzdHJpbmdbXSk7XG4gICAgICB9XG4gICAgICBpZiAodmFsdWUubGlzdC5zaGFwZSAhPSBudWxsKSB7XG4gICAgICAgIHJldHVybiBnZXRUZW5zb3JTaGFwZUFycmF5UGFyYW0oXG4gICAgICAgICAgICB0aGlzLm5vZGUucmF3QXR0cnMsIG5hbWUsIGRlZmF1bHRWYWx1ZSBhcyBudW1iZXJbXVtdKTtcbiAgICAgIH1cbiAgICAgIGlmICh2YWx1ZS5saXN0LmIgIT0gbnVsbCkge1xuICAgICAgICByZXR1cm4gZ2V0Qm9vbEFycmF5UGFyYW0oXG4gICAgICAgICAgICB0aGlzLm5vZGUucmF3QXR0cnMsIG5hbWUsIGRlZmF1bHRWYWx1ZSBhcyBib29sZWFuW10pO1xuICAgICAgfVxuICAgICAgaWYgKHZhbHVlLmxpc3QudHlwZSAhPSBudWxsKSB7XG4gICAgICAgIHJldHVybiBnZXREdHlwZUFycmF5UGFyYW0oXG4gICAgICAgICAgICB0aGlzLm5vZGUucmF3QXR0cnMsIG5hbWUsIGRlZmF1bHRWYWx1ZSBhcyBEYXRhVHlwZVtdKTtcbiAgICAgIH1cbiAgICB9XG5cbiAgICByZXR1cm4gZGVmYXVsdFZhbHVlO1xuICB9XG59XG4iXX0=