@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
25 lines • 1.17 kB
JavaScript
import { CPUTensor } from '../../../../tensor/cpu/tensor';
import { SparseTensor } from '../../../../tensor/sparse/tensor';
import { WASMTensor } from '../../../../tensor/wasm/tensor';
import { poolResultShape } from '../../../util/pool';
import { sumSparseCPU } from './cpu';
import { sumSparseWASM } from './wasm';
export function sum(tensor, axes, keepDims) {
if (axes.find(ax => ax < tensor.sparseDims) !== undefined) {
return sumSparse(tensor, axes, keepDims);
}
else {
const [resultShape, _ixMap] = poolResultShape(tensor.shape, axes, keepDims);
return new SparseTensor(tensor.values.sum(axes.map(ax => ax - tensor.sparseDims + 1), keepDims), tensor.indices.copy(), resultShape, keepDims ? tensor.denseDims : tensor.denseDims - axes.length);
}
}
function sumSparse(tensor, axes, keepDims) {
if (tensor.values instanceof CPUTensor) {
return sumSparseCPU(tensor, axes, keepDims);
}
else if (tensor.values instanceof WASMTensor) {
return sumSparseWASM(tensor, axes, keepDims);
}
throw new Error('Sum over sparse dimensions not implemented in WebGL yet');
}
//# sourceMappingURL=sum.js.map