UNPKG

@thi.ng/tensors

Version:

0D/1D/2D/3D/4D tensors with extensible polymorphic operations and customizable storage

192 lines (191 loc) 4.28 kB
import { broadcast } from "./broadcast.js"; import { ensureShape } from "./errors.js"; import { tensor } from "./tensor.js"; const defOpTTT = (fn) => { const f0 = (out, a, b, c) => { out.data[out.offset] = fn( a.data[a.offset], b.data[b.offset], c.data[c.offset] ); return out; }; const f1 = (out, a, b, c) => { const { data: odata, offset: oo, stride: [txo] } = out; const { data: adata, offset: oa, shape: [sx], stride: [txa] } = a; const { data: bdata, offset: ob, stride: [txb] } = b; const { data: cdata, offset: oc, stride: [txc] } = c; for (let x = 0; x < sx; x++) { odata[oo + x * txo] = fn( adata[oa + x * txa], bdata[ob + x * txb], cdata[oc + x * txc] ); } return out; }; const f2 = (out, a, b, c) => { const { data: odata, offset: oo, stride: [txo, tyo] } = out; const { data: adata, offset: oa, shape: [sx, sy], stride: [txa, tya] } = a; const { data: bdata, offset: ob, stride: [txb, tyb] } = b; const { data: cdata, offset: oc, stride: [txc, tyc] } = c; let oox, oax, obx, ocx; for (let x = 0; x < sx; x++) { oox = oo + x * txo; oax = oa + x * txa; obx = ob + x * txb; ocx = oc + x * txc; for (let y = 0; y < sy; y++) { odata[oox + y * tyo] = fn( adata[oax + y * tya], bdata[obx + y * tyb], cdata[ocx + y * tyc] ); } } return out; }; const f3 = (out, a, b, c) => { const { data: odata, offset: oo, stride: [txo, tyo, tzo] } = out; const { data: adata, offset: oa, shape: [sx, sy, sz], stride: [txa, tya, tza] } = a; const { data: bdata, offset: ob, stride: [txb, tyb, tzb] } = b; const { data: cdata, offset: oc, stride: [txc, tyc, tzc] } = c; let oox, oax, obx, ocx, ooy, oay, oby, ocy; for (let x = 0; x < sx; x++) { oox = oo + x * txo; oax = oa + x * txa; obx = ob + x * txb; ocx = oc + x * txc; for (let y = 0; y < sy; y++) { ooy = oox + y * tyo; oay = oax + y * tya; oby = obx + y * tyb; ocy = ocx + y * tyc; for (let z = 0; z < sz; z++) { odata[ooy + z * tzo] = fn( adata[oay + z * tza], bdata[oby + z * tzb], cdata[ocy + z * tzc] ); } } } return out; }; const f4 = (out, a, b, c) => { const { data: odata, offset: oo, stride: [txo, tyo, tzo, two] } = out; const { data: adata, offset: oa, shape: [sx, sy, sz, sw], stride: [txa, tya, tza, twa] } = a; const { data: bdata, offset: ob, stride: [txb, tyb, tzb, twb] } = b; const { data: cdata, offset: oc, stride: [txc, tyc, tzc, twc] } = c; let oox, oax, obx, ocx, ooy, oay, oby, ocy, ooz, oaz, obz, ocz; for (let x = 0; x < sx; x++) { oox = oo + x * txo; oax = oa + x * txa; obx = ob + x * txb; ocx = oc + x * txc; for (let y = 0; y < sy; y++) { ooy = oox + y * tyo; oay = oax + y * tya; oby = obx + y * tyb; ocy = ocx + y * tyc; for (let z = 0; z < sz; z++) { ooz = ooy + z * tzo; oaz = oay + z * tza; obz = oby + z * tzb; ocz = ocy + z * tzc; for (let w = 0; w < sw; w++) { odata[ooz + w * two] = fn( adata[oaz + w * twa], bdata[obz + w * twb], cdata[ocz + w * twc] ); } } } } return out; }; const impls = [f0, f1, f2, f3, f4]; const wrapper = (out, a, b, c) => { const { a: $a1, b: $b } = broadcast(a, b); const { shape, a: $a2, b: $c } = broadcast($a1, c); if (out) { ensureShape(out, shape); } else { out = tensor(a.type, shape, { storage: a.storage }); } return impls[shape.length](out, $a2, $b, $c); }; return wrapper; }; export { defOpTTT };