@thi.ng/tensors
Version:
0D/1D/2D/3D/4D tensors with extensible polymorphic operations and customizable storage
120 lines (119 loc) • 2.87 kB
JavaScript
import { identity } from "@thi.ng/api/fn";
import { broadcast } from "./broadcast.js";
const defOpRTT = (rfn, init, complete = identity, useBroadcast = true) => {
const f0 = (a, b) => complete(rfn(init(), a.data, b.data, a.offset, b.offset), a, b);
const f1 = (a, b) => {
const {
data: adata,
offset: oa,
shape: [sx],
stride: [txa]
} = a;
const {
data: bdata,
offset: ob,
stride: [txb]
} = b;
let res = init();
for (let x = 0; x < sx; x++) {
res = rfn(res, adata, bdata, oa + x * txa, ob + x * txb);
}
return complete(res, a, b);
};
const f2 = (a, b) => {
const {
data: adata,
offset: oa,
shape: [sx, sy],
stride: [txa, tya]
} = a;
const {
data: bdata,
offset: ob,
stride: [txb, tyb]
} = b;
let res = init();
let oax, obx;
for (let x = 0; x < sx; x++) {
oax = oa + x * txa;
obx = ob + x * txb;
for (let y = 0; y < sy; y++) {
res = rfn(res, adata, bdata, oax + y * tya, obx + y * tyb);
}
}
return complete(res, a, b);
};
const f3 = (a, b) => {
const {
data: adata,
offset: oa,
shape: [sx, sy, sz],
stride: [txa, tya, tza]
} = a;
const {
data: bdata,
offset: ob,
stride: [txb, tyb, tzb]
} = b;
let res = init();
let oax, obx, oay, oby;
for (let x = 0; x < sx; x++) {
oax = oa + x * txa;
obx = ob + x * txb;
for (let y = 0; y < sy; y++) {
oay = oax + y * tya;
oby = obx + y * tyb;
for (let z = 0; z < sz; z++) {
res = rfn(res, adata, bdata, oay + z * tza, oby + z * tzb);
}
}
}
return complete(res, a, b);
};
const f4 = (a, b) => {
const {
data: adata,
offset: oa,
shape: [sx, sy, sz, sw],
stride: [txa, tya, tza, twa]
} = a;
const {
data: bdata,
offset: ob,
stride: [txb, tyb, tzb, twb]
} = b;
let res = init();
let oax, obx, oay, oby, oaz, obz;
for (let x = 0; x < sx; x++) {
oax = oa + x * txa;
obx = ob + x * txb;
for (let y = 0; y < sy; y++) {
oay = oax + y * tya;
oby = obx + y * tyb;
for (let z = 0; z < sz; z++) {
oaz = oay + z * tza;
obz = oby + z * tzb;
for (let w = 0; w < sw; w++) {
res = rfn(
res,
adata,
bdata,
oaz + w * twa,
obz + w * twb
);
}
}
}
}
return complete(res, a, b);
};
const impls = [f0, f1, f2, f3, f4];
const wrapper = useBroadcast ? (a, b) => {
const { shape, a: $a, b: $b } = broadcast(a, b);
return impls[shape.length]($a, $b);
} : (a, b) => impls[a.dim](a, b);
return wrapper;
};
export {
defOpRTT
};