@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
25 lines • 1.3 kB
JavaScript
import { CPUTensor } from '../../../../tensor/cpu/tensor';
import { SparseTensor } from '../../../../tensor/sparse/tensor';
import { WASMTensor } from '../../../../tensor/wasm/tensor';
import { poolResultShape } from '../../../util/pool';
import { reduceMeanSquareSparseCPU } from './cpu';
import { reduceMeanSquareSparseWASM } from './wasm';
export function reduceMeanSquare(tensor, axes, keepDims) {
if (axes.find(ax => ax < tensor.sparseDims) !== undefined) {
return reduceMeanSquareSparse(tensor, axes, keepDims);
}
else {
const [resultShape, _ixMap] = poolResultShape(tensor.shape, axes, keepDims);
return new SparseTensor(tensor.values.reduceMeanSquare(axes.map(ax => ax - tensor.sparseDims + 1), keepDims), tensor.indices.copy(), resultShape, keepDims ? tensor.denseDims : tensor.denseDims - axes.length);
}
}
function reduceMeanSquareSparse(tensor, axes, keepDims) {
if (tensor.values instanceof CPUTensor) {
return reduceMeanSquareSparseCPU(tensor, axes, keepDims);
}
else if (tensor.values instanceof WASMTensor) {
return reduceMeanSquareSparseWASM(tensor, axes, keepDims);
}
throw new Error('Reduce mean squared over sparse dimensions not implemented in WebGL yet');
}
//# sourceMappingURL=reduceMeanSquare.js.map