UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

30 lines 1.22 kB
import { SparseTensor } from '../../../tensor/sparse/tensor'; import { WASMTensor } from '../../../tensor/wasm/tensor'; import { compareShapes, getSize } from '../../../util/shape'; export function reshapeWasm(tensor, values, indices, shape, copy) { const oldSparseSize = getSize(tensor.getSparseShape()); const sparseShape = []; const denseShape = []; let sparseSize = 1; for (let i = 0; i < shape.length; i++) { if (sparseSize < oldSparseSize) { sparseSize *= shape[i]; sparseShape.push(shape[i]); } else { denseShape.push(shape[i]); } } const nnzFraction = sparseSize / oldSparseSize; const nnz = tensor.nnz * nnzFraction; const newValues = values.reshape([nnz, ...denseShape], copy); let newIndices; if (!copy && compareShapes(sparseShape, tensor.getSparseShape())) { newIndices = indices; } else { newIndices = new WASMTensor(indices.wasmTensor.reshape_sparse_indices(new Uint32Array(tensor.getSparseShape()), new Uint32Array(shape)), undefined, 'uint32'); } return new SparseTensor(newValues, newIndices, shape, denseShape.length); } //# sourceMappingURL=wasm.js.map