@hoff97/tensor-js
Version:
PyTorch like deep learning inferrence library
36 lines • 1.44 kB
JavaScript
import { CPUTensor } from '../../../tensor/cpu/tensor';
import { SparseTensor } from '../../../tensor/sparse/tensor';
import { WASMTensor } from '../../../tensor/wasm/tensor';
import { compareShapes } from '../../../util/shape';
import { addIndexCPU } from './cpu';
import { addIndexGPU } from './gpu';
import { addIndexWASM } from './wasm';
export function concat(a, b, axis) {
if (!compareShapes(a.shape, b.shape) || a.sparseDims !== b.sparseDims) {
throw new Error('Sparse tensors can only be concatenated with the same shape and number of sparse dims');
}
if (axis > a.sparseDims) {
throw new Error('Concatenation along dense axis of sparse tensor not supported yet');
}
else {
const values = a.values.concat(b.values, 0);
const indexAdded = addIndex(b.indices, axis, a.shape[axis]);
const indices = a.indices.concat(indexAdded, 0);
indexAdded.delete();
const resultShape = [...a.shape];
resultShape[axis] += b.shape[axis];
return new SparseTensor(values, indices, resultShape, a.denseDims);
}
}
function addIndex(indices, axis, count) {
if (indices instanceof CPUTensor) {
return addIndexCPU(indices, axis, count);
}
else if (indices instanceof WASMTensor) {
return addIndexWASM(indices, axis, count);
}
else {
return addIndexGPU(indices, axis, count);
}
}
//# sourceMappingURL=concat.js.map