UNPKG

@tensorflow/tfjs-converter

Version:

Tensorflow model converter for javascript

134 lines 15.7 kB
/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { keep, scalar, stack, tidy, unstack, util } from '@tensorflow/tfjs-core'; // tslint:disable-next-line: no-imports-from-dist import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter'; /** * Hashtable contains a set of tensors, which can be accessed by key. */ export class HashTable { get id() { return this.handle.id; } /** * Constructor of HashTable. Creates a hash table. * * @param keyDType `dtype` of the table keys. * @param valueDType `dtype` of the table values. */ constructor(keyDType, valueDType) { this.keyDType = keyDType; this.valueDType = valueDType; this.handle = scalar(0); // tslint:disable-next-line: no-any this.tensorMap = new Map(); keep(this.handle); } /** * Dispose the tensors and handle and clear the hashtable. */ clearAndClose() { this.tensorMap.forEach(value => value.dispose()); this.tensorMap.clear(); this.handle.dispose(); } /** * The number of items in the hash table. */ size() { return this.tensorMap.size; } /** * The number of items in the hash table as a rank-0 tensor. */ tensorSize() { return tfOps.scalar(this.size(), 'int32'); } /** * Replaces the contents of the table with the specified keys and values. * @param keys Keys to store in the hashtable. * @param values Values to store in the hashtable. */ async import(keys, values) { this.checkKeyAndValueTensor(keys, values); // We only store the primitive values of the keys, this allows lookup // to be O(1). const $keys = await keys.data(); // Clear the hashTable before inserting new values. this.tensorMap.forEach(value => value.dispose()); this.tensorMap.clear(); return tidy(() => { const $values = unstack(values); const keysLength = $keys.length; const valuesLength = $values.length; util.assert(keysLength === valuesLength, () => `The number of elements doesn't match, keys has ` + `${keysLength} elements, the values has ${valuesLength} ` + `elements.`); for (let i = 0; i < keysLength; i++) { const key = $keys[i]; const value = $values[i]; keep(value); this.tensorMap.set(key, value); } return this.handle; }); } /** * Looks up keys in a hash table, outputs the corresponding values. * * Performs batch lookups, for every element in the key tensor, `find` * stacks the corresponding value into the return tensor. * * If an element is not present in the table, the given `defaultValue` is * used. * * @param keys Keys to look up. Must have the same type as the keys of the * table. * @param defaultValue The scalar `defaultValue` is the value output for keys * not present in the table. It must also be of the same type as the * table values. */ async find(keys, defaultValue) { this.checkKeyAndValueTensor(keys, defaultValue); const $keys = await keys.data(); return tidy(() => { const result = []; for (let i = 0; i < $keys.length; i++) { const key = $keys[i]; const value = this.findWithDefault(key, defaultValue); result.push(value); } return stack(result); }); } // tslint:disable-next-line: no-any findWithDefault(key, defaultValue) { const result = this.tensorMap.get(key); return result != null ? result : defaultValue; } checkKeyAndValueTensor(key, value) { if (key.dtype !== this.keyDType) { throw new Error(`Expect key dtype ${this.keyDType}, but got ` + `${key.dtype}`); } if (value.dtype !== this.valueDType) { throw new Error(`Expect value dtype ${this.valueDType}, but got ` + `${value.dtype}`); } } } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"hash_table.js","sourceRoot":"","sources":["../../../../../../tfjs-converter/src/executor/hash_table.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AACH,OAAO,EAAW,IAAI,EAAE,MAAM,EAAE,KAAK,EAAU,IAAI,EAAE,OAAO,EAAE,IAAI,EAAC,MAAM,uBAAuB,CAAC;AACjG,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAE1E;;GAEG;AACH,MAAM,OAAO,SAAS;IAMpB,IAAI,EAAE;QACJ,OAAO,IAAI,CAAC,MAAM,CAAC,EAAE,CAAC;IACxB,CAAC;IAED;;;;;OAKG;IACH,YAAqB,QAAkB,EAAW,UAAoB;QAAjD,aAAQ,GAAR,QAAQ,CAAU;QAAW,eAAU,GAAV,UAAU,CAAU;QACpE,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACxB,mCAAmC;QACnC,IAAI,CAAC,SAAS,GAAG,IAAI,GAAG,EAAe,CAAC;QAExC,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;IACpB,CAAC;IAED;;OAEG;IACH,aAAa;QACX,IAAI,CAAC,SAAS,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,OAAO,EAAE,CAAC,CAAC;QACjD,IAAI,CAAC,SAAS,CAAC,KAAK,EAAE,CAAC;QACvB,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC;IACxB,CAAC;IAED;;OAEG;IACH,IAAI;QACF,OAAO,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC;IAC7B,CAAC;IAED;;OAEG;IACH,UAAU;QACR,OAAO,KAAK,CAAC,MAAM,CAAC,IAAI,CAAC,IAAI,EAAE,EAAE,OAAO,CAAC,CAAC;IAC5C,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,MAAM,CAAC,IAAY,EAAE,MAAc;QACvC,IAAI,CAAC,sBAAsB,CAAC,IAAI,EAAE,MAAM,CAAC,CAAC;QAE1C,qEAAqE;QACrE,cAAc;QACd,MAAM,KAAK,GAAG,MAAM,IAAI,CAAC,IAAI,EAAE,CAAC;QAEhC,mDAAmD;QACnD,IAAI,CAAC,SAAS,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,OAAO,EAAE,CAAC,CAAC;QACjD,IAAI,CAAC,SAAS,CAAC,KAAK,EAAE,CAAC;QAEvB,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,OAAO,GAAG,OAAO,CAAC,MAAM,CAAC,CAAC;YAEhC,MAAM,UAAU,GAAG,KAAK,CAAC,MAAM,CAAC;YAChC,MAAM,YAAY,GAAG,OAAO,CAAC,MAAM,CAAC;YAEpC,IAAI,CAAC,MAAM,CACP,UAAU,KAAK,YAAY,EAC3B,GAAG,EAAE,CAAC,iDAAiD;gBACnD,GAAG,UAAU,6BAA6B,YAAY,GAAG;gBACzD,WAAW,CAAC,CAAC;YAErB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE;gBACnC,MAAM,GAAG,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;gBACrB,MAAM,KAAK,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;gBAEzB,IAAI,CAAC,KAAK,CAAC,CAAC;gBACZ,IAAI,CAAC,SAAS,CAAC,GAAG,CAAC,GAAG,EAAE,KAAK,CAAC,CAAC;aAChC;YAED,OAAO,IAAI,CAAC,MAAM,CAAC;QACrB,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;;;;;;;;;;;;OAcG;IACH,KAAK,CAAC,IAAI,CAAC,IAAY,EAAE,YAAoB;QAC3C,IAAI,CAAC,sBAAsB,CAAC,IAAI,EAAE,YAAY,CAAC,CAAC;QAEhD,MAAM,KAAK,GAAG,MAAM,IAAI,CAAC,IAAI,EAAE,CAAC;QAEhC,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,MAAM,GAAa,EAAE,CAAC;YAE5B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;gBACrC,MAAM,GAAG,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;gBAErB,MAAM,KAAK,GAAG,IAAI,CAAC,eAAe,CAAC,GAAG,EAAE,YAAY,CAAC,CAAC;gBACtD,MAAM,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;aACpB;YAED,OAAO,KAAK,CAAC,MAAM,CAAC,CAAC;QACvB,CAAC,CAAC,CAAC;IACL,CAAC;IAED,mCAAmC;IAC3B,eAAe,CAAC,GAAQ,EAAE,YAAoB;QACpD,MAAM,MAAM,GAAG,IAAI,CAAC,SAAS,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;QAEvC,OAAO,MAAM,IAAI,IAAI,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,YAAY,CAAC;IAChD,CAAC;IAEO,sBAAsB,CAAC,GAAW,EAAE,KAAa;QACvD,IAAI,GAAG,CAAC,KAAK,KAAK,IAAI,CAAC,QAAQ,EAAE;YAC/B,MAAM,IAAI,KAAK,CACX,oBAAoB,IAAI,CAAC,QAAQ,YAAY;gBAC7C,GAAG,GAAG,CAAC,KAAK,EAAE,CAAC,CAAC;SACrB;QAED,IAAI,KAAK,CAAC,KAAK,KAAK,IAAI,CAAC,UAAU,EAAE;YACnC,MAAM,IAAI,KAAK,CACX,sBAAsB,IAAI,CAAC,UAAU,YAAY;gBACjD,GAAG,KAAK,CAAC,KAAK,EAAE,CAAC,CAAC;SACvB;IACH,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {DataType, keep, scalar, stack, Tensor, tidy, unstack, util} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\n/**\n * Hashtable contains a set of tensors, which can be accessed by key.\n */\nexport class HashTable {\n  readonly handle: Tensor;\n\n  // tslint:disable-next-line: no-any\n  private tensorMap: Map<any, Tensor>;\n\n  get id() {\n    return this.handle.id;\n  }\n\n  /**\n   * Constructor of HashTable. Creates a hash table.\n   *\n   * @param keyDType `dtype` of the table keys.\n   * @param valueDType `dtype` of the table values.\n   */\n  constructor(readonly keyDType: DataType, readonly valueDType: DataType) {\n    this.handle = scalar(0);\n    // tslint:disable-next-line: no-any\n    this.tensorMap = new Map<any, Tensor>();\n\n    keep(this.handle);\n  }\n\n  /**\n   * Dispose the tensors and handle and clear the hashtable.\n   */\n  clearAndClose() {\n    this.tensorMap.forEach(value => value.dispose());\n    this.tensorMap.clear();\n    this.handle.dispose();\n  }\n\n  /**\n   * The number of items in the hash table.\n   */\n  size(): number {\n    return this.tensorMap.size;\n  }\n\n  /**\n   * The number of items in the hash table as a rank-0 tensor.\n   */\n  tensorSize(): Tensor {\n    return tfOps.scalar(this.size(), 'int32');\n  }\n\n  /**\n   * Replaces the contents of the table with the specified keys and values.\n   * @param keys Keys to store in the hashtable.\n   * @param values Values to store in the hashtable.\n   */\n  async import(keys: Tensor, values: Tensor): Promise<Tensor> {\n    this.checkKeyAndValueTensor(keys, values);\n\n    // We only store the primitive values of the keys, this allows lookup\n    // to be O(1).\n    const $keys = await keys.data();\n\n    // Clear the hashTable before inserting new values.\n    this.tensorMap.forEach(value => value.dispose());\n    this.tensorMap.clear();\n\n    return tidy(() => {\n      const $values = unstack(values);\n\n      const keysLength = $keys.length;\n      const valuesLength = $values.length;\n\n      util.assert(\n          keysLength === valuesLength,\n          () => `The number of elements doesn't match, keys has ` +\n              `${keysLength} elements, the values has ${valuesLength} ` +\n              `elements.`);\n\n      for (let i = 0; i < keysLength; i++) {\n        const key = $keys[i];\n        const value = $values[i];\n\n        keep(value);\n        this.tensorMap.set(key, value);\n      }\n\n      return this.handle;\n    });\n  }\n\n  /**\n   * Looks up keys in a hash table, outputs the corresponding values.\n   *\n   * Performs batch lookups, for every element in the key tensor, `find`\n   * stacks the corresponding value into the return tensor.\n   *\n   * If an element is not present in the table, the given `defaultValue` is\n   * used.\n   *\n   * @param keys Keys to look up. Must have the same type as the keys of the\n   *     table.\n   * @param defaultValue The scalar `defaultValue` is the value output for keys\n   *     not present in the table. It must also be of the same type as the\n   *     table values.\n   */\n  async find(keys: Tensor, defaultValue: Tensor): Promise<Tensor> {\n    this.checkKeyAndValueTensor(keys, defaultValue);\n\n    const $keys = await keys.data();\n\n    return tidy(() => {\n      const result: Tensor[] = [];\n\n      for (let i = 0; i < $keys.length; i++) {\n        const key = $keys[i];\n\n        const value = this.findWithDefault(key, defaultValue);\n        result.push(value);\n      }\n\n      return stack(result);\n    });\n  }\n\n  // tslint:disable-next-line: no-any\n  private findWithDefault(key: any, defaultValue: Tensor): Tensor {\n    const result = this.tensorMap.get(key);\n\n    return result != null ? result : defaultValue;\n  }\n\n  private checkKeyAndValueTensor(key: Tensor, value: Tensor) {\n    if (key.dtype !== this.keyDType) {\n      throw new Error(\n          `Expect key dtype ${this.keyDType}, but got ` +\n          `${key.dtype}`);\n    }\n\n    if (value.dtype !== this.valueDType) {\n      throw new Error(\n          `Expect value dtype ${this.valueDType}, but got ` +\n          `${value.dtype}`);\n    }\n  }\n}\n"]}