@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
85 lines • 3.83 kB
JavaScript
import { CPUTensor } from '../../tensor/cpu/tensor';
import { getSize, incrementIndex } from '../../util/shape';
import { outputDimsSize } from '../util/conv';
export function conv(x, w, dilations, group, pads, strides, activation, bias) {
const N = x.shape[0];
const C = x.shape[1];
const D = x.shape.slice(2);
const W = w.shape.slice(2);
const M = w.shape[0];
const CG = C / group;
const kernelSize = getSize(W);
const R = outputDimsSize(D, W, pads.slice(0, pads.length / 2), pads.slice(pads.length / 2), dilations, strides);
const outputSize = getSize(R);
let outputShape = [N, M];
outputShape = outputShape.concat(R);
const Y = new CPUTensor(outputShape, undefined, x.dtype);
const dataRank = R.length;
// Iterate over all batches
for (let n = 0; n < N; n++) {
// Iterate over all output channels
for (let m = 0; m < M; m++) {
if (bias) {
const b = bias ? bias.get([m]) : 0;
const outputIndices = new Array(R.length).fill(0);
outputIndices.unshift(n, m);
for (let oIx = 0; oIx < outputSize; oIx++) {
Y.set(outputIndices, b);
incrementIndex(outputIndices, Y.shape);
}
}
for (let cg = 0; cg < CG; cg++) {
const c = (m * CG + cg) % C;
const outputIndices = new Array(R.length).fill(0);
outputIndices.unshift(n, m);
for (let oIx = 0; oIx < outputSize; oIx++) {
let result = Y.get(outputIndices);
const kernelIndices = new Array(R.length).fill(0);
kernelIndices.unshift(m, cg);
for (let kIx = 0; kIx < kernelSize; kIx++) {
const inputIx = [n, c];
let skip = false;
for (let axis = 0; axis < dataRank; axis++) {
const stride = strides.length === 0 ? 1 : strides[axis];
const pad = pads.length === 0 ? 0 : pads[axis];
const dilation = dilations.length === 0 ? 1 : dilations[axis];
const ix = outputIndices[axis + 2] * stride -
pad +
kernelIndices[axis + 2] * dilation;
if (ix < 0 || ix >= D[axis]) {
skip = true;
break;
}
inputIx.push(ix);
}
if (!skip) {
const Wi = w.get(kernelIndices);
const Xi = x.get(inputIx);
result += Wi * Xi;
}
incrementIndex(kernelIndices, w.shape);
}
Y.set(outputIndices, result);
incrementIndex(outputIndices, Y.shape);
}
}
if (activation !== 'id') {
const outputIndices = new Array(R.length).fill(0);
outputIndices.unshift(n, m);
for (let oIx = 0; oIx < outputSize; oIx++) {
let result = Y.get(outputIndices);
if (activation === 'relu') {
result = Math.max(0, result);
}
else if (activation === 'relu6') {
result = Math.min(Math.max(0, result), 6);
}
Y.set(outputIndices, result);
incrementIndex(outputIndices, Y.shape);
}
}
}
}
return Y;
}
//# sourceMappingURL=conv.js.map