UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

107 lines 2.54 kB
export function getSize(shape, zeroSize = 0) { if (shape.length === 0) { return zeroSize; } let size = 1; for (let i = 0; i < shape.length; i += 1) { size *= shape[i]; } return size; } export function computeStrides(shape) { const rank = shape.length; if (rank === 0) { return []; } if (rank === 1) { if (shape[0] === 1) { return [0]; } else { return [1]; } } const strides = new Array(rank); strides[rank - 1] = 1; if (shape[rank - 1] === 1) { strides[rank - 1] = 0; } let lastStride = 1; for (let i = rank - 2; i >= 0; i -= 1) { lastStride = shape[i + 1] * lastStride; if (shape[i] === 1) { strides[i] = 0; } else { strides[i] = lastStride; } } return strides; } export function indexToPos(index, strides, shape) { let ix = 0; for (let i = 0; i < index.length; i += 1) { if (shape) { if (index[i] < 0 || (index[i] >= shape[i] && shape[i] !== 1)) { throw new Error('Invalid index'); } } ix += index[i] * strides[i]; } return ix; } export function posToIndex(pos, strides) { let res = pos; const rank = strides.length; const index = new Array(rank); for (let i = 0; i < index.length; i += 1) { index[i] = Math.floor(res / strides[i]); res %= strides[i]; } return index; } export function compareShapes(a, b) { if (a.length !== b.length) { return false; } for (let i = 0; i < a.length; i += 1) { if (a[i] !== b[i]) { return false; } } return true; } export function checkEquivShapes(a, b) { if (a.length !== b.length) { return false; } for (let i = 0; i < a.length; i += 1) { if (a[i] !== b[i] && a[i] !== 1 && b[i] !== 1) { return false; } } return true; } export function incrementIndex(index, shape) { for (let i = index.length - 1; i >= 0; i--) { index[i] += 1; if (index[i] >= shape[i]) { index[i] = 0; } else { break; } } } export function decrementIndex(index, shape) { for (let i = index.length - 1; i >= 0; i--) { index[i] -= 1; if (index[i] < 0) { index[i] = shape[i] - 1; } else { break; } } } //# sourceMappingURL=shape.js.map