UNPKG

@thi.ng/tensors

Version:

0D/1D/2D/3D/4D tensors with extensible polymorphic operations and customizable storage

55 lines (54 loc) 1.35 kB
import { illegalArgs } from "@thi.ng/errors"; import { equals } from "@thi.ng/vectors"; import { max } from "@thi.ng/vectors/max"; const broadcast = (a, b) => { if (equals(a.shape, b.shape)) return { shape: a.shape, a, b }; const ashape = a.shape.slice(); const astride = a.stride.slice(); const bshape = b.shape.slice(); const bstride = b.stride.slice(); let da = a.dim; let db = b.dim; let bcastA = da < db; let bcastB = db < da; if (bcastA) { while (da < db) { ashape.unshift(1); astride.unshift(0); da++; } } else if (bcastB) { while (db < da) { bshape.unshift(1); bstride.unshift(0); db++; } } for (let i = 0; i < da; i++) { const sa = ashape[i]; const sb = bshape[i]; if (sa < sb) { if (sa > 1) __broadcastError(ashape, bshape); astride[i] = 0; bcastA = true; } else if (sb < sa) { if (sb > 1) __broadcastError(ashape, bshape); bstride[i] = 0; bcastB = true; } } const shape = max([], ashape, bshape); return { shape, a: bcastA ? a.broadcast(shape, astride) : a, b: bcastB ? b.broadcast(shape, bstride) : b }; }; const __broadcastError = (ashape, bshape) => illegalArgs( `incompatible shapes: ${JSON.stringify(ashape)} vs ${JSON.stringify( bshape )}` ); export { broadcast };