UNPKG

@thi.ng/tensors

Version:

0D/1D/2D/3D/4D tensors with extensible polymorphic operations and customizable storage

112 lines (111 loc) 2.46 kB
import { top } from "./top.js"; const defOpTN = (fn) => { const f0 = (out, a, n) => { !out && (out = a); out.data[out.offset] = fn(a.data[a.offset], n); return out; }; const f1 = (out, a, n) => { !out && (out = a); const { data: odata, offset: oo, stride: [txo] } = out; const { data: adata, offset: oa, shape: [sx], stride: [txa] } = a; for (let x = 0; x < sx; x++) { odata[oo + x * txo] = fn(adata[oa + x * txa], n); } return out; }; const f2 = (out, a, n) => { !out && (out = a); const { data: odata, offset: oo, stride: [txo, tyo] } = out; const { data: adata, offset: oa, shape: [sx, sy], stride: [txa, tya] } = a; let oox, oax; for (let x = 0; x < sx; x++) { oox = oo + x * txo; oax = oa + x * txa; for (let y = 0; y < sy; y++) { odata[oox + y * tyo] = fn(adata[oax + y * tya], n); } } return out; }; const f3 = (out, a, n) => { !out && (out = a); const { data: odata, offset: oo, stride: [txo, tyo, tzo] } = out; const { data: adata, offset: oa, shape: [sx, sy, sz], stride: [txa, tya, tza] } = a; let oox, oax, ooy, oay; for (let x = 0; x < sx; x++) { oox = oo + x * txo; oax = oa + x * txa; for (let y = 0; y < sy; y++) { ooy = oox + y * tyo; oay = oax + y * tya; for (let z = 0; z < sz; z++) { odata[ooy + z * tzo] = fn(adata[oay + z * tza], n); } } } return out; }; const f4 = (out, a, n) => { !out && (out = a); const { data: odata, offset: oo, stride: [txo, tyo, tzo, two] } = out; const { data: adata, offset: oa, shape: [sx, sy, sz, sw], stride: [txa, tya, tza, twa] } = a; let oox, oax, ooy, oay, ooz, oaz; for (let x = 0; x < sx; x++) { oox = oo + x * txo; oax = oa + x * txa; for (let y = 0; y < sy; y++) { ooy = oox + y * tyo; oay = oax + y * tya; for (let z = 0; z < sz; z++) { ooz = ooy + z * tzo; oaz = oay + z * tza; for (let w = 0; w < sw; w++) { odata[ooz + w * two] = fn(adata[oaz + w * twa], n); } } } } return out; }; return top(1, f0, f1, f2, f3, f4); }; export { defOpTN };