@thi.ng/tensors
Version:
0D/1D/2D/3D/4D tensors with extensible polymorphic operations and customizable storage
55 lines (54 loc) • 1.35 kB
JavaScript
import { illegalArgs } from "@thi.ng/errors";
import { equals } from "@thi.ng/vectors";
import { max } from "@thi.ng/vectors/max";
const broadcast = (a, b) => {
if (equals(a.shape, b.shape)) return { shape: a.shape, a, b };
const ashape = a.shape.slice();
const astride = a.stride.slice();
const bshape = b.shape.slice();
const bstride = b.stride.slice();
let da = a.dim;
let db = b.dim;
let bcastA = da < db;
let bcastB = db < da;
if (bcastA) {
while (da < db) {
ashape.unshift(1);
astride.unshift(0);
da++;
}
} else if (bcastB) {
while (db < da) {
bshape.unshift(1);
bstride.unshift(0);
db++;
}
}
for (let i = 0; i < da; i++) {
const sa = ashape[i];
const sb = bshape[i];
if (sa < sb) {
if (sa > 1) __broadcastError(ashape, bshape);
astride[i] = 0;
bcastA = true;
} else if (sb < sa) {
if (sb > 1) __broadcastError(ashape, bshape);
bstride[i] = 0;
bcastB = true;
}
}
const shape = max([], ashape, bshape);
return {
shape,
a: bcastA ? a.broadcast(shape, astride) : a,
b: bcastB ? b.broadcast(shape, bstride) : b
};
};
const __broadcastError = (ashape, bshape) => illegalArgs(
`incompatible shapes: ${JSON.stringify(ashape)} vs ${JSON.stringify(
bshape
)}`
);
export {
broadcast
};