@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
52 lines • 2.21 kB
JavaScript
import { CPUTensor } from '../../../tensor/cpu/tensor';
import { getSize, incrementIndex, indexToPos } from '../../../util/shape';
import { poolResultShape } from '../../util/pool';
export function aggregateSparseCPU(tensor, axes, keepDims, op, init, postProcess) {
const [resultShape, ixMap] = poolResultShape(tensor.shape, axes, keepDims);
const result = new CPUTensor(resultShape, undefined, tensor.values.dtype);
const denseShape = tensor.getDenseShape();
const denseSize = getSize(denseShape, 1);
let count;
if (init !== undefined || postProcess !== undefined) {
count = new Array(result.size).fill(0);
}
const aIndices = tensor.indices;
const aValues = tensor.values;
for (let i = 0; i < tensor.nnz; i++) {
const sparseIx = new Array(tensor.sparseDims);
for (let j = 0; j < tensor.sparseDims; j++) {
sparseIx[j] = aIndices.get(i * tensor.sparseDims + j);
}
const denseIx = new Array(tensor.denseDims).fill(0);
for (let j = 0; j < denseSize; j++) {
const resultIx = [...sparseIx, ...denseIx];
const mappedResultIx = new Array(ixMap.length);
for (let k = 0; k < ixMap.length; k++) {
mappedResultIx[k] = resultIx[ixMap[k]];
}
const pos = indexToPos(mappedResultIx, result.strides);
if (init !== undefined && count !== undefined) {
if (count[pos] === 0) {
result.set(pos, init(aValues.get(i * denseSize + j)));
}
else {
result.set(pos, op(result.get(pos), aValues.get(i * denseSize + j)));
}
}
else {
result.set(pos, op(result.get(pos), aValues.get(i * denseSize + j)));
}
if (count !== undefined) {
count[pos]++;
}
incrementIndex(denseIx, denseShape);
}
}
if (postProcess !== undefined && count !== undefined) {
for (let i = 0; i < result.size; i++) {
result.set(i, postProcess(result.get(i), count[i]));
}
}
return result;
}
//# sourceMappingURL=cpu.js.map