UNPKG

@tensorflow/tfjs-layers

Version:

TensorFlow layers API in JavaScript

109 lines 12.2 kB
/** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Common functions for TensorFlow.js Layers. */ import { VALID_DATA_FORMAT_VALUES, VALID_INTERPOLATION_FORMAT_VALUES, VALID_PADDING_MODE_VALUES, VALID_POOL_MODE_VALUES } from './keras_format/common'; import { checkStringTypeUnionValue } from './utils/generic_utils'; // A map from the requested scoped name of a Tensor to the number of Tensors // wanting that name so far. This allows enforcing name uniqueness by appending // an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc. const nameMap = new Map(); export function checkDataFormat(value) { checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value); } export function checkInterpolationFormat(value) { checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value); } export function checkPaddingMode(value) { checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value); } export function checkPoolMode(value) { checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value); } const _nameScopeStack = []; const _nameScopeDivider = '/'; /** * Enter namescope, which can be nested. */ export function nameScope(name, fn) { _nameScopeStack.push(name); try { const val = fn(); _nameScopeStack.pop(); return val; } catch (e) { _nameScopeStack.pop(); throw e; } } /** * Get the current namescope as a flat, concatenated string. */ function currentNameScopePrefix() { if (_nameScopeStack.length === 0) { return ''; } else { return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider; } } /** * Get the name a Tensor (or Variable) would have if not uniqueified. * @param tensorName * @return Scoped name string. */ export function getScopedTensorName(tensorName) { if (!isValidTensorName(tensorName)) { throw new Error('Not a valid tensor name: \'' + tensorName + '\''); } return currentNameScopePrefix() + tensorName; } /** * Get unique names for Tensors and Variables. * @param scopedName The fully-qualified name of the Tensor, i.e. as produced by * `getScopedTensorName()`. * @return A unique version of the given fully scoped name. * If this is the first time that the scoped name is seen in this session, * then the given `scopedName` is returned unaltered. If the same name is * seen again (producing a collision), an incrementing suffix is added to the * end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc. */ export function getUniqueTensorName(scopedName) { if (!isValidTensorName(scopedName)) { throw new Error('Not a valid tensor name: \'' + scopedName + '\''); } if (!nameMap.has(scopedName)) { nameMap.set(scopedName, 0); } const index = nameMap.get(scopedName); nameMap.set(scopedName, nameMap.get(scopedName) + 1); if (index > 0) { const result = `${scopedName}_${index}`; // Mark the composed name as used in case someone wants // to call getUniqueTensorName("name_1"). nameMap.set(result, 1); return result; } else { return scopedName; } } const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/); /** * Determine whether a string is a valid tensor name. * @param name * @returns A Boolean indicating whether `name` is a valid tensor name. */ export function isValidTensorName(name) { return !!name.match(tensorNameRegex); } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"common.js","sourceRoot":"","sources":["../../../../../tfjs-layers/src/common.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AACH,OAAO,EAAC,wBAAwB,EAAE,iCAAiC,EAAE,yBAAyB,EAAE,sBAAsB,EAAC,MAAM,uBAAuB,CAAC;AACrJ,OAAO,EAAC,yBAAyB,EAAC,MAAM,uBAAuB,CAAC;AAEhE,4EAA4E;AAC5E,gFAAgF;AAChF,2EAA2E;AAC3E,MAAM,OAAO,GAAwB,IAAI,GAAG,EAAkB,CAAC;AAE/D,MAAM,UAAU,eAAe,CAAC,KAAc;IAC5C,yBAAyB,CAAC,wBAAwB,EAAE,YAAY,EAAE,KAAK,CAAC,CAAC;AAC3E,CAAC;AAED,MAAM,UAAU,wBAAwB,CAAC,KAAc;IACrD,yBAAyB,CACrB,iCAAiC,EAAE,qBAAqB,EAAE,KAAK,CAAC,CAAC;AACvE,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,KAAc;IAC7C,yBAAyB,CAAC,yBAAyB,EAAE,aAAa,EAAE,KAAK,CAAC,CAAC;AAC7E,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,KAAc;IAC1C,yBAAyB,CAAC,sBAAsB,EAAE,UAAU,EAAE,KAAK,CAAC,CAAC;AACvE,CAAC;AAED,MAAM,eAAe,GAAa,EAAE,CAAC;AACrC,MAAM,iBAAiB,GAAG,GAAG,CAAC;AAE9B;;GAEG;AACH,MAAM,UAAU,SAAS,CAAI,IAAY,EAAE,EAAW;IACpD,eAAe,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;IAC3B,IAAI;QACF,MAAM,GAAG,GAAM,EAAE,EAAE,CAAC;QACpB,eAAe,CAAC,GAAG,EAAE,CAAC;QACtB,OAAO,GAAG,CAAC;KACZ;IAAC,OAAO,CAAC,EAAE;QACV,eAAe,CAAC,GAAG,EAAE,CAAC;QACtB,MAAM,CAAC,CAAC;KACT;AACH,CAAC;AAED;;GAEG;AACH,SAAS,sBAAsB;IAC7B,IAAI,eAAe,CAAC,MAAM,KAAK,CAAC,EAAE;QAChC,OAAO,EAAE,CAAC;KACX;SAAM;QACL,OAAO,eAAe,CAAC,IAAI,CAAC,iBAAiB,CAAC,GAAG,iBAAiB,CAAC;KACpE;AACH,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,mBAAmB,CAAC,UAAkB;IACpD,IAAI,CAAC,iBAAiB,CAAC,UAAU,CAAC,EAAE;QAClC,MAAM,IAAI,KAAK,CAAC,6BAA6B,GAAG,UAAU,GAAG,IAAI,CAAC,CAAC;KACpE;IACD,OAAO,sBAAsB,EAAE,GAAG,UAAU,CAAC;AAC/C,CAAC;AAED;;;;;;;;;GASG;AACH,MAAM,UAAU,mBAAmB,CAAC,UAAkB;IACpD,IAAI,CAAC,iBAAiB,CAAC,UAAU,CAAC,EAAE;QAClC,MAAM,IAAI,KAAK,CAAC,6BAA6B,GAAG,UAAU,GAAG,IAAI,CAAC,CAAC;KACpE;IACD,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,UAAU,CAAC,EAAE;QAC5B,OAAO,CAAC,GAAG,CAAC,UAAU,EAAE,CAAC,CAAC,CAAC;KAC5B;IACD,MAAM,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;IACtC,OAAO,CAAC,GAAG,CAAC,UAAU,EAAE,OAAO,CAAC,GAAG,CAAC,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC;IAErD,IAAI,KAAK,GAAG,CAAC,EAAE;QACb,MAAM,MAAM,GAAG,GAAG,UAAU,IAAI,KAAK,EAAE,CAAC;QACxC,uDAAuD;QACvD,yCAAyC;QACzC,OAAO,CAAC,GAAG,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;QACvB,OAAO,MAAM,CAAC;KACf;SAAM;QACL,OAAO,UAAU,CAAC;KACnB;AACH,CAAC;AAED,MAAM,eAAe,GAAG,IAAI,MAAM,CAAC,iCAAiC,CAAC,CAAC;AAEtE;;;;GAIG;AACH,MAAM,UAAU,iBAAiB,CAAC,IAAY;IAC5C,OAAO,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC;AACvC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * Common functions for TensorFlow.js Layers.\n */\nimport {VALID_DATA_FORMAT_VALUES, VALID_INTERPOLATION_FORMAT_VALUES, VALID_PADDING_MODE_VALUES, VALID_POOL_MODE_VALUES} from './keras_format/common';\nimport {checkStringTypeUnionValue} from './utils/generic_utils';\n\n// A map from the requested scoped name of a Tensor to the number of Tensors\n// wanting that name so far.  This allows enforcing name uniqueness by appending\n// an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.\nconst nameMap: Map<string, number> = new Map<string, number>();\n\nexport function checkDataFormat(value?: string): void {\n  checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);\n}\n\nexport function checkInterpolationFormat(value?: string): void {\n  checkStringTypeUnionValue(\n      VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value);\n}\n\nexport function checkPaddingMode(value?: string): void {\n  checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);\n}\n\nexport function checkPoolMode(value?: string): void {\n  checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);\n}\n\nconst _nameScopeStack: string[] = [];\nconst _nameScopeDivider = '/';\n\n/**\n * Enter namescope, which can be nested.\n */\nexport function nameScope<T>(name: string, fn: () => T): T {\n  _nameScopeStack.push(name);\n  try {\n    const val: T = fn();\n    _nameScopeStack.pop();\n    return val;\n  } catch (e) {\n    _nameScopeStack.pop();\n    throw e;\n  }\n}\n\n/**\n * Get the current namescope as a flat, concatenated string.\n */\nfunction currentNameScopePrefix(): string {\n  if (_nameScopeStack.length === 0) {\n    return '';\n  } else {\n    return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;\n  }\n}\n\n/**\n * Get the name a Tensor (or Variable) would have if not uniqueified.\n * @param tensorName\n * @return Scoped name string.\n */\nexport function getScopedTensorName(tensorName: string): string {\n  if (!isValidTensorName(tensorName)) {\n    throw new Error('Not a valid tensor name: \\'' + tensorName + '\\'');\n  }\n  return currentNameScopePrefix() + tensorName;\n}\n\n/**\n * Get unique names for Tensors and Variables.\n * @param scopedName The fully-qualified name of the Tensor, i.e. as produced by\n *  `getScopedTensorName()`.\n * @return A unique version of the given fully scoped name.\n *   If this is the first time that the scoped name is seen in this session,\n *   then the given `scopedName` is returned unaltered.  If the same name is\n *   seen again (producing a collision), an incrementing suffix is added to the\n *   end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.\n */\nexport function getUniqueTensorName(scopedName: string): string {\n  if (!isValidTensorName(scopedName)) {\n    throw new Error('Not a valid tensor name: \\'' + scopedName + '\\'');\n  }\n  if (!nameMap.has(scopedName)) {\n    nameMap.set(scopedName, 0);\n  }\n  const index = nameMap.get(scopedName);\n  nameMap.set(scopedName, nameMap.get(scopedName) + 1);\n\n  if (index > 0) {\n    const result = `${scopedName}_${index}`;\n    // Mark the composed name as used in case someone wants\n    // to call getUniqueTensorName(\"name_1\").\n    nameMap.set(result, 1);\n    return result;\n  } else {\n    return scopedName;\n  }\n}\n\nconst tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\\._\\/]*$/);\n\n/**\n * Determine whether a string is a valid tensor name.\n * @param name\n * @returns A Boolean indicating whether `name` is a valid tensor name.\n */\nexport function isValidTensorName(name: string): boolean {\n  return !!name.match(tensorNameRegex);\n}\n"]}