@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
33 lines • 1.41 kB
JavaScript
import { CPUTensor } from '../../../tensor/cpu/tensor';
import { SparseTensor } from '../../../tensor/sparse/tensor';
import { WASMTensor } from '../../../tensor/wasm/tensor';
import { repeatIndicesCPU } from './cpu';
import { repeatIndexGPU } from './gpu';
import { repeatIndicesWASM } from './wasm';
export function repeat(tensor, repeats) {
const sparseRepeats = repeats.slice(0, tensor.sparseDims);
const denseRepeats = repeats.slice(tensor.sparseDims);
const sparseRepeatsProd = sparseRepeats.reduce((a, b) => a * b, 1);
const values = tensor.values.repeat([sparseRepeatsProd, ...denseRepeats]);
let indices;
if (sparseRepeatsProd > 1) {
indices = repeatIndices(tensor.indices, sparseRepeats, tensor.getSparseShape(), sparseRepeatsProd);
}
else {
indices = tensor.indices.copy();
}
const newShape = tensor.shape.map((v, i) => v * repeats[i]);
return new SparseTensor(values, indices, newShape, tensor.denseDims);
}
function repeatIndices(indices, repeats, shape, repeatsProd) {
if (indices instanceof CPUTensor) {
return repeatIndicesCPU(indices, repeats, shape, repeatsProd);
}
else if (indices instanceof WASMTensor) {
return repeatIndicesWASM(indices, repeats, shape, repeatsProd);
}
else {
return repeatIndexGPU(indices, repeats, shape, repeatsProd);
}
}
//# sourceMappingURL=repeat.js.map