UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

33 lines 1.41 kB
import { CPUTensor } from '../../../tensor/cpu/tensor'; import { SparseTensor } from '../../../tensor/sparse/tensor'; import { WASMTensor } from '../../../tensor/wasm/tensor'; import { repeatIndicesCPU } from './cpu'; import { repeatIndexGPU } from './gpu'; import { repeatIndicesWASM } from './wasm'; export function repeat(tensor, repeats) { const sparseRepeats = repeats.slice(0, tensor.sparseDims); const denseRepeats = repeats.slice(tensor.sparseDims); const sparseRepeatsProd = sparseRepeats.reduce((a, b) => a * b, 1); const values = tensor.values.repeat([sparseRepeatsProd, ...denseRepeats]); let indices; if (sparseRepeatsProd > 1) { indices = repeatIndices(tensor.indices, sparseRepeats, tensor.getSparseShape(), sparseRepeatsProd); } else { indices = tensor.indices.copy(); } const newShape = tensor.shape.map((v, i) => v * repeats[i]); return new SparseTensor(values, indices, newShape, tensor.denseDims); } function repeatIndices(indices, repeats, shape, repeatsProd) { if (indices instanceof CPUTensor) { return repeatIndicesCPU(indices, repeats, shape, repeatsProd); } else if (indices instanceof WASMTensor) { return repeatIndicesWASM(indices, repeats, shape, repeatsProd); } else { return repeatIndexGPU(indices, repeats, shape, repeatsProd); } } //# sourceMappingURL=repeat.js.map