@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
109 lines • 12.2 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Common functions for TensorFlow.js Layers.
*/
import { VALID_DATA_FORMAT_VALUES, VALID_INTERPOLATION_FORMAT_VALUES, VALID_PADDING_MODE_VALUES, VALID_POOL_MODE_VALUES } from './keras_format/common';
import { checkStringTypeUnionValue } from './utils/generic_utils';
// A map from the requested scoped name of a Tensor to the number of Tensors
// wanting that name so far. This allows enforcing name uniqueness by appending
// an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.
const nameMap = new Map();
export function checkDataFormat(value) {
checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
}
export function checkInterpolationFormat(value) {
checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value);
}
export function checkPaddingMode(value) {
checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);
}
export function checkPoolMode(value) {
checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);
}
const _nameScopeStack = [];
const _nameScopeDivider = '/';
/**
* Enter namescope, which can be nested.
*/
export function nameScope(name, fn) {
_nameScopeStack.push(name);
try {
const val = fn();
_nameScopeStack.pop();
return val;
}
catch (e) {
_nameScopeStack.pop();
throw e;
}
}
/**
* Get the current namescope as a flat, concatenated string.
*/
function currentNameScopePrefix() {
if (_nameScopeStack.length === 0) {
return '';
}
else {
return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
}
}
/**
* Get the name a Tensor (or Variable) would have if not uniqueified.
* @param tensorName
* @return Scoped name string.
*/
export function getScopedTensorName(tensorName) {
if (!isValidTensorName(tensorName)) {
throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
}
return currentNameScopePrefix() + tensorName;
}
/**
* Get unique names for Tensors and Variables.
* @param scopedName The fully-qualified name of the Tensor, i.e. as produced by
* `getScopedTensorName()`.
* @return A unique version of the given fully scoped name.
* If this is the first time that the scoped name is seen in this session,
* then the given `scopedName` is returned unaltered. If the same name is
* seen again (producing a collision), an incrementing suffix is added to the
* end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.
*/
export function getUniqueTensorName(scopedName) {
if (!isValidTensorName(scopedName)) {
throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
}
if (!nameMap.has(scopedName)) {
nameMap.set(scopedName, 0);
}
const index = nameMap.get(scopedName);
nameMap.set(scopedName, nameMap.get(scopedName) + 1);
if (index > 0) {
const result = `${scopedName}_${index}`;
// Mark the composed name as used in case someone wants
// to call getUniqueTensorName("name_1").
nameMap.set(result, 1);
return result;
}
else {
return scopedName;
}
}
const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
/**
* Determine whether a string is a valid tensor name.
* @param name
* @returns A Boolean indicating whether `name` is a valid tensor name.
*/
export function isValidTensorName(name) {
return !!name.match(tensorNameRegex);
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY29tbW9uLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vdGZqcy1sYXllcnMvc3JjL2NvbW1vbi50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7R0FRRztBQUVIOztHQUVHO0FBQ0gsT0FBTyxFQUFDLHdCQUF3QixFQUFFLGlDQUFpQyxFQUFFLHlCQUF5QixFQUFFLHNCQUFzQixFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFDckosT0FBTyxFQUFDLHlCQUF5QixFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFFaEUsNEVBQTRFO0FBQzVFLGdGQUFnRjtBQUNoRiwyRUFBMkU7QUFDM0UsTUFBTSxPQUFPLEdBQXdCLElBQUksR0FBRyxFQUFrQixDQUFDO0FBRS9ELE1BQU0sVUFBVSxlQUFlLENBQUMsS0FBYztJQUM1Qyx5QkFBeUIsQ0FBQyx3QkFBd0IsRUFBRSxZQUFZLEVBQUUsS0FBSyxDQUFDLENBQUM7QUFDM0UsQ0FBQztBQUVELE1BQU0sVUFBVSx3QkFBd0IsQ0FBQyxLQUFjO0lBQ3JELHlCQUF5QixDQUNyQixpQ0FBaUMsRUFBRSxxQkFBcUIsRUFBRSxLQUFLLENBQUMsQ0FBQztBQUN2RSxDQUFDO0FBRUQsTUFBTSxVQUFVLGdCQUFnQixDQUFDLEtBQWM7SUFDN0MseUJBQXlCLENBQUMseUJBQXlCLEVBQUUsYUFBYSxFQUFFLEtBQUssQ0FBQyxDQUFDO0FBQzdFLENBQUM7QUFFRCxNQUFNLFVBQVUsYUFBYSxDQUFDLEtBQWM7SUFDMUMseUJBQXlCLENBQUMsc0JBQXNCLEVBQUUsVUFBVSxFQUFFLEtBQUssQ0FBQyxDQUFDO0FBQ3ZFLENBQUM7QUFFRCxNQUFNLGVBQWUsR0FBYSxFQUFFLENBQUM7QUFDckMsTUFBTSxpQkFBaUIsR0FBRyxHQUFHLENBQUM7QUFFOUI7O0dBRUc7QUFDSCxNQUFNLFVBQVUsU0FBUyxDQUFJLElBQVksRUFBRSxFQUFXO0lBQ3BELGVBQWUsQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLENBQUM7SUFDM0IsSUFBSTtRQUNGLE1BQU0sR0FBRyxHQUFNLEVBQUUsRUFBRSxDQUFDO1FBQ3BCLGVBQWUsQ0FBQyxHQUFHLEVBQUUsQ0FBQztRQUN0QixPQUFPLEdBQUcsQ0FBQztLQUNaO0lBQUMsT0FBTyxDQUFDLEVBQUU7UUFDVixlQUFlLENBQUMsR0FBRyxFQUFFLENBQUM7UUFDdEIsTUFBTSxDQUFDLENBQUM7S0FDVDtBQUNILENBQUM7QUFFRDs7R0FFRztBQUNILFNBQVMsc0JBQXNCO0lBQzdCLElBQUksZUFBZSxDQUFDLE1BQU0sS0FBSyxDQUFDLEVBQUU7UUFDaEMsT0FBTyxFQUFFLENBQUM7S0FDWDtTQUFNO1FBQ0wsT0FBTyxlQUFlLENBQUMsSUFBSSxDQUFDLGlCQUFpQixDQUFDLEdBQUcsaUJBQWlCLENBQUM7S0FDcEU7QUFDSCxDQUFDO0FBRUQ7Ozs7R0FJRztBQUNILE1BQU0sVUFBVSxtQkFBbUIsQ0FBQyxVQUFrQjtJQUNwRCxJQUFJLENBQUMsaUJBQWlCLENBQUMsVUFBVSxDQUFDLEVBQUU7UUFDbEMsTUFBTSxJQUFJLEtBQUssQ0FBQyw2QkFBNkIsR0FBRyxVQUFVLEdBQUcsSUFBSSxDQUFDLENBQUM7S0FDcEU7SUFDRCxPQUFPLHNCQUFzQixFQUFFLEdBQUcsVUFBVSxDQUFDO0FBQy9DLENBQUM7QUFFRDs7Ozs7Ozs7O0dBU0c7QUFDSCxNQUFNLFVBQVUsbUJBQW1CLENBQUMsVUFBa0I7SUFDcEQsSUFBSSxDQUFDLGlCQUFpQixDQUFDLFVBQVUsQ0FBQyxFQUFFO1FBQ2xDLE1BQU0sSUFBSSxLQUFLLENBQUMsNkJBQTZCLEdBQUcsVUFBVSxHQUFHLElBQUksQ0FBQyxDQUFDO0tBQ3BFO0lBQ0QsSUFBSSxDQUFDLE9BQU8sQ0FBQyxHQUFHLENBQUMsVUFBVSxDQUFDLEVBQUU7UUFDNUIsT0FBTyxDQUFDLEdBQUcsQ0FBQyxVQUFVLEVBQUUsQ0FBQyxDQUFDLENBQUM7S0FDNUI7SUFDRCxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsR0FBRyxDQUFDLFVBQVUsQ0FBQyxDQUFDO0lBQ3RDLE9BQU8sQ0FBQyxHQUFHLENBQUMsVUFBVSxFQUFFLE9BQU8sQ0FBQyxHQUFHLENBQUMsVUFBVSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7SUFFckQsSUFBSSxLQUFLLEdBQUcsQ0FBQyxFQUFFO1FBQ2IsTUFBTSxNQUFNLEdBQUcsR0FBRyxVQUFVLElBQUksS0FBSyxFQUFFLENBQUM7UUFDeEMsdURBQXVEO1FBQ3ZELHlDQUF5QztRQUN6QyxPQUFPLENBQUMsR0FBRyxDQUFDLE1BQU0sRUFBRSxDQUFDLENBQUMsQ0FBQztRQUN2QixPQUFPLE1BQU0sQ0FBQztLQUNmO1NBQU07UUFDTCxPQUFPLFVBQVUsQ0FBQztLQUNuQjtBQUNILENBQUM7QUFFRCxNQUFNLGVBQWUsR0FBRyxJQUFJLE1BQU0sQ0FBQyxpQ0FBaUMsQ0FBQyxDQUFDO0FBRXRFOzs7O0dBSUc7QUFDSCxNQUFNLFVBQVUsaUJBQWlCLENBQUMsSUFBWTtJQUM1QyxPQUFPLENBQUMsQ0FBQyxJQUFJLENBQUMsS0FBSyxDQUFDLGVBQWUsQ0FBQyxDQUFDO0FBQ3ZDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxOCBHb29nbGUgTExDXG4gKlxuICogVXNlIG9mIHRoaXMgc291cmNlIGNvZGUgaXMgZ292ZXJuZWQgYnkgYW4gTUlULXN0eWxlXG4gKiBsaWNlbnNlIHRoYXQgY2FuIGJlIGZvdW5kIGluIHRoZSBMSUNFTlNFIGZpbGUgb3IgYXRcbiAqIGh0dHBzOi8vb3BlbnNvdXJjZS5vcmcvbGljZW5zZXMvTUlULlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG4vKipcbiAqIENvbW1vbiBmdW5jdGlvbnMgZm9yIFRlbnNvckZsb3cuanMgTGF5ZXJzLlxuICovXG5pbXBvcnQge1ZBTElEX0RBVEFfRk9STUFUX1ZBTFVFUywgVkFMSURfSU5URVJQT0xBVElPTl9GT1JNQVRfVkFMVUVTLCBWQUxJRF9QQURESU5HX01PREVfVkFMVUVTLCBWQUxJRF9QT09MX01PREVfVkFMVUVTfSBmcm9tICcuL2tlcmFzX2Zvcm1hdC9jb21tb24nO1xuaW1wb3J0IHtjaGVja1N0cmluZ1R5cGVVbmlvblZhbHVlfSBmcm9tICcuL3V0aWxzL2dlbmVyaWNfdXRpbHMnO1xuXG4vLyBBIG1hcCBmcm9tIHRoZSByZXF1ZXN0ZWQgc2NvcGVkIG5hbWUgb2YgYSBUZW5zb3IgdG8gdGhlIG51bWJlciBvZiBUZW5zb3JzXG4vLyB3YW50aW5nIHRoYXQgbmFtZSBzbyBmYXIuICBUaGlzIGFsbG93cyBlbmZvcmNpbmcgbmFtZSB1bmlxdWVuZXNzIGJ5IGFwcGVuZGluZ1xuLy8gYW4gaW5jcmVtZW50aW5nIGluZGV4LCBlLmcuIHNjb3BlL25hbWUsIHNjb3BlL25hbWVfMSwgc2NvcGUvbmFtZV8yLCBldGMuXG5jb25zdCBuYW1lTWFwOiBNYXA8c3RyaW5nLCBudW1iZXI+ID0gbmV3IE1hcDxzdHJpbmcsIG51bWJlcj4oKTtcblxuZXhwb3J0IGZ1bmN0aW9uIGNoZWNrRGF0YUZvcm1hdCh2YWx1ZT86IHN0cmluZyk6IHZvaWQge1xuICBjaGVja1N0cmluZ1R5cGVVbmlvblZhbHVlKFZBTElEX0RBVEFfRk9STUFUX1ZBTFVFUywgJ0RhdGFGb3JtYXQnLCB2YWx1ZSk7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBjaGVja0ludGVycG9sYXRpb25Gb3JtYXQodmFsdWU/OiBzdHJpbmcpOiB2b2lkIHtcbiAgY2hlY2tTdHJpbmdUeXBlVW5pb25WYWx1ZShcbiAgICAgIFZBTElEX0lOVEVSUE9MQVRJT05fRk9STUFUX1ZBTFVFUywgJ0ludGVycG9sYXRpb25Gb3JtYXQnLCB2YWx1ZSk7XG59XG5cbmV4cG9ydCBmdW5jdGlvbiBjaGVja1BhZGRpbmdNb2RlKHZhbHVlPzogc3RyaW5nKTogdm9pZCB7XG4gIGNoZWNrU3RyaW5nVHlwZVVuaW9uVmFsdWUoVkFMSURfUEFERElOR19NT0RFX1ZBTFVFUywgJ1BhZGRpbmdNb2RlJywgdmFsdWUpO1xufVxuXG5leHBvcnQgZnVuY3Rpb24gY2hlY2tQb29sTW9kZSh2YWx1ZT86IHN0cmluZyk6IHZvaWQge1xuICBjaGVja1N0cmluZ1R5cGVVbmlvblZhbHVlKFZBTElEX1BPT0xfTU9ERV9WQUxVRVMsICdQb29sTW9kZScsIHZhbHVlKTtcbn1cblxuY29uc3QgX25hbWVTY29wZVN0YWNrOiBzdHJpbmdbXSA9IFtdO1xuY29uc3QgX25hbWVTY29wZURpdmlkZXIgPSAnLyc7XG5cbi8qKlxuICogRW50ZXIgbmFtZXNjb3BlLCB3aGljaCBjYW4gYmUgbmVzdGVkLlxuICovXG5leHBvcnQgZnVuY3Rpb24gbmFtZVNjb3BlPFQ+KG5hbWU6IHN0cmluZywgZm46ICgpID0+IFQpOiBUIHtcbiAgX25hbWVTY29wZVN0YWNrLnB1c2gobmFtZSk7XG4gIHRyeSB7XG4gICAgY29uc3QgdmFsOiBUID0gZm4oKTtcbiAgICBfbmFtZVNjb3BlU3RhY2sucG9wKCk7XG4gICAgcmV0dXJuIHZhbDtcbiAgfSBjYXRjaCAoZSkge1xuICAgIF9uYW1lU2NvcGVTdGFjay5wb3AoKTtcbiAgICB0aHJvdyBlO1xuICB9XG59XG5cbi8qKlxuICogR2V0IHRoZSBjdXJyZW50IG5hbWVzY29wZSBhcyBhIGZsYXQsIGNvbmNhdGVuYXRlZCBzdHJpbmcuXG4gKi9cbmZ1bmN0aW9uIGN1cnJlbnROYW1lU2NvcGVQcmVmaXgoKTogc3RyaW5nIHtcbiAgaWYgKF9uYW1lU2NvcGVTdGFjay5sZW5ndGggPT09IDApIHtcbiAgICByZXR1cm4gJyc7XG4gIH0gZWxzZSB7XG4gICAgcmV0dXJuIF9uYW1lU2NvcGVTdGFjay5qb2luKF9uYW1lU2NvcGVEaXZpZGVyKSArIF9uYW1lU2NvcGVEaXZpZGVyO1xuICB9XG59XG5cbi8qKlxuICogR2V0IHRoZSBuYW1lIGEgVGVuc29yIChvciBWYXJpYWJsZSkgd291bGQgaGF2ZSBpZiBub3QgdW5pcXVlaWZpZWQuXG4gKiBAcGFyYW0gdGVuc29yTmFtZVxuICogQHJldHVybiBTY29wZWQgbmFtZSBzdHJpbmcuXG4gKi9cbmV4cG9ydCBmdW5jdGlvbiBnZXRTY29wZWRUZW5zb3JOYW1lKHRlbnNvck5hbWU6IHN0cmluZyk6IHN0cmluZyB7XG4gIGlmICghaXNWYWxpZFRlbnNvck5hbWUodGVuc29yTmFtZSkpIHtcbiAgICB0aHJvdyBuZXcgRXJyb3IoJ05vdCBhIHZhbGlkIHRlbnNvciBuYW1lOiBcXCcnICsgdGVuc29yTmFtZSArICdcXCcnKTtcbiAgfVxuICByZXR1cm4gY3VycmVudE5hbWVTY29wZVByZWZpeCgpICsgdGVuc29yTmFtZTtcbn1cblxuLyoqXG4gKiBHZXQgdW5pcXVlIG5hbWVzIGZvciBUZW5zb3JzIGFuZCBWYXJpYWJsZXMuXG4gKiBAcGFyYW0gc2NvcGVkTmFtZSBUaGUgZnVsbHktcXVhbGlmaWVkIG5hbWUgb2YgdGhlIFRlbnNvciwgaS5lLiBhcyBwcm9kdWNlZCBieVxuICogIGBnZXRTY29wZWRUZW5zb3JOYW1lKClgLlxuICogQHJldHVybiBBIHVuaXF1ZSB2ZXJzaW9uIG9mIHRoZSBnaXZlbiBmdWxseSBzY29wZWQgbmFtZS5cbiAqICAgSWYgdGhpcyBpcyB0aGUgZmlyc3QgdGltZSB0aGF0IHRoZSBzY29wZWQgbmFtZSBpcyBzZWVuIGluIHRoaXMgc2Vzc2lvbixcbiAqICAgdGhlbiB0aGUgZ2l2ZW4gYHNjb3BlZE5hbWVgIGlzIHJldHVybmVkIHVuYWx0ZXJlZC4gIElmIHRoZSBzYW1lIG5hbWUgaXNcbiAqICAgc2VlbiBhZ2FpbiAocHJvZHVjaW5nIGEgY29sbGlzaW9uKSwgYW4gaW5jcmVtZW50aW5nIHN1ZmZpeCBpcyBhZGRlZCB0byB0aGVcbiAqICAgZW5kIG9mIHRoZSBuYW1lLCBzbyBpdCB0YWtlcyB0aGUgZm9ybSAnc2NvcGUvbmFtZV8xJywgJ3Njb3BlL25hbWVfMicsIGV0Yy5cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGdldFVuaXF1ZVRlbnNvck5hbWUoc2NvcGVkTmFtZTogc3RyaW5nKTogc3RyaW5nIHtcbiAgaWYgKCFpc1ZhbGlkVGVuc29yTmFtZShzY29wZWROYW1lKSkge1xuICAgIHRocm93IG5ldyBFcnJvcignTm90IGEgdmFsaWQgdGVuc29yIG5hbWU6IFxcJycgKyBzY29wZWROYW1lICsgJ1xcJycpO1xuICB9XG4gIGlmICghbmFtZU1hcC5oYXMoc2NvcGVkTmFtZSkpIHtcbiAgICBuYW1lTWFwLnNldChzY29wZWROYW1lLCAwKTtcbiAgfVxuICBjb25zdCBpbmRleCA9IG5hbWVNYXAuZ2V0KHNjb3BlZE5hbWUpO1xuICBuYW1lTWFwLnNldChzY29wZWROYW1lLCBuYW1lTWFwLmdldChzY29wZWROYW1lKSArIDEpO1xuXG4gIGlmIChpbmRleCA+IDApIHtcbiAgICBjb25zdCByZXN1bHQgPSBgJHtzY29wZWROYW1lfV8ke2luZGV4fWA7XG4gICAgLy8gTWFyayB0aGUgY29tcG9zZWQgbmFtZSBhcyB1c2VkIGluIGNhc2Ugc29tZW9uZSB3YW50c1xuICAgIC8vIHRvIGNhbGwgZ2V0VW5pcXVlVGVuc29yTmFtZShcIm5hbWVfMVwiKS5cbiAgICBuYW1lTWFwLnNldChyZXN1bHQsIDEpO1xuICAgIHJldHVybiByZXN1bHQ7XG4gIH0gZWxzZSB7XG4gICAgcmV0dXJuIHNjb3BlZE5hbWU7XG4gIH1cbn1cblxuY29uc3QgdGVuc29yTmFtZVJlZ2V4ID0gbmV3IFJlZ0V4cCgvXltBLVphLXowLTldWy1BLVphLXowLTlcXC5fXFwvXSokLyk7XG5cbi8qKlxuICogRGV0ZXJtaW5lIHdoZXRoZXIgYSBzdHJpbmcgaXMgYSB2YWxpZCB0ZW5zb3IgbmFtZS5cbiAqIEBwYXJhbSBuYW1lXG4gKiBAcmV0dXJucyBBIEJvb2xlYW4gaW5kaWNhdGluZyB3aGV0aGVyIGBuYW1lYCBpcyBhIHZhbGlkIHRlbnNvciBuYW1lLlxuICovXG5leHBvcnQgZnVuY3Rpb24gaXNWYWxpZFRlbnNvck5hbWUobmFtZTogc3RyaW5nKTogYm9vbGVhbiB7XG4gIHJldHVybiAhIW5hbWUubWF0Y2godGVuc29yTmFtZVJlZ2V4KTtcbn1cbiJdfQ==