@tensorflow/tfjs-core
Version:
Hardware-accelerated JavaScript library for machine intelligence
421 lines • 12.2 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
function shuffle(array) {
var counter = array.length;
var temp = 0;
var index = 0;
while (counter > 0) {
index = (Math.random() * counter) | 0;
counter--;
temp = array[counter];
array[counter] = array[index];
array[index] = temp;
}
}
exports.shuffle = shuffle;
function clamp(min, x, max) {
return Math.max(min, Math.min(x, max));
}
exports.clamp = clamp;
function randUniform(a, b) {
return Math.random() * (b - a) + a;
}
exports.randUniform = randUniform;
function distSquared(a, b) {
var result = 0;
for (var i = 0; i < a.length; i++) {
var diff = Number(a[i]) - Number(b[i]);
result += diff * diff;
}
return result;
}
exports.distSquared = distSquared;
function assert(expr, msg) {
if (!expr) {
throw new Error(typeof msg === 'string' ? msg : msg());
}
}
exports.assert = assert;
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) {
if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; }
assert(arraysEqual(shapeA, shapeB), errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match"));
}
exports.assertShapesMatch = assertShapesMatch;
function assertNonNull(a) {
assert(a != null, "The input to the tensor constructor must be a non-null value.");
}
exports.assertNonNull = assertNonNull;
function flatten(arr, ret) {
if (ret === void 0) { ret = []; }
if (Array.isArray(arr)) {
for (var i = 0; i < arr.length; ++i) {
flatten(arr[i], ret);
}
}
else {
ret.push(arr);
}
return ret;
}
exports.flatten = flatten;
function inferShape(val) {
var firstElem = val;
if (isTypedArray(val)) {
return [val.length];
}
if (!Array.isArray(val)) {
return [];
}
var shape = [];
while (firstElem instanceof Array) {
shape.push(firstElem.length);
firstElem = firstElem[0];
}
if (val instanceof Array) {
deepAssertShapeConsistency(val, shape, []);
}
return shape;
}
exports.inferShape = inferShape;
function deepAssertShapeConsistency(val, shape, indices) {
indices = indices || [];
if (!(val instanceof Array)) {
assert(shape.length === 0, function () { return "Element arr[" + indices.join('][') + "] is a primitive, " +
("but should be an array of " + shape[0] + " elements"); });
return;
}
assert(shape.length > 0, function () { return "Element arr[" + indices.join('][') + "] should be a primitive, " +
("but is an array of " + val.length + " elements"); });
assert(val.length === shape[0], function () { return "Element arr[" + indices.join('][') + "] should have " + shape[0] + " " +
("elements, but has " + val.length + " elements"); });
var subShape = shape.slice(1);
for (var i = 0; i < val.length; ++i) {
deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
}
}
function sizeFromShape(shape) {
if (shape.length === 0) {
return 1;
}
var size = shape[0];
for (var i = 1; i < shape.length; i++) {
size *= shape[i];
}
return size;
}
exports.sizeFromShape = sizeFromShape;
function isScalarShape(shape) {
return shape.length === 0;
}
exports.isScalarShape = isScalarShape;
function arraysEqual(n1, n2) {
if (n1.length !== n2.length) {
return false;
}
for (var i = 0; i < n1.length; i++) {
if (n1[i] !== n2[i]) {
return false;
}
}
return true;
}
exports.arraysEqual = arraysEqual;
function isInt(a) {
return a % 1 === 0;
}
exports.isInt = isInt;
function tanh(x) {
if (Math.tanh != null) {
return Math.tanh(x);
}
if (x === Infinity) {
return 1;
}
else if (x === -Infinity) {
return -1;
}
else {
var e2x = Math.exp(2 * x);
return (e2x - 1) / (e2x + 1);
}
}
exports.tanh = tanh;
function sizeToSquarishShape(size) {
for (var a = Math.floor(Math.sqrt(size)); a > 1; --a) {
if (size % a === 0) {
return [a, size / a];
}
}
return [1, size];
}
exports.sizeToSquarishShape = sizeToSquarishShape;
function createShuffledIndices(n) {
var shuffledIndices = new Uint32Array(n);
for (var i = 0; i < n; ++i) {
shuffledIndices[i] = i;
}
shuffle(shuffledIndices);
return shuffledIndices;
}
exports.createShuffledIndices = createShuffledIndices;
function rightPad(a, size) {
if (size <= a.length) {
return a;
}
return a + ' '.repeat(size - a.length);
}
exports.rightPad = rightPad;
function repeatedTry(checkFn, delayFn, maxCounter) {
if (delayFn === void 0) { delayFn = function (counter) { return 0; }; }
return new Promise(function (resolve, reject) {
var tryCount = 0;
var tryFn = function () {
if (checkFn()) {
resolve();
return;
}
tryCount++;
var nextBackoff = delayFn(tryCount);
if (maxCounter != null && tryCount >= maxCounter) {
reject();
return;
}
setTimeout(tryFn, nextBackoff);
};
tryFn();
});
}
exports.repeatedTry = repeatedTry;
function inferFromImplicitShape(shape, size) {
var shapeProd = 1;
var implicitIdx = -1;
for (var i = 0; i < shape.length; ++i) {
if (shape[i] >= 0) {
shapeProd *= shape[i];
}
else if (shape[i] === -1) {
if (implicitIdx !== -1) {
throw Error("Shapes can only have 1 implicit size. " +
("Found -1 at dim " + implicitIdx + " and dim " + i));
}
implicitIdx = i;
}
else if (shape[i] < 0) {
throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i);
}
}
if (implicitIdx === -1) {
if (size > 0 && size !== shapeProd) {
throw Error("Size(" + size + ") must match the product of shape " + shape);
}
return shape;
}
if (shapeProd === 0) {
throw Error("Cannot infer the missing size in [" + shape + "] when " +
"there are 0 elements");
}
if (size % shapeProd !== 0) {
throw Error("The implicit shape can't be a fractional number. " +
("Got " + size + " / " + shapeProd));
}
var newShape = shape.slice();
newShape[implicitIdx] = size / shapeProd;
return newShape;
}
exports.inferFromImplicitShape = inferFromImplicitShape;
function squeezeShape(shape, axis) {
var newShape = [];
var keptDims = [];
var j = 0;
for (var i = 0; i < shape.length; ++i) {
if (axis != null) {
if (axis[j] === i && shape[i] !== 1) {
throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1");
}
if ((axis[j] == null || axis[j] > i) && shape[i] === 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
if (axis[j] <= i) {
j++;
}
}
if (shape[i] !== 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
}
return { newShape: newShape, keptDims: keptDims };
}
exports.squeezeShape = squeezeShape;
function getTypedArrayFromDType(dtype, size) {
var values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
}
else if (dtype === 'int32') {
values = new Int32Array(size);
}
else if (dtype === 'bool') {
values = new Uint8Array(size);
}
else {
throw new Error("Unknown data type " + dtype);
}
return values;
}
exports.getTypedArrayFromDType = getTypedArrayFromDType;
function checkComputationForNaN(vals, dtype, name) {
if (dtype !== 'float32') {
return;
}
for (var i = 0; i < vals.length; i++) {
if (isNaN(vals[i])) {
throw Error("The result of the '" + name + "' has NaNs.");
}
}
}
exports.checkComputationForNaN = checkComputationForNaN;
function checkConversionForNaN(vals, dtype) {
if (dtype === 'float32') {
return;
}
for (var i = 0; i < vals.length; i++) {
if (isNaN(vals[i])) {
throw Error("NaN is not a valid value for dtype: '" + dtype + "'.");
}
}
}
exports.checkConversionForNaN = checkConversionForNaN;
function hasEncodingLoss(oldType, newType) {
if (newType === 'float32') {
return false;
}
if (newType === 'int32' && oldType !== 'float32') {
return false;
}
if (newType === 'bool' && oldType === 'bool') {
return false;
}
return true;
}
exports.hasEncodingLoss = hasEncodingLoss;
function copyTypedArray(array, dtype, debugMode) {
if (dtype == null || dtype === 'float32') {
return new Float32Array(array);
}
else if (dtype === 'int32') {
if (debugMode) {
checkConversionForNaN(array, dtype);
}
return new Int32Array(array);
}
else if (dtype === 'bool') {
var bool = new Uint8Array(array.length);
for (var i = 0; i < bool.length; ++i) {
if (Math.round(array[i]) !== 0) {
bool[i] = 1;
}
}
return bool;
}
else {
throw new Error("Unknown data type " + dtype);
}
}
function isTypedArray(a) {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array;
}
exports.isTypedArray = isTypedArray;
function bytesPerElement(dtype) {
if (dtype === 'float32' || dtype === 'int32') {
return 4;
}
else if (dtype === 'bool') {
return 1;
}
else {
throw new Error("Unknown dtype " + dtype);
}
}
exports.bytesPerElement = bytesPerElement;
function isFunction(f) {
return !!(f && f.constructor && f.call && f.apply);
}
exports.isFunction = isFunction;
function nearestDivisor(size, start) {
for (var i = start; i < size; ++i) {
if (size % i === 0) {
return i;
}
}
return size;
}
exports.nearestDivisor = nearestDivisor;
function computeStrides(shape) {
var rank = shape.length;
if (rank < 2) {
return [];
}
var strides = new Array(rank - 1);
strides[rank - 2] = shape[rank - 1];
for (var i = rank - 3; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
exports.computeStrides = computeStrides;
function toTypedArray(a, dtype, debugMode) {
if (noConversionNeeded(a, dtype)) {
return a;
}
if (Array.isArray(a)) {
a = flatten(a);
}
return copyTypedArray(a, dtype, debugMode);
}
exports.toTypedArray = toTypedArray;
function noConversionNeeded(a, dtype) {
return (a instanceof Float32Array && dtype === 'float32') ||
(a instanceof Int32Array && dtype === 'int32') ||
(a instanceof Uint8Array && dtype === 'bool');
}
function makeOnesTypedArray(size, dtype) {
var array = makeZerosTypedArray(size, dtype);
for (var i = 0; i < array.length; i++) {
array[i] = 1;
}
return array;
}
exports.makeOnesTypedArray = makeOnesTypedArray;
function makeZerosTypedArray(size, dtype) {
if (dtype == null || dtype === 'float32') {
return new Float32Array(size);
}
else if (dtype === 'int32') {
return new Int32Array(size);
}
else if (dtype === 'bool') {
return new Uint8Array(size);
}
else {
throw new Error("Unknown data type " + dtype);
}
}
exports.makeZerosTypedArray = makeZerosTypedArray;
function now() {
if (typeof performance !== 'undefined') {
return performance.now();
}
else if (typeof process !== 'undefined') {
var time = process.hrtime();
return time[0] * 1000 + time[1] / 1000000;
}
else {
throw new Error('Can not measure time in this environment. You should run tf.js ' +
'in the browser or in Node.js');
}
}
exports.now = now;
//# sourceMappingURL=util.js.map