@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
484 lines • 13.9 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
function shuffle(array) {
var counter = array.length;
var temp = 0;
var index = 0;
while (counter > 0) {
index = (Math.random() * counter) | 0;
counter--;
temp = array[counter];
array[counter] = array[index];
array[index] = temp;
}
}
exports.shuffle = shuffle;
function clamp(min, x, max) {
return Math.max(min, Math.min(x, max));
}
exports.clamp = clamp;
function nearestLargerEven(val) {
return val % 2 === 0 ? val : val + 1;
}
exports.nearestLargerEven = nearestLargerEven;
function sum(arr) {
var sum = 0;
for (var i = 0; i < arr.length; i++) {
sum += arr[i];
}
return sum;
}
exports.sum = sum;
function randUniform(a, b) {
var r = Math.random();
return (b * r) + (1 - r) * a;
}
exports.randUniform = randUniform;
function distSquared(a, b) {
var result = 0;
for (var i = 0; i < a.length; i++) {
var diff = Number(a[i]) - Number(b[i]);
result += diff * diff;
}
return result;
}
exports.distSquared = distSquared;
function assert(expr, msg) {
if (!expr) {
throw new Error(typeof msg === 'string' ? msg : msg());
}
}
exports.assert = assert;
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) {
if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; }
assert(arraysEqual(shapeA, shapeB), errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match"));
}
exports.assertShapesMatch = assertShapesMatch;
function assertNonNull(a) {
assert(a != null, "The input to the tensor constructor must be a non-null value.");
}
exports.assertNonNull = assertNonNull;
function flatten(arr, ret) {
if (ret === void 0) { ret = []; }
if (Array.isArray(arr) || isTypedArray(arr)) {
for (var i = 0; i < arr.length; ++i) {
flatten(arr[i], ret);
}
}
else {
ret.push(arr);
}
return ret;
}
exports.flatten = flatten;
function sizeFromShape(shape) {
if (shape.length === 0) {
return 1;
}
var size = shape[0];
for (var i = 1; i < shape.length; i++) {
size *= shape[i];
}
return size;
}
exports.sizeFromShape = sizeFromShape;
function isScalarShape(shape) {
return shape.length === 0;
}
exports.isScalarShape = isScalarShape;
function arraysEqual(n1, n2) {
if (n1 === n2) {
return true;
}
if (n1 == null || n2 == null) {
return false;
}
if (n1.length !== n2.length) {
return false;
}
for (var i = 0; i < n1.length; i++) {
if (n1[i] !== n2[i]) {
return false;
}
}
return true;
}
exports.arraysEqual = arraysEqual;
function isInt(a) {
return a % 1 === 0;
}
exports.isInt = isInt;
function tanh(x) {
if (Math.tanh != null) {
return Math.tanh(x);
}
if (x === Infinity) {
return 1;
}
else if (x === -Infinity) {
return -1;
}
else {
var e2x = Math.exp(2 * x);
return (e2x - 1) / (e2x + 1);
}
}
exports.tanh = tanh;
function sizeToSquarishShape(size) {
for (var a = Math.floor(Math.sqrt(size)); a > 1; --a) {
if (size % a === 0) {
return [a, size / a];
}
}
return [1, size];
}
exports.sizeToSquarishShape = sizeToSquarishShape;
function createShuffledIndices(n) {
var shuffledIndices = new Uint32Array(n);
for (var i = 0; i < n; ++i) {
shuffledIndices[i] = i;
}
shuffle(shuffledIndices);
return shuffledIndices;
}
exports.createShuffledIndices = createShuffledIndices;
function rightPad(a, size) {
if (size <= a.length) {
return a;
}
return a + ' '.repeat(size - a.length);
}
exports.rightPad = rightPad;
function repeatedTry(checkFn, delayFn, maxCounter) {
if (delayFn === void 0) { delayFn = function (counter) { return 0; }; }
return new Promise(function (resolve, reject) {
var tryCount = 0;
var tryFn = function () {
if (checkFn()) {
resolve();
return;
}
tryCount++;
var nextBackoff = delayFn(tryCount);
if (maxCounter != null && tryCount >= maxCounter) {
reject();
return;
}
setTimeout(tryFn, nextBackoff);
};
tryFn();
});
}
exports.repeatedTry = repeatedTry;
function inferFromImplicitShape(shape, size) {
var shapeProd = 1;
var implicitIdx = -1;
for (var i = 0; i < shape.length; ++i) {
if (shape[i] >= 0) {
shapeProd *= shape[i];
}
else if (shape[i] === -1) {
if (implicitIdx !== -1) {
throw Error("Shapes can only have 1 implicit size. " +
("Found -1 at dim " + implicitIdx + " and dim " + i));
}
implicitIdx = i;
}
else if (shape[i] < 0) {
throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i);
}
}
if (implicitIdx === -1) {
if (size > 0 && size !== shapeProd) {
throw Error("Size(" + size + ") must match the product of shape " + shape);
}
return shape;
}
if (shapeProd === 0) {
throw Error("Cannot infer the missing size in [" + shape + "] when " +
"there are 0 elements");
}
if (size % shapeProd !== 0) {
throw Error("The implicit shape can't be a fractional number. " +
("Got " + size + " / " + shapeProd));
}
var newShape = shape.slice();
newShape[implicitIdx] = size / shapeProd;
return newShape;
}
exports.inferFromImplicitShape = inferFromImplicitShape;
function parseAxisParam(axis, shape) {
var rank = shape.length;
axis = axis == null ? shape.map(function (s, i) { return i; }) : [].concat(axis);
assert(axis.every(function (ax) { return ax >= -rank && ax < rank; }), "All values in axis param must be in range [-" + rank + ", " + rank + ") but " +
("got axis " + axis));
assert(axis.every(function (ax) { return isInt(ax); }), "All values in axis param must be integers but " +
("got axis " + axis));
return axis.map(function (a) { return a < 0 ? rank + a : a; });
}
exports.parseAxisParam = parseAxisParam;
function squeezeShape(shape, axis) {
var newShape = [];
var keptDims = [];
var axes = axis == null ? null : parseAxisParam(axis, shape).sort();
var j = 0;
for (var i = 0; i < shape.length; ++i) {
if (axes != null) {
if (axes[j] === i && shape[i] !== 1) {
throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1");
}
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
if (axes[j] <= i) {
j++;
}
}
if (shape[i] !== 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
}
return { newShape: newShape, keptDims: keptDims };
}
exports.squeezeShape = squeezeShape;
function getTypedArrayFromDType(dtype, size) {
var values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
}
else if (dtype === 'int32') {
values = new Int32Array(size);
}
else if (dtype === 'bool') {
values = new Uint8Array(size);
}
else {
throw new Error("Unknown data type " + dtype);
}
return values;
}
exports.getTypedArrayFromDType = getTypedArrayFromDType;
function getArrayFromDType(dtype, size) {
var values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
}
else if (dtype === 'int32') {
values = new Int32Array(size);
}
else if (dtype === 'bool') {
values = new Uint8Array(size);
}
else if (dtype === 'string') {
values = new Array(size);
}
else {
throw new Error("Unknown data type " + dtype);
}
return values;
}
exports.getArrayFromDType = getArrayFromDType;
function checkComputationForErrors(vals, dtype, name) {
if (dtype !== 'float32') {
return;
}
for (var i = 0; i < vals.length; i++) {
var num = vals[i];
if (isNaN(num) || !isFinite(num)) {
throw Error("The result of the '" + name + "' is " + num + ".");
}
}
}
exports.checkComputationForErrors = checkComputationForErrors;
function checkConversionForErrors(vals, dtype) {
for (var i = 0; i < vals.length; i++) {
var num = vals[i];
if (isNaN(num) || !isFinite(num)) {
throw Error("A tensor of type " + dtype + " being uploaded contains " + num + ".");
}
}
}
exports.checkConversionForErrors = checkConversionForErrors;
function hasEncodingLoss(oldType, newType) {
if (newType === 'complex64') {
return false;
}
if (newType === 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'bool' && oldType === 'bool') {
return false;
}
return true;
}
exports.hasEncodingLoss = hasEncodingLoss;
function isTypedArray(a) {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array;
}
exports.isTypedArray = isTypedArray;
function bytesPerElement(dtype) {
if (dtype === 'float32' || dtype === 'int32') {
return 4;
}
else if (dtype === 'complex64') {
return 8;
}
else if (dtype === 'bool') {
return 1;
}
else {
throw new Error("Unknown dtype " + dtype);
}
}
exports.bytesPerElement = bytesPerElement;
function bytesFromStringArray(arr) {
if (arr == null) {
return 0;
}
var bytes = 0;
arr.forEach(function (x) { return bytes += x.length * 2; });
return bytes;
}
exports.bytesFromStringArray = bytesFromStringArray;
function isString(value) {
return typeof value === 'string' || value instanceof String;
}
exports.isString = isString;
function isBoolean(value) {
return typeof value === 'boolean';
}
exports.isBoolean = isBoolean;
function isNumber(value) {
return typeof value === 'number';
}
exports.isNumber = isNumber;
function inferDtype(values) {
if (values instanceof Array) {
return inferDtype(values[0]);
}
if (values instanceof Float32Array) {
return 'float32';
}
else if (values instanceof Int32Array || values instanceof Uint8Array) {
return 'int32';
}
else if (isNumber(values)) {
return 'float32';
}
else if (isString(values)) {
return 'string';
}
else if (isBoolean(values)) {
return 'bool';
}
return 'float32';
}
exports.inferDtype = inferDtype;
function isFunction(f) {
return !!(f && f.constructor && f.call && f.apply);
}
exports.isFunction = isFunction;
function nearestDivisor(size, start) {
for (var i = start; i < size; ++i) {
if (size % i === 0) {
return i;
}
}
return size;
}
exports.nearestDivisor = nearestDivisor;
function computeStrides(shape) {
var rank = shape.length;
if (rank < 2) {
return [];
}
var strides = new Array(rank - 1);
strides[rank - 2] = shape[rank - 1];
for (var i = rank - 3; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
exports.computeStrides = computeStrides;
function toTypedArray(a, dtype, debugMode) {
if (dtype === 'string') {
throw new Error('Cannot convert a string[] to a TypedArray');
}
if (Array.isArray(a)) {
a = flatten(a);
}
if (debugMode) {
checkConversionForErrors(a, dtype);
}
if (noConversionNeeded(a, dtype)) {
return a;
}
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(a);
}
else if (dtype === 'int32') {
return new Int32Array(a);
}
else if (dtype === 'bool') {
var bool = new Uint8Array(a.length);
for (var i = 0; i < bool.length; ++i) {
if (Math.round(a[i]) !== 0) {
bool[i] = 1;
}
}
return bool;
}
else {
throw new Error("Unknown data type " + dtype);
}
}
exports.toTypedArray = toTypedArray;
function noConversionNeeded(a, dtype) {
return (a instanceof Float32Array && dtype === 'float32') ||
(a instanceof Int32Array && dtype === 'int32') ||
(a instanceof Uint8Array && dtype === 'bool');
}
function makeOnesTypedArray(size, dtype) {
var array = makeZerosTypedArray(size, dtype);
for (var i = 0; i < array.length; i++) {
array[i] = 1;
}
return array;
}
exports.makeOnesTypedArray = makeOnesTypedArray;
function makeZerosTypedArray(size, dtype) {
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(size);
}
else if (dtype === 'int32') {
return new Int32Array(size);
}
else if (dtype === 'bool') {
return new Uint8Array(size);
}
else {
throw new Error("Unknown data type " + dtype);
}
}
exports.makeZerosTypedArray = makeZerosTypedArray;
function now() {
if (typeof performance !== 'undefined') {
return performance.now();
}
else if (typeof process !== 'undefined') {
var time = process.hrtime();
return time[0] * 1000 + time[1] / 1000000;
}
else {
throw new Error('Cannot measure time in this environment. You should run tf.js ' +
'in the browser or in Node.js');
}
}
exports.now = now;
//# sourceMappingURL=util.js.map