@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
109 lines • 12.2 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Common functions for TensorFlow.js Layers.
*/
import { VALID_DATA_FORMAT_VALUES, VALID_INTERPOLATION_FORMAT_VALUES, VALID_PADDING_MODE_VALUES, VALID_POOL_MODE_VALUES } from './keras_format/common';
import { checkStringTypeUnionValue } from './utils/generic_utils';
// A map from the requested scoped name of a Tensor to the number of Tensors
// wanting that name so far. This allows enforcing name uniqueness by appending
// an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.
const nameMap = new Map();
export function checkDataFormat(value) {
checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
}
export function checkInterpolationFormat(value) {
checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value);
}
export function checkPaddingMode(value) {
checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);
}
export function checkPoolMode(value) {
checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);
}
const _nameScopeStack = [];
const _nameScopeDivider = '/';
/**
* Enter namescope, which can be nested.
*/
export function nameScope(name, fn) {
_nameScopeStack.push(name);
try {
const val = fn();
_nameScopeStack.pop();
return val;
}
catch (e) {
_nameScopeStack.pop();
throw e;
}
}
/**
* Get the current namescope as a flat, concatenated string.
*/
function currentNameScopePrefix() {
if (_nameScopeStack.length === 0) {
return '';
}
else {
return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
}
}
/**
* Get the name a Tensor (or Variable) would have if not uniqueified.
* @param tensorName
* @return Scoped name string.
*/
export function getScopedTensorName(tensorName) {
if (!isValidTensorName(tensorName)) {
throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
}
return currentNameScopePrefix() + tensorName;
}
/**
* Get unique names for Tensors and Variables.
* @param scopedName The fully-qualified name of the Tensor, i.e. as produced by
* `getScopedTensorName()`.
* @return A unique version of the given fully scoped name.
* If this is the first time that the scoped name is seen in this session,
* then the given `scopedName` is returned unaltered. If the same name is
* seen again (producing a collision), an incrementing suffix is added to the
* end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.
*/
export function getUniqueTensorName(scopedName) {
if (!isValidTensorName(scopedName)) {
throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
}
if (!nameMap.has(scopedName)) {
nameMap.set(scopedName, 0);
}
const index = nameMap.get(scopedName);
nameMap.set(scopedName, nameMap.get(scopedName) + 1);
if (index > 0) {
const result = `${scopedName}_${index}`;
// Mark the composed name as used in case someone wants
// to call getUniqueTensorName("name_1").
nameMap.set(result, 1);
return result;
}
else {
return scopedName;
}
}
const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
/**
* Determine whether a string is a valid tensor name.
* @param name
* @returns A Boolean indicating whether `name` is a valid tensor name.
*/
export function isValidTensorName(name) {
return !!name.match(tensorNameRegex);
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"common.js","sourceRoot":"","sources":["../../../../../tfjs-layers/src/common.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AACH,OAAO,EAAC,wBAAwB,EAAE,iCAAiC,EAAE,yBAAyB,EAAE,sBAAsB,EAAC,MAAM,uBAAuB,CAAC;AACrJ,OAAO,EAAC,yBAAyB,EAAC,MAAM,uBAAuB,CAAC;AAEhE,4EAA4E;AAC5E,gFAAgF;AAChF,2EAA2E;AAC3E,MAAM,OAAO,GAAwB,IAAI,GAAG,EAAkB,CAAC;AAE/D,MAAM,UAAU,eAAe,CAAC,KAAc;IAC5C,yBAAyB,CAAC,wBAAwB,EAAE,YAAY,EAAE,KAAK,CAAC,CAAC;AAC3E,CAAC;AAED,MAAM,UAAU,wBAAwB,CAAC,KAAc;IACrD,yBAAyB,CACrB,iCAAiC,EAAE,qBAAqB,EAAE,KAAK,CAAC,CAAC;AACvE,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,KAAc;IAC7C,yBAAyB,CAAC,yBAAyB,EAAE,aAAa,EAAE,KAAK,CAAC,CAAC;AAC7E,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,KAAc;IAC1C,yBAAyB,CAAC,sBAAsB,EAAE,UAAU,EAAE,KAAK,CAAC,CAAC;AACvE,CAAC;AAED,MAAM,eAAe,GAAa,EAAE,CAAC;AACrC,MAAM,iBAAiB,GAAG,GAAG,CAAC;AAE9B;;GAEG;AACH,MAAM,UAAU,SAAS,CAAI,IAAY,EAAE,EAAW;IACpD,eAAe,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;IAC3B,IAAI;QACF,MAAM,GAAG,GAAM,EAAE,EAAE,CAAC;QACpB,eAAe,CAAC,GAAG,EAAE,CAAC;QACtB,OAAO,GAAG,CAAC;KACZ;IAAC,OAAO,CAAC,EAAE;QACV,eAAe,CAAC,GAAG,EAAE,CAAC;QACtB,MAAM,CAAC,CAAC;KACT;AACH,CAAC;AAED;;GAEG;AACH,SAAS,sBAAsB;IAC7B,IAAI,eAAe,CAAC,MAAM,KAAK,CAAC,EAAE;QAChC,OAAO,EAAE,CAAC;KACX;SAAM;QACL,OAAO,eAAe,CAAC,IAAI,CAAC,iBAAiB,CAAC,GAAG,iBAAiB,CAAC;KACpE;AACH,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,mBAAmB,CAAC,UAAkB;IACpD,IAAI,CAAC,iBAAiB,CAAC,UAAU,CAAC,EAAE;QAClC,MAAM,IAAI,KAAK,CAAC,6BAA6B,GAAG,UAAU,GAAG,IAAI,CAAC,CAAC;KACpE;IACD,OAAO,sBAAsB,EAAE,GAAG,UAAU,CAAC;AAC/C,CAAC;AAED;;;;;;;;;GASG;AACH,MAAM,UAAU,mBAAmB,CAAC,UAAkB;IACpD,IAAI,CAAC,iBAAiB,CAAC,UAAU,CAAC,EAAE;QAClC,MAAM,IAAI,KAAK,CAAC,6BAA6B,GAAG,UAAU,GAAG,IAAI,CAAC,CAAC;KACpE;IACD,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,UAAU,CAAC,EAAE;QAC5B,OAAO,CAAC,GAAG,CAAC,UAAU,EAAE,CAAC,CAAC,CAAC;KAC5B;IACD,MAAM,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;IACtC,OAAO,CAAC,GAAG,CAAC,UAAU,EAAE,OAAO,CAAC,GAAG,CAAC,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC;IAErD,IAAI,KAAK,GAAG,CAAC,EAAE;QACb,MAAM,MAAM,GAAG,GAAG,UAAU,IAAI,KAAK,EAAE,CAAC;QACxC,uDAAuD;QACvD,yCAAyC;QACzC,OAAO,CAAC,GAAG,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;QACvB,OAAO,MAAM,CAAC;KACf;SAAM;QACL,OAAO,UAAU,CAAC;KACnB;AACH,CAAC;AAED,MAAM,eAAe,GAAG,IAAI,MAAM,CAAC,iCAAiC,CAAC,CAAC;AAEtE;;;;GAIG;AACH,MAAM,UAAU,iBAAiB,CAAC,IAAY;IAC5C,OAAO,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC;AACvC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n * Common functions for TensorFlow.js Layers.\n */\nimport {VALID_DATA_FORMAT_VALUES, VALID_INTERPOLATION_FORMAT_VALUES, VALID_PADDING_MODE_VALUES, VALID_POOL_MODE_VALUES} from './keras_format/common';\nimport {checkStringTypeUnionValue} from './utils/generic_utils';\n\n// A map from the requested scoped name of a Tensor to the number of Tensors\n// wanting that name so far.  This allows enforcing name uniqueness by appending\n// an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.\nconst nameMap: Map<string, number> = new Map<string, number>();\n\nexport function checkDataFormat(value?: string): void {\n  checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);\n}\n\nexport function checkInterpolationFormat(value?: string): void {\n  checkStringTypeUnionValue(\n      VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value);\n}\n\nexport function checkPaddingMode(value?: string): void {\n  checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);\n}\n\nexport function checkPoolMode(value?: string): void {\n  checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);\n}\n\nconst _nameScopeStack: string[] = [];\nconst _nameScopeDivider = '/';\n\n/**\n * Enter namescope, which can be nested.\n */\nexport function nameScope<T>(name: string, fn: () => T): T {\n  _nameScopeStack.push(name);\n  try {\n    const val: T = fn();\n    _nameScopeStack.pop();\n    return val;\n  } catch (e) {\n    _nameScopeStack.pop();\n    throw e;\n  }\n}\n\n/**\n * Get the current namescope as a flat, concatenated string.\n */\nfunction currentNameScopePrefix(): string {\n  if (_nameScopeStack.length === 0) {\n    return '';\n  } else {\n    return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;\n  }\n}\n\n/**\n * Get the name a Tensor (or Variable) would have if not uniqueified.\n * @param tensorName\n * @return Scoped name string.\n */\nexport function getScopedTensorName(tensorName: string): string {\n  if (!isValidTensorName(tensorName)) {\n    throw new Error('Not a valid tensor name: \\'' + tensorName + '\\'');\n  }\n  return currentNameScopePrefix() + tensorName;\n}\n\n/**\n * Get unique names for Tensors and Variables.\n * @param scopedName The fully-qualified name of the Tensor, i.e. as produced by\n *  `getScopedTensorName()`.\n * @return A unique version of the given fully scoped name.\n *   If this is the first time that the scoped name is seen in this session,\n *   then the given `scopedName` is returned unaltered.  If the same name is\n *   seen again (producing a collision), an incrementing suffix is added to the\n *   end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.\n */\nexport function getUniqueTensorName(scopedName: string): string {\n  if (!isValidTensorName(scopedName)) {\n    throw new Error('Not a valid tensor name: \\'' + scopedName + '\\'');\n  }\n  if (!nameMap.has(scopedName)) {\n    nameMap.set(scopedName, 0);\n  }\n  const index = nameMap.get(scopedName);\n  nameMap.set(scopedName, nameMap.get(scopedName) + 1);\n\n  if (index > 0) {\n    const result = `${scopedName}_${index}`;\n    // Mark the composed name as used in case someone wants\n    // to call getUniqueTensorName(\"name_1\").\n    nameMap.set(result, 1);\n    return result;\n  } else {\n    return scopedName;\n  }\n}\n\nconst tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\\._\\/]*$/);\n\n/**\n * Determine whether a string is a valid tensor name.\n * @param name\n * @returns A Boolean indicating whether `name` is a valid tensor name.\n */\nexport function isValidTensorName(name: string): boolean {\n  return !!name.match(tensorNameRegex);\n}\n"]}