@dxzmpk/js-algorithms-data-structures
Version:
Algorithms and data-structures implemented on JavaScript
310 lines (276 loc) • 7.46 kB
JavaScript
/**
* @typedef {number} Cell
* @typedef {Cell[][]|Cell[][][]} Matrix
* @typedef {number[]} Shape
* @typedef {number[]} CellIndices
*/
/**
* Gets the matrix's shape.
*
* @param {Matrix} m
* @returns {Shape}
*/
export const shape = (m) => {
const shapes = [];
let dimension = m;
while (dimension && Array.isArray(dimension)) {
shapes.push(dimension.length);
dimension = (dimension.length && [...dimension][0]) || null;
}
return shapes;
};
/**
* Checks if matrix has a correct type.
*
* @param {Matrix} m
* @throws {Error}
*/
const validateType = (m) => {
if (
!m
|| !Array.isArray(m)
|| !Array.isArray(m[0])
) {
throw new Error('Invalid matrix format');
}
};
/**
* Checks if matrix is two dimensional.
*
* @param {Matrix} m
* @throws {Error}
*/
const validate2D = (m) => {
validateType(m);
const aShape = shape(m);
if (aShape.length !== 2) {
throw new Error('Matrix is not of 2D shape');
}
};
/**
* Validates that matrices are of the same shape.
*
* @param {Matrix} a
* @param {Matrix} b
* @trows {Error}
*/
export const validateSameShape = (a, b) => {
validateType(a);
validateType(b);
const aShape = shape(a);
const bShape = shape(b);
if (aShape.length !== bShape.length) {
throw new Error('Matrices have different dimensions');
}
while (aShape.length && bShape.length) {
if (aShape.pop() !== bShape.pop()) {
throw new Error('Matrices have different shapes');
}
}
};
/**
* Generates the matrix of specific shape with specific values.
*
* @param {Shape} mShape - the shape of the matrix to generate
* @param {function({CellIndex}): Cell} fill - cell values of a generated matrix.
* @returns {Matrix}
*/
export const generate = (mShape, fill) => {
/**
* Generates the matrix recursively.
*
* @param {Shape} recShape - the shape of the matrix to generate
* @param {CellIndices} recIndices
* @returns {Matrix}
*/
const generateRecursively = (recShape, recIndices) => {
if (recShape.length === 1) {
return Array(recShape[0])
.fill(null)
.map((cellValue, cellIndex) => fill([...recIndices, cellIndex]));
}
const m = [];
for (let i = 0; i < recShape[0]; i += 1) {
m.push(generateRecursively(recShape.slice(1), [...recIndices, i]));
}
return m;
};
return generateRecursively(mShape, []);
};
/**
* Generates the matrix of zeros of specified shape.
*
* @param {Shape} mShape - shape of the matrix
* @returns {Matrix}
*/
export const zeros = (mShape) => {
return generate(mShape, () => 0);
};
/**
* @param {Matrix} a
* @param {Matrix} b
* @return Matrix
* @throws {Error}
*/
export const dot = (a, b) => {
// Validate inputs.
validate2D(a);
validate2D(b);
// Check dimensions.
const aShape = shape(a);
const bShape = shape(b);
if (aShape[1] !== bShape[0]) {
throw new Error('Matrices have incompatible shape for multiplication');
}
// Perform matrix multiplication.
const outputShape = [aShape[0], bShape[1]];
const c = zeros(outputShape);
for (let bCol = 0; bCol < b[0].length; bCol += 1) {
for (let aRow = 0; aRow < a.length; aRow += 1) {
let cellSum = 0;
for (let aCol = 0; aCol < a[aRow].length; aCol += 1) {
cellSum += a[aRow][aCol] * b[aCol][bCol];
}
c[aRow][bCol] = cellSum;
}
}
return c;
};
/**
* Transposes the matrix.
*
* @param {Matrix} m
* @returns Matrix
* @throws {Error}
*/
export const t = (m) => {
validate2D(m);
const mShape = shape(m);
const transposed = zeros([mShape[1], mShape[0]]);
for (let row = 0; row < m.length; row += 1) {
for (let col = 0; col < m[0].length; col += 1) {
transposed[col][row] = m[row][col];
}
}
return transposed;
};
/**
* Traverses the matrix.
*
* @param {Matrix} m
* @param {function(indices: CellIndices, c: Cell)} visit
*/
export const walk = (m, visit) => {
/**
* Traverses the matrix recursively.
*
* @param {Matrix} recM
* @param {CellIndices} cellIndices
* @return {Matrix}
*/
const recWalk = (recM, cellIndices) => {
const recMShape = shape(recM);
if (recMShape.length === 1) {
for (let i = 0; i < recM.length; i += 1) {
visit([...cellIndices, i], recM[i]);
}
}
for (let i = 0; i < recM.length; i += 1) {
recWalk(recM[i], [...cellIndices, i]);
}
};
recWalk(m, []);
};
/**
* Gets the matrix cell value at specific index.
*
* @param {Matrix} m - Matrix that contains the cell that needs to be updated
* @param {CellIndices} cellIndices - Array of cell indices
* @return {Cell}
*/
export const getCellAtIndex = (m, cellIndices) => {
// We start from the row at specific index.
let cell = m[cellIndices[0]];
// Going deeper into the next dimensions but not to the last one to preserve
// the pointer to the last dimension array.
for (let dimIdx = 1; dimIdx < cellIndices.length - 1; dimIdx += 1) {
cell = cell[cellIndices[dimIdx]];
}
// At this moment the cell variable points to the array at the last needed dimension.
return cell[cellIndices[cellIndices.length - 1]];
};
/**
* Update the matrix cell at specific index.
*
* @param {Matrix} m - Matrix that contains the cell that needs to be updated
* @param {CellIndices} cellIndices - Array of cell indices
* @param {Cell} cellValue - New cell value
*/
export const updateCellAtIndex = (m, cellIndices, cellValue) => {
// We start from the row at specific index.
let cell = m[cellIndices[0]];
// Going deeper into the next dimensions but not to the last one to preserve
// the pointer to the last dimension array.
for (let dimIdx = 1; dimIdx < cellIndices.length - 1; dimIdx += 1) {
cell = cell[cellIndices[dimIdx]];
}
// At this moment the cell variable points to the array at the last needed dimension.
cell[cellIndices[cellIndices.length - 1]] = cellValue;
};
/**
* Adds two matrices element-wise.
*
* @param {Matrix} a
* @param {Matrix} b
* @return {Matrix}
*/
export const add = (a, b) => {
validateSameShape(a, b);
const result = zeros(shape(a));
walk(a, (cellIndices, cellValue) => {
updateCellAtIndex(result, cellIndices, cellValue);
});
walk(b, (cellIndices, cellValue) => {
const currentCellValue = getCellAtIndex(result, cellIndices);
updateCellAtIndex(result, cellIndices, currentCellValue + cellValue);
});
return result;
};
/**
* Multiplies two matrices element-wise.
*
* @param {Matrix} a
* @param {Matrix} b
* @return {Matrix}
*/
export const mul = (a, b) => {
validateSameShape(a, b);
const result = zeros(shape(a));
walk(a, (cellIndices, cellValue) => {
updateCellAtIndex(result, cellIndices, cellValue);
});
walk(b, (cellIndices, cellValue) => {
const currentCellValue = getCellAtIndex(result, cellIndices);
updateCellAtIndex(result, cellIndices, currentCellValue * cellValue);
});
return result;
};
/**
* Subtract two matrices element-wise.
*
* @param {Matrix} a
* @param {Matrix} b
* @return {Matrix}
*/
export const sub = (a, b) => {
validateSameShape(a, b);
const result = zeros(shape(a));
walk(a, (cellIndices, cellValue) => {
updateCellAtIndex(result, cellIndices, cellValue);
});
walk(b, (cellIndices, cellValue) => {
const currentCellValue = getCellAtIndex(result, cellIndices);
updateCellAtIndex(result, cellIndices, currentCellValue - cellValue);
});
return result;
};