@thi.ng/tensors
Version:
0D/1D/2D/3D/4D tensors with extensible polymorphic operations and customizable storage
141 lines (140 loc) • 3.25 kB
JavaScript
import { top } from "./top.js";
const convolve1 = (out, a, k) => {
!out && (out = a.empty());
const {
data: odata,
offset: oo,
stride: [txo]
} = out;
const {
data: adata,
offset: oa,
shape: [sxa],
stride: [txa]
} = a;
const {
data: kdata,
offset: ok,
shape: [sxk],
stride: [txk]
} = k;
const sxk2 = sxk >> 1;
const mx = sxa - 1;
let x, xx, i, sum;
for (x = 0; x < sxa; x++) {
for (sum = 0, i = 0; i < sxk; i++) {
xx = x + i - sxk2;
if (xx < 0) xx = 0;
else if (xx > mx) xx = mx;
sum += adata[oa + xx * txa] * kdata[ok + i * txk];
}
odata[oo + x * txo] = sum;
}
return out;
};
const convolve2 = (out, a, k) => {
!out && (out = a.empty());
const {
data: odata,
offset: oo,
stride: [txo, tyo]
} = out;
const {
data: adata,
offset: oa,
shape: [sxa, sya],
stride: [txa, tya]
} = a;
const {
data: kdata,
offset: ok,
shape: [sxk, syk],
stride: [txk, tyk]
} = k;
const sxk2 = sxk >> 1;
const syk2 = syk >> 1;
const mx = sxa - 1;
const my = sya - 1;
let x, xx, y, yy, oox, oax, okx, i, j, sum;
for (x = 0; x < sxa; x++) {
oox = oo + x * txo;
for (y = 0; y < sya; y++) {
for (sum = 0, i = 0; i < sxk; i++) {
xx = x + i - sxk2;
if (xx < 0) xx = 0;
else if (xx > mx) xx = mx;
oax = oa + xx * txa;
okx = ok + i * txk;
for (j = 0; j < syk; j++) {
yy = y + j - syk2;
if (yy < 0) yy = 0;
else if (yy > my) yy = my;
sum += adata[oax + yy * tya] * kdata[okx + j * tyk];
}
}
odata[oox + y * tyo] = sum;
}
}
return out;
};
const convolve3 = (out, a, k) => {
!out && (out = a.empty());
const {
data: odata,
offset: oo,
stride: [txo, tyo, tzo]
} = out;
const {
data: adata,
offset: oa,
shape: [sxa, sya, sza],
stride: [txa, tya, tza]
} = a;
const {
data: kdata,
offset: ok,
shape: [sxk, syk, szk],
stride: [txk, tyk, tzk]
} = k;
const sxk2 = sxk >> 1;
const syk2 = syk >> 1;
const szk2 = szk >> 1;
const mx = sxa - 1;
const my = sya - 1;
const mz = sza - 1;
let x, xx, y, yy, z, zz, oox, oax, okx, ooy, oay, oky, i, j, l, sum;
for (x = 0; x < sxa; x++) {
oox = oo + x * txo;
for (y = 0; y < sya; y++) {
ooy = oox + y * tyo;
for (z = 0; z < sza; z++) {
for (sum = 0, i = 0; i < sxk; i++) {
xx = x + i - sxk2;
if (xx < 0) xx = 0;
else if (xx > mx) xx = mx;
oax = oa + xx * txa;
okx = ok + i * txk;
for (j = 0; j < syk; j++) {
yy = y + j - syk2;
if (yy < 0) yy = 0;
else if (yy > my) yy = my;
oay = oax + yy * tya;
oky = okx + j * tyk;
for (l = 0; l < szk; l++) {
zz = z + l - szk2;
if (zz < 0) zz = 0;
else if (zz > mz) zz = mz;
sum += adata[oay + zz * tza] * kdata[oky + l * tzk];
}
}
}
odata[ooy + z * tzo] = sum;
}
}
}
return out;
};
const convolve = top(1, void 0, convolve1, convolve2, convolve3);
export {
convolve
};