UNPKG

@thi.ng/tensors

Version:

0D/1D/2D/3D/4D tensors with extensible polymorphic operations and customizable storage

120 lines (119 loc) 2.87 kB
import { identity } from "@thi.ng/api/fn"; import { broadcast } from "./broadcast.js"; const defOpRTT = (rfn, init, complete = identity, useBroadcast = true) => { const f0 = (a, b) => complete(rfn(init(), a.data, b.data, a.offset, b.offset), a, b); const f1 = (a, b) => { const { data: adata, offset: oa, shape: [sx], stride: [txa] } = a; const { data: bdata, offset: ob, stride: [txb] } = b; let res = init(); for (let x = 0; x < sx; x++) { res = rfn(res, adata, bdata, oa + x * txa, ob + x * txb); } return complete(res, a, b); }; const f2 = (a, b) => { const { data: adata, offset: oa, shape: [sx, sy], stride: [txa, tya] } = a; const { data: bdata, offset: ob, stride: [txb, tyb] } = b; let res = init(); let oax, obx; for (let x = 0; x < sx; x++) { oax = oa + x * txa; obx = ob + x * txb; for (let y = 0; y < sy; y++) { res = rfn(res, adata, bdata, oax + y * tya, obx + y * tyb); } } return complete(res, a, b); }; const f3 = (a, b) => { const { data: adata, offset: oa, shape: [sx, sy, sz], stride: [txa, tya, tza] } = a; const { data: bdata, offset: ob, stride: [txb, tyb, tzb] } = b; let res = init(); let oax, obx, oay, oby; for (let x = 0; x < sx; x++) { oax = oa + x * txa; obx = ob + x * txb; for (let y = 0; y < sy; y++) { oay = oax + y * tya; oby = obx + y * tyb; for (let z = 0; z < sz; z++) { res = rfn(res, adata, bdata, oay + z * tza, oby + z * tzb); } } } return complete(res, a, b); }; const f4 = (a, b) => { const { data: adata, offset: oa, shape: [sx, sy, sz, sw], stride: [txa, tya, tza, twa] } = a; const { data: bdata, offset: ob, stride: [txb, tyb, tzb, twb] } = b; let res = init(); let oax, obx, oay, oby, oaz, obz; for (let x = 0; x < sx; x++) { oax = oa + x * txa; obx = ob + x * txb; for (let y = 0; y < sy; y++) { oay = oax + y * tya; oby = obx + y * tyb; for (let z = 0; z < sz; z++) { oaz = oay + z * tza; obz = oby + z * tzb; for (let w = 0; w < sw; w++) { res = rfn( res, adata, bdata, oaz + w * twa, obz + w * twb ); } } } } return complete(res, a, b); }; const impls = [f0, f1, f2, f3, f4]; const wrapper = useBroadcast ? (a, b) => { const { shape, a: $a, b: $b } = broadcast(a, b); return impls[shape.length]($a, $b); } : (a, b) => impls[a.dim](a, b); return wrapper; }; export { defOpRTT };