UNPKG

@tensorflow/tfjs-converter

Version:

Tensorflow model converter for javascript

65 lines 11 kB
/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { HashTable } from '../../executor/hash_table'; import { getParamValue } from './utils'; export const executeOp = async (node, tensorMap, context, resourceManager) => { switch (node.op) { case 'HashTable': case 'HashTableV2': { const existingTableHandle = resourceManager.getHashTableHandleByName(node.name); // Table is shared with initializer. if (existingTableHandle != null) { return [existingTableHandle]; } else { const keyDType = getParamValue('keyDType', node, tensorMap, context); const valueDType = getParamValue('valueDType', node, tensorMap, context); const hashTable = new HashTable(keyDType, valueDType); resourceManager.addHashTable(node.name, hashTable); return [hashTable.handle]; } } case 'InitializeTable': case 'InitializeTableV2': case 'LookupTableImport': case 'LookupTableImportV2': { const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager); const keys = getParamValue('keys', node, tensorMap, context); const values = getParamValue('values', node, tensorMap, context); const hashTable = resourceManager.getHashTableById(handle.id); return [await hashTable.import(keys, values)]; } case 'LookupTableFind': case 'LookupTableFindV2': { const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager); const keys = getParamValue('keys', node, tensorMap, context); const defaultValue = getParamValue('defaultValue', node, tensorMap, context); const hashTable = resourceManager.getHashTableById(handle.id); return [await hashTable.find(keys, defaultValue)]; } case 'LookupTableSize': case 'LookupTableSizeV2': { const handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager); const hashTable = resourceManager.getHashTableById(handle.id); return [hashTable.tensorSize()]; } default: throw TypeError(`Node type ${node.op} is not implemented`); } }; export const CATEGORY = 'hash_table'; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiaGFzaF90YWJsZV9leGVjdXRvci5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29udmVydGVyL3NyYy9vcGVyYXRpb25zL2V4ZWN1dG9ycy9oYXNoX3RhYmxlX2V4ZWN1dG9yLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQU1ILE9BQU8sRUFBQyxTQUFTLEVBQUMsTUFBTSwyQkFBMkIsQ0FBQztBQUlwRCxPQUFPLEVBQUMsYUFBYSxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBRXRDLE1BQU0sQ0FBQyxNQUFNLFNBQVMsR0FBNEIsS0FBSyxFQUNuRCxJQUFVLEVBQUUsU0FBMEIsRUFBRSxPQUF5QixFQUNqRSxlQUFnQyxFQUFxQixFQUFFO0lBQ3pELFFBQVEsSUFBSSxDQUFDLEVBQUUsRUFBRTtRQUNmLEtBQUssV0FBVyxDQUFDO1FBQ2pCLEtBQUssYUFBYSxDQUFDLENBQUM7WUFDbEIsTUFBTSxtQkFBbUIsR0FDckIsZUFBZSxDQUFDLHdCQUF3QixDQUFDLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQztZQUN4RCxvQ0FBb0M7WUFDcEMsSUFBSSxtQkFBbUIsSUFBSSxJQUFJLEVBQUU7Z0JBQy9CLE9BQU8sQ0FBQyxtQkFBbUIsQ0FBQyxDQUFDO2FBQzlCO2lCQUFNO2dCQUNMLE1BQU0sUUFBUSxHQUNWLGFBQWEsQ0FBQyxVQUFVLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQWEsQ0FBQztnQkFDcEUsTUFBTSxVQUFVLEdBQ1osYUFBYSxDQUFDLFlBQVksRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBYSxDQUFDO2dCQUV0RSxNQUFNLFNBQVMsR0FBRyxJQUFJLFNBQVMsQ0FBQyxRQUFRLEVBQUUsVUFBVSxDQUFDLENBQUM7Z0JBQ3RELGVBQWUsQ0FBQyxZQUFZLENBQUMsSUFBSSxDQUFDLElBQUksRUFBRSxTQUFTLENBQUMsQ0FBQztnQkFDbkQsT0FBTyxDQUFDLFNBQVMsQ0FBQyxNQUFNLENBQUMsQ0FBQzthQUMzQjtTQUNGO1FBQ0QsS0FBSyxpQkFBaUIsQ0FBQztRQUN2QixLQUFLLG1CQUFtQixDQUFDO1FBQ3pCLEtBQUssbUJBQW1CLENBQUM7UUFDekIsS0FBSyxxQkFBcUIsQ0FBQyxDQUFDO1lBQzFCLE1BQU0sTUFBTSxHQUFHLGFBQWEsQ0FDVCxhQUFhLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLEVBQ3ZDLGVBQWUsQ0FBVyxDQUFDO1lBQzlDLE1BQU0sSUFBSSxHQUFHLGFBQWEsQ0FBQyxNQUFNLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVcsQ0FBQztZQUN2RSxNQUFNLE1BQU0sR0FDUixhQUFhLENBQUMsUUFBUSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxDQUFXLENBQUM7WUFFaEUsTUFBTSxTQUFTLEdBQUcsZUFBZSxDQUFDLGdCQUFnQixDQUFDLE1BQU0sQ0FBQyxFQUFFLENBQUMsQ0FBQztZQUU5RCxPQUFPLENBQUMsTUFBTSxTQUFTLENBQUMsTUFBTSxDQUFDLElBQUksRUFBRSxNQUFNLENBQUMsQ0FBQyxDQUFDO1NBQy9DO1FBQ0QsS0FBSyxpQkFBaUIsQ0FBQztRQUN2QixLQUFLLG1CQUFtQixDQUFDLENBQUM7WUFDeEIsTUFBTSxNQUFNLEdBQUcsYUFBYSxDQUNULGFBQWEsRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sRUFDdkMsZUFBZSxDQUFXLENBQUM7WUFDOUMsTUFBTSxJQUFJLEdBQUcsYUFBYSxDQUFDLE1BQU0sRUFBRSxJQUFJLEVBQUUsU0FBUyxFQUFFLE9BQU8sQ0FBVyxDQUFDO1lBQ3ZFLE1BQU0sWUFBWSxHQUNkLGFBQWEsQ0FBQyxjQUFjLEVBQUUsSUFBSSxFQUFFLFNBQVMsRUFBRSxPQUFPLENBQVcsQ0FBQztZQUV0RSxNQUFNLFNBQVMsR0FBRyxlQUFlLENBQUMsZ0JBQWdCLENBQUMsTUFBTSxDQUFDLEVBQUUsQ0FBQyxDQUFDO1lBQzlELE9BQU8sQ0FBQyxNQUFNLFNBQVMsQ0FBQyxJQUFJLENBQUMsSUFBSSxFQUFFLFlBQVksQ0FBQyxDQUFDLENBQUM7U0FDbkQ7UUFDRCxLQUFLLGlCQUFpQixDQUFDO1FBQ3ZCLEtBQUssbUJBQW1CLENBQUMsQ0FBQztZQUN4QixNQUFNLE1BQU0sR0FBRyxhQUFhLENBQ1QsYUFBYSxFQUFFLElBQUksRUFBRSxTQUFTLEVBQUUsT0FBTyxFQUN2QyxlQUFlLENBQVcsQ0FBQztZQUU5QyxNQUFNLFNBQVMsR0FBRyxlQUFlLENBQUMsZ0JBQWdCLENBQUMsTUFBTSxDQUFDLEVBQUUsQ0FBQyxDQUFDO1lBQzlELE9BQU8sQ0FBQyxTQUFTLENBQUMsVUFBVSxFQUFFLENBQUMsQ0FBQztTQUNqQztRQUNEO1lBQ0UsTUFBTSxTQUFTLENBQUMsYUFBYSxJQUFJLENBQUMsRUFBRSxxQkFBcUIsQ0FBQyxDQUFDO0tBQzlEO0FBQ0gsQ0FBQyxDQUFDO0FBRUYsTUFBTSxDQUFDLE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtEYXRhVHlwZSwgVGVuc29yfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge05hbWVkVGVuc29yc01hcH0gZnJvbSAnLi4vLi4vZGF0YS90eXBlcyc7XG5pbXBvcnQge0V4ZWN1dGlvbkNvbnRleHR9IGZyb20gJy4uLy4uL2V4ZWN1dG9yL2V4ZWN1dGlvbl9jb250ZXh0JztcbmltcG9ydCB7SGFzaFRhYmxlfSBmcm9tICcuLi8uLi9leGVjdXRvci9oYXNoX3RhYmxlJztcbmltcG9ydCB7UmVzb3VyY2VNYW5hZ2VyfSBmcm9tICcuLi8uLi9leGVjdXRvci9yZXNvdXJjZV9tYW5hZ2VyJztcbmltcG9ydCB7SW50ZXJuYWxPcEFzeW5jRXhlY3V0b3IsIE5vZGV9IGZyb20gJy4uL3R5cGVzJztcblxuaW1wb3J0IHtnZXRQYXJhbVZhbHVlfSBmcm9tICcuL3V0aWxzJztcblxuZXhwb3J0IGNvbnN0IGV4ZWN1dGVPcDogSW50ZXJuYWxPcEFzeW5jRXhlY3V0b3IgPSBhc3luYyhcbiAgICBub2RlOiBOb2RlLCB0ZW5zb3JNYXA6IE5hbWVkVGVuc29yc01hcCwgY29udGV4dDogRXhlY3V0aW9uQ29udGV4dCxcbiAgICByZXNvdXJjZU1hbmFnZXI6IFJlc291cmNlTWFuYWdlcik6IFByb21pc2U8VGVuc29yW10+ID0+IHtcbiAgc3dpdGNoIChub2RlLm9wKSB7XG4gICAgY2FzZSAnSGFzaFRhYmxlJzpcbiAgICBjYXNlICdIYXNoVGFibGVWMic6IHtcbiAgICAgIGNvbnN0IGV4aXN0aW5nVGFibGVIYW5kbGUgPVxuICAgICAgICAgIHJlc291cmNlTWFuYWdlci5nZXRIYXNoVGFibGVIYW5kbGVCeU5hbWUobm9kZS5uYW1lKTtcbiAgICAgIC8vIFRhYmxlIGlzIHNoYXJlZCB3aXRoIGluaXRpYWxpemVyLlxuICAgICAgaWYgKGV4aXN0aW5nVGFibGVIYW5kbGUgIT0gbnVsbCkge1xuICAgICAgICByZXR1cm4gW2V4aXN0aW5nVGFibGVIYW5kbGVdO1xuICAgICAgfSBlbHNlIHtcbiAgICAgICAgY29uc3Qga2V5RFR5cGUgPVxuICAgICAgICAgICAgZ2V0UGFyYW1WYWx1ZSgna2V5RFR5cGUnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIERhdGFUeXBlO1xuICAgICAgICBjb25zdCB2YWx1ZURUeXBlID1cbiAgICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3ZhbHVlRFR5cGUnLCBub2RlLCB0ZW5zb3JNYXAsIGNvbnRleHQpIGFzIERhdGFUeXBlO1xuXG4gICAgICAgIGNvbnN0IGhhc2hUYWJsZSA9IG5ldyBIYXNoVGFibGUoa2V5RFR5cGUsIHZhbHVlRFR5cGUpO1xuICAgICAgICByZXNvdXJjZU1hbmFnZXIuYWRkSGFzaFRhYmxlKG5vZGUubmFtZSwgaGFzaFRhYmxlKTtcbiAgICAgICAgcmV0dXJuIFtoYXNoVGFibGUuaGFuZGxlXTtcbiAgICAgIH1cbiAgICB9XG4gICAgY2FzZSAnSW5pdGlhbGl6ZVRhYmxlJzpcbiAgICBjYXNlICdJbml0aWFsaXplVGFibGVWMic6XG4gICAgY2FzZSAnTG9va3VwVGFibGVJbXBvcnQnOlxuICAgIGNhc2UgJ0xvb2t1cFRhYmxlSW1wb3J0VjInOiB7XG4gICAgICBjb25zdCBoYW5kbGUgPSBnZXRQYXJhbVZhbHVlKFxuICAgICAgICAgICAgICAgICAgICAgICAgICd0YWJsZUhhbmRsZScsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCxcbiAgICAgICAgICAgICAgICAgICAgICAgICByZXNvdXJjZU1hbmFnZXIpIGFzIFRlbnNvcjtcbiAgICAgIGNvbnN0IGtleXMgPSBnZXRQYXJhbVZhbHVlKCdrZXlzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3I7XG4gICAgICBjb25zdCB2YWx1ZXMgPVxuICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ3ZhbHVlcycsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yO1xuXG4gICAgICBjb25zdCBoYXNoVGFibGUgPSByZXNvdXJjZU1hbmFnZXIuZ2V0SGFzaFRhYmxlQnlJZChoYW5kbGUuaWQpO1xuXG4gICAgICByZXR1cm4gW2F3YWl0IGhhc2hUYWJsZS5pbXBvcnQoa2V5cywgdmFsdWVzKV07XG4gICAgfVxuICAgIGNhc2UgJ0xvb2t1cFRhYmxlRmluZCc6XG4gICAgY2FzZSAnTG9va3VwVGFibGVGaW5kVjInOiB7XG4gICAgICBjb25zdCBoYW5kbGUgPSBnZXRQYXJhbVZhbHVlKFxuICAgICAgICAgICAgICAgICAgICAgICAgICd0YWJsZUhhbmRsZScsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCxcbiAgICAgICAgICAgICAgICAgICAgICAgICByZXNvdXJjZU1hbmFnZXIpIGFzIFRlbnNvcjtcbiAgICAgIGNvbnN0IGtleXMgPSBnZXRQYXJhbVZhbHVlKCdrZXlzJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0KSBhcyBUZW5zb3I7XG4gICAgICBjb25zdCBkZWZhdWx0VmFsdWUgPVxuICAgICAgICAgIGdldFBhcmFtVmFsdWUoJ2RlZmF1bHRWYWx1ZScsIG5vZGUsIHRlbnNvck1hcCwgY29udGV4dCkgYXMgVGVuc29yO1xuXG4gICAgICBjb25zdCBoYXNoVGFibGUgPSByZXNvdXJjZU1hbmFnZXIuZ2V0SGFzaFRhYmxlQnlJZChoYW5kbGUuaWQpO1xuICAgICAgcmV0dXJuIFthd2FpdCBoYXNoVGFibGUuZmluZChrZXlzLCBkZWZhdWx0VmFsdWUpXTtcbiAgICB9XG4gICAgY2FzZSAnTG9va3VwVGFibGVTaXplJzpcbiAgICBjYXNlICdMb29rdXBUYWJsZVNpemVWMic6IHtcbiAgICAgIGNvbnN0IGhhbmRsZSA9IGdldFBhcmFtVmFsdWUoXG4gICAgICAgICAgICAgICAgICAgICAgICAgJ3RhYmxlSGFuZGxlJywgbm9kZSwgdGVuc29yTWFwLCBjb250ZXh0LFxuICAgICAgICAgICAgICAgICAgICAgICAgIHJlc291cmNlTWFuYWdlcikgYXMgVGVuc29yO1xuXG4gICAgICBjb25zdCBoYXNoVGFibGUgPSByZXNvdXJjZU1hbmFnZXIuZ2V0SGFzaFRhYmxlQnlJZChoYW5kbGUuaWQpO1xuICAgICAgcmV0dXJuIFtoYXNoVGFibGUudGVuc29yU2l6ZSgpXTtcbiAgICB9XG4gICAgZGVmYXVsdDpcbiAgICAgIHRocm93IFR5cGVFcnJvcihgTm9kZSB0eXBlICR7bm9kZS5vcH0gaXMgbm90IGltcGxlbWVudGVkYCk7XG4gIH1cbn07XG5cbmV4cG9ydCBjb25zdCBDQVRFR09SWSA9ICdoYXNoX3RhYmxlJztcbiJdfQ==