UNPKG

@thi.ng/tensors

Version:

0D/1D/2D/3D/4D tensors with extensible polymorphic operations and customizable storage

164 lines (163 loc) 4.17 kB
import { top } from "./top.js"; const applyKernel1 = (out, a, { init, reduce, complete, shape: [sxk] }, pad = true) => { !out && (out = a.empty()); const { data: odata, offset: oo, stride: [txo] } = out; const { data: adata, offset: oa, shape: [sxa], stride: [txa] } = a; const sxk2 = sxk >> 1; const maxx = sxa - 1; const repeat = pad === true; let x, xx, i, acc, maskx; for (x = 0; x < sxa; x++) { for (acc = init(), i = 0; i < sxk; i++) { xx = x + i - sxk2; if (xx < 0) { maskx = false; xx = 0; } else if (xx > maxx) { maskx = false; xx = maxx; } else maskx = true; acc = reduce(acc, maskx || repeat ? adata[oa + xx * txa] : pad, i); } odata[oo + x * txo] = complete(acc); } return out; }; const applyKernel2 = (out, a, { init, reduce, complete, shape: [sxk, syk] }, pad = true) => { !out && (out = a.empty()); const { data: odata, offset: oo, stride: [txo, tyo] } = out; const { data: adata, offset: oa, shape: [sxa, sya], stride: [txa, tya] } = a; const sxk2 = sxk >> 1; const syk2 = syk >> 1; const maxx = sxa - 1; const maxy = sya - 1; const repeat = pad === true; let x, xx, y, yy, oox, oax, i, j, acc, maskx, masky; for (x = 0; x < sxa; x++) { oox = oo + x * txo; for (y = 0; y < sya; y++) { for (acc = init(), i = 0; i < sxk; i++) { xx = x + i - sxk2; if (xx < 0) { maskx = false; xx = 0; } else if (xx > maxx) { maskx = false; xx = maxx; } else maskx = true; oax = oa + xx * txa; for (j = 0; j < syk; j++) { yy = y + j - syk2; if (yy < 0) { masky = false; yy = 0; } else if (yy > maxy) { masky = false; yy = maxy; } else masky = maskx; acc = reduce( acc, masky || repeat ? adata[oax + yy * tya] : pad, i, j ); } } odata[oox + y * tyo] = complete(acc); } } return out; }; const applyKernel3 = (out, a, { init, reduce, complete, shape: [sxk, syk, szk] }, pad = true) => { !out && (out = a.empty()); const { data: odata, offset: oo, stride: [txo, tyo, tzo] } = out; const { data: adata, offset: oa, shape: [sxa, sya, sza], stride: [txa, tya, tza] } = a; const sxk2 = sxk >> 1; const syk2 = syk >> 1; const szk2 = szk >> 1; const maxx = sxa - 1; const maxy = sya - 1; const maxz = sza - 1; const repeat = pad === true; let x, xx, y, yy, z, zz, oox, oax, ooy, oay, i, j, k, acc, maskx, masky, maskz; for (x = 0; x < sxa; x++) { oox = oo + x * txo; for (y = 0; y < sya; y++) { ooy = oox + y * tyo; for (z = 0; z < sza; z++) { for (acc = init(), i = 0; i < sxk; i++) { xx = x + i - sxk2; if (xx < 0) { maskx = false; xx = 0; } else if (xx > maxx) { maskx = false; xx = maxx; } else maskx = true; oax = oa + xx * txa; for (j = 0; j < syk; j++) { yy = y + j - syk2; if (yy < 0) { masky = false; yy = 0; } else if (yy > maxy) { masky = false; yy = maxy; } else masky = maskx; oay = oax + yy * tya; for (k = 0; k < szk; k++) { zz = z + k - szk2; if (zz < 0) { maskz = false; zz = 0; } else if (zz > maxz) { maskz = false; zz = maxz; } else maskz = masky; acc = reduce( acc, maskz || repeat ? adata[oay + zz * tza] : pad, i, j, k ); } } } odata[ooy + z * tzo] = complete(acc); } } } return out; }; const applyKernel = top(1, void 0, applyKernel1, applyKernel2, applyKernel3); export { applyKernel };