@tensorflow/tfjs-layers
Version:
TensorFlow layers API in JavaScript
76 lines • 7.22 kB
JavaScript
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
import { ValueError } from '../errors';
// tslint:enable
/**
* Determine whether the input is an Array of Shapes.
*/
export function isArrayOfShapes(x) {
return Array.isArray(x) && Array.isArray(x[0]);
}
/**
* Special case of normalizing shapes to lists.
*
* @param x A shape or list of shapes to normalize into a list of Shapes.
* @return A list of Shapes.
*/
export function normalizeShapeList(x) {
if (x.length === 0) {
return [];
}
if (!Array.isArray(x[0])) {
return [x];
}
return x;
}
/**
* Helper function to obtain exactly one Tensor.
* @param xs: A single `tf.Tensor` or an `Array` of `tf.Tensor`s.
* @return A single `tf.Tensor`. If `xs` is an `Array`, return the first one.
* @throws ValueError: If `xs` is an `Array` and its length is not 1.
*/
export function getExactlyOneTensor(xs) {
let x;
if (Array.isArray(xs)) {
if (xs.length !== 1) {
throw new ValueError(`Expected Tensor length to be 1; got ${xs.length}`);
}
x = xs[0];
}
else {
x = xs;
}
return x;
}
/**
* Helper function to obtain exactly on instance of Shape.
*
* @param shapes Input single `Shape` or Array of `Shape`s.
* @returns If input is a single `Shape`, return it unchanged. If the input is
* an `Array` containing exactly one instance of `Shape`, return the instance.
* Otherwise, throw a `ValueError`.
* @throws ValueError: If input is an `Array` of `Shape`s, and its length is not
* 1.
*/
export function getExactlyOneShape(shapes) {
if (Array.isArray(shapes) && Array.isArray(shapes[0])) {
if (shapes.length === 1) {
shapes = shapes;
return shapes[0];
}
else {
throw new ValueError(`Expected exactly 1 Shape; got ${shapes.length}`);
}
}
else {
return shapes;
}
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidHlwZXNfdXRpbHMuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWxheWVycy9zcmMvdXRpbHMvdHlwZXNfdXRpbHMudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7O0dBUUc7QUFLSCxPQUFPLEVBQUMsVUFBVSxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBRXJDLGdCQUFnQjtBQUVoQjs7R0FFRztBQUNILE1BQU0sVUFBVSxlQUFlLENBQUMsQ0FBZ0I7SUFDOUMsT0FBTyxLQUFLLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxJQUFJLEtBQUssQ0FBQyxPQUFPLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7QUFDakQsQ0FBQztBQUVEOzs7OztHQUtHO0FBQ0gsTUFBTSxVQUFVLGtCQUFrQixDQUFDLENBQWdCO0lBQ2pELElBQUksQ0FBQyxDQUFDLE1BQU0sS0FBSyxDQUFDLEVBQUU7UUFDbEIsT0FBTyxFQUFFLENBQUM7S0FDWDtJQUNELElBQUksQ0FBQyxLQUFLLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFO1FBQ3hCLE9BQU8sQ0FBQyxDQUFDLENBQVksQ0FBQztLQUN2QjtJQUNELE9BQU8sQ0FBWSxDQUFDO0FBQ3RCLENBQUM7QUFFRDs7Ozs7R0FLRztBQUNILE1BQU0sVUFBVSxtQkFBbUIsQ0FBQyxFQUFtQjtJQUNyRCxJQUFJLENBQVMsQ0FBQztJQUNkLElBQUksS0FBSyxDQUFDLE9BQU8sQ0FBQyxFQUFFLENBQUMsRUFBRTtRQUNyQixJQUFJLEVBQUUsQ0FBQyxNQUFNLEtBQUssQ0FBQyxFQUFFO1lBQ25CLE1BQU0sSUFBSSxVQUFVLENBQUMsdUNBQXVDLEVBQUUsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxDQUFDO1NBQzFFO1FBQ0QsQ0FBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztLQUNYO1NBQU07UUFDTCxDQUFDLEdBQUcsRUFBRSxDQUFDO0tBQ1I7SUFDRCxPQUFPLENBQUMsQ0FBQztBQUNYLENBQUM7QUFFRDs7Ozs7Ozs7O0dBU0c7QUFDSCxNQUFNLFVBQVUsa0JBQWtCLENBQUMsTUFBcUI7SUFDdEQsSUFBSSxLQUFLLENBQUMsT0FBTyxDQUFDLE1BQU0sQ0FBQyxJQUFJLEtBQUssQ0FBQyxPQUFPLENBQUMsTUFBTSxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUU7UUFDckQsSUFBSSxNQUFNLENBQUMsTUFBTSxLQUFLLENBQUMsRUFBRTtZQUN2QixNQUFNLEdBQUcsTUFBaUIsQ0FBQztZQUMzQixPQUFPLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQztTQUNsQjthQUFNO1lBQ0wsTUFBTSxJQUFJLFVBQVUsQ0FBQyxpQ0FBaUMsTUFBTSxDQUFDLE1BQU0sRUFBRSxDQUFDLENBQUM7U0FDeEU7S0FDRjtTQUFNO1FBQ0wsT0FBTyxNQUFlLENBQUM7S0FDeEI7QUFDSCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTggR29vZ2xlIExMQ1xuICpcbiAqIFVzZSBvZiB0aGlzIHNvdXJjZSBjb2RlIGlzIGdvdmVybmVkIGJ5IGFuIE1JVC1zdHlsZVxuICogbGljZW5zZSB0aGF0IGNhbiBiZSBmb3VuZCBpbiB0aGUgTElDRU5TRSBmaWxlIG9yIGF0XG4gKiBodHRwczovL29wZW5zb3VyY2Uub3JnL2xpY2Vuc2VzL01JVC5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuLyogT3JpZ2luYWwgc291cmNlOiB1dGlscy9nZW5lcmljX3V0aWxzLnB5ICovXG5cbmltcG9ydCB7VGVuc29yfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuaW1wb3J0IHtWYWx1ZUVycm9yfSBmcm9tICcuLi9lcnJvcnMnO1xuaW1wb3J0IHtTaGFwZX0gZnJvbSAnLi4va2VyYXNfZm9ybWF0L2NvbW1vbic7XG4vLyB0c2xpbnQ6ZW5hYmxlXG5cbi8qKlxuICogRGV0ZXJtaW5lIHdoZXRoZXIgdGhlIGlucHV0IGlzIGFuIEFycmF5IG9mIFNoYXBlcy5cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGlzQXJyYXlPZlNoYXBlcyh4OiBTaGFwZXxTaGFwZVtdKTogYm9vbGVhbiB7XG4gIHJldHVybiBBcnJheS5pc0FycmF5KHgpICYmIEFycmF5LmlzQXJyYXkoeFswXSk7XG59XG5cbi8qKlxuICogU3BlY2lhbCBjYXNlIG9mIG5vcm1hbGl6aW5nIHNoYXBlcyB0byBsaXN0cy5cbiAqXG4gKiBAcGFyYW0geCBBIHNoYXBlIG9yIGxpc3Qgb2Ygc2hhcGVzIHRvIG5vcm1hbGl6ZSBpbnRvIGEgbGlzdCBvZiBTaGFwZXMuXG4gKiBAcmV0dXJuIEEgbGlzdCBvZiBTaGFwZXMuXG4gKi9cbmV4cG9ydCBmdW5jdGlvbiBub3JtYWxpemVTaGFwZUxpc3QoeDogU2hhcGV8U2hhcGVbXSk6IFNoYXBlW10ge1xuICBpZiAoeC5sZW5ndGggPT09IDApIHtcbiAgICByZXR1cm4gW107XG4gIH1cbiAgaWYgKCFBcnJheS5pc0FycmF5KHhbMF0pKSB7XG4gICAgcmV0dXJuIFt4XSBhcyBTaGFwZVtdO1xuICB9XG4gIHJldHVybiB4IGFzIFNoYXBlW107XG59XG5cbi8qKlxuICogSGVscGVyIGZ1bmN0aW9uIHRvIG9idGFpbiBleGFjdGx5IG9uZSBUZW5zb3IuXG4gKiBAcGFyYW0geHM6IEEgc2luZ2xlIGB0Zi5UZW5zb3JgIG9yIGFuIGBBcnJheWAgb2YgYHRmLlRlbnNvcmBzLlxuICogQHJldHVybiBBIHNpbmdsZSBgdGYuVGVuc29yYC4gSWYgYHhzYCBpcyBhbiBgQXJyYXlgLCByZXR1cm4gdGhlIGZpcnN0IG9uZS5cbiAqIEB0aHJvd3MgVmFsdWVFcnJvcjogSWYgYHhzYCBpcyBhbiBgQXJyYXlgIGFuZCBpdHMgbGVuZ3RoIGlzIG5vdCAxLlxuICovXG5leHBvcnQgZnVuY3Rpb24gZ2V0RXhhY3RseU9uZVRlbnNvcih4czogVGVuc29yfFRlbnNvcltdKTogVGVuc29yIHtcbiAgbGV0IHg6IFRlbnNvcjtcbiAgaWYgKEFycmF5LmlzQXJyYXkoeHMpKSB7XG4gICAgaWYgKHhzLmxlbmd0aCAhPT0gMSkge1xuICAgICAgdGhyb3cgbmV3IFZhbHVlRXJyb3IoYEV4cGVjdGVkIFRlbnNvciBsZW5ndGggdG8gYmUgMTsgZ290ICR7eHMubGVuZ3RofWApO1xuICAgIH1cbiAgICB4ID0geHNbMF07XG4gIH0gZWxzZSB7XG4gICAgeCA9IHhzO1xuICB9XG4gIHJldHVybiB4O1xufVxuXG4vKipcbiAqIEhlbHBlciBmdW5jdGlvbiB0byBvYnRhaW4gZXhhY3RseSBvbiBpbnN0YW5jZSBvZiBTaGFwZS5cbiAqXG4gKiBAcGFyYW0gc2hhcGVzIElucHV0IHNpbmdsZSBgU2hhcGVgIG9yIEFycmF5IG9mIGBTaGFwZWBzLlxuICogQHJldHVybnMgSWYgaW5wdXQgaXMgYSBzaW5nbGUgYFNoYXBlYCwgcmV0dXJuIGl0IHVuY2hhbmdlZC4gSWYgdGhlIGlucHV0IGlzXG4gKiAgIGFuIGBBcnJheWAgY29udGFpbmluZyBleGFjdGx5IG9uZSBpbnN0YW5jZSBvZiBgU2hhcGVgLCByZXR1cm4gdGhlIGluc3RhbmNlLlxuICogICBPdGhlcndpc2UsIHRocm93IGEgYFZhbHVlRXJyb3JgLlxuICogQHRocm93cyBWYWx1ZUVycm9yOiBJZiBpbnB1dCBpcyBhbiBgQXJyYXlgIG9mIGBTaGFwZWBzLCBhbmQgaXRzIGxlbmd0aCBpcyBub3RcbiAqICAgMS5cbiAqL1xuZXhwb3J0IGZ1bmN0aW9uIGdldEV4YWN0bHlPbmVTaGFwZShzaGFwZXM6IFNoYXBlfFNoYXBlW10pOiBTaGFwZSB7XG4gIGlmIChBcnJheS5pc0FycmF5KHNoYXBlcykgJiYgQXJyYXkuaXNBcnJheShzaGFwZXNbMF0pKSB7XG4gICAgaWYgKHNoYXBlcy5sZW5ndGggPT09IDEpIHtcbiAgICAgIHNoYXBlcyA9IHNoYXBlcyBhcyBTaGFwZVtdO1xuICAgICAgcmV0dXJuIHNoYXBlc1swXTtcbiAgICB9IGVsc2Uge1xuICAgICAgdGhyb3cgbmV3IFZhbHVlRXJyb3IoYEV4cGVjdGVkIGV4YWN0bHkgMSBTaGFwZTsgZ290ICR7c2hhcGVzLmxlbmd0aH1gKTtcbiAgICB9XG4gIH0gZWxzZSB7XG4gICAgcmV0dXJuIHNoYXBlcyBhcyBTaGFwZTtcbiAgfVxufVxuIl19