@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
30 lines • 1.22 kB
JavaScript
import { SparseTensor } from '../../../tensor/sparse/tensor';
import { WASMTensor } from '../../../tensor/wasm/tensor';
import { compareShapes, getSize } from '../../../util/shape';
export function reshapeWasm(tensor, values, indices, shape, copy) {
const oldSparseSize = getSize(tensor.getSparseShape());
const sparseShape = [];
const denseShape = [];
let sparseSize = 1;
for (let i = 0; i < shape.length; i++) {
if (sparseSize < oldSparseSize) {
sparseSize *= shape[i];
sparseShape.push(shape[i]);
}
else {
denseShape.push(shape[i]);
}
}
const nnzFraction = sparseSize / oldSparseSize;
const nnz = tensor.nnz * nnzFraction;
const newValues = values.reshape([nnz, ...denseShape], copy);
let newIndices;
if (!copy && compareShapes(sparseShape, tensor.getSparseShape())) {
newIndices = indices;
}
else {
newIndices = new WASMTensor(indices.wasmTensor.reshape_sparse_indices(new Uint32Array(tensor.getSparseShape()), new Uint32Array(shape)), undefined, 'uint32');
}
return new SparseTensor(newValues, newIndices, shape, denseShape.length);
}
//# sourceMappingURL=wasm.js.map