UNPKG

@hoff97/tensor-js

Version:

PyTorch like deep learning inferrence library

171 lines (164 loc) 6.06 kB
import { defaultAllocator } from '../../../tensor/gpu/gl'; import { gpuConstructor } from '../../../tensor/gpu/tensor'; import { SparseTensor } from '../../../tensor/sparse/tensor'; import { compareShapes, computeStrides, getSize } from '../../../util/shape'; import { Dispatcher } from '../../gpu/dispatcher'; import { Operation } from '../../gpu/operation'; export class ReshapeIndicesOperation extends Operation { constructor(tensorConstructor, dtype, allocator) { super(tensorConstructor, dtype, allocator); } getVariables() { return ` ${this.getVarModifier('nnzFraction')} int nnzFraction; ${this.getVarModifier('sparseDims')} int sparseDims; ${this.getVarModifier('oldSparseStrides')} int oldSparseStrides[${this.maxRank}]; ${this.getVarModifier('newSparseStrides')} int newSparseStrides[${this.maxRank}]; `; } getUniformAttrs() { return [ { name: 'nnzFraction' }, { name: 'sparseDims' }, { name: 'oldSparseStrides', length: this.maxRank }, { name: 'newSparseStrides', length: this.maxRank }, ]; } // eslint-disable-next-line @typescript-eslint/no-unused-vars getFragmentShader(info) { return ` float process(int index[${this.maxRank}]) { float res = 0.0; int newNNZ = int(index[0]); int oldNnzIx = newNNZ / nnzFraction; int residualNNZ = newNNZ - oldNnzIx*nnzFraction; int oldSparseIx[${this.maxRank}]; ${this.initIndex('oldSparseIx')} for (int j = 0; j < ${this.maxRank}; j++) { if (j >= sparseDims) { break; } oldSparseIx[j] = int(getValueAtPos(oldNnzIx * sparseDims + j, widthA, heightA, A)); } int oldSparsePos = indexToPos(oldSparseIx, oldSparseStrides); int newSparsePos = oldSparsePos * nnzFraction + residualNNZ; int newSparseIx[${this.maxRank}]; ${this.initIndex('newSparseIx')} ${this.posToIndex('newSparseStrides', 'newSparseIx', 'newSparsePos')} int ax = int(index[1]); for (int j = 0; j < ${this.maxRank}; j++) { if (j == ax) { res = float(newSparseIx[j]); break; } } return res; } ${this.getDefaultMain()} `; } getTextureNames() { return ['A']; } calc(input) { if (this.fullyStatic && this.outputShape !== undefined) { return this.compute(this.outputShape, { A: input.A }); } const outputShape = this.getOutputShape(input); const info = this.getCompilationInfo(input); return this.compute(outputShape, { A: input.A }, { nnzFraction: info.nnzFraction, sparseDims: info.sparseDims, oldSparseStrides: this.pad(info.oldSparseStrides), newSparseStrides: this.pad(info.newSparseStrides), }); } getOutputShape(input) { const oldSparseSize = getSize(input.sparseShape); const sparseShape = []; let sparseSize = 1; for (let i = 0; i < input.shape.length; i++) { if (sparseSize < oldSparseSize) { sparseSize *= input.shape[i]; sparseShape.push(input.shape[i]); } else { break; } } const nnzFraction = sparseSize / oldSparseSize; const nnz = input.nnz * nnzFraction; return [nnz, sparseShape.length]; } compile(info) { super.compile(info); } getCompilationInfo(input) { const outputShape = this.getOutputShape(input); const outputSize = defaultAllocator.getAllocationDimensions(getSize(outputShape), this.dtype); const oldSparseSize = getSize(input.sparseShape); const sparseShape = []; const denseShape = []; let sparseSize = 1; for (let i = 0; i < input.shape.length; i++) { if (sparseSize < oldSparseSize) { sparseSize *= input.shape[i]; sparseShape.push(input.shape[i]); } else { denseShape.push(input.shape[i]); } } const oldSparseStrides = computeStrides(input.sparseShape); const newSparseStrides = computeStrides(sparseShape); const nnzFraction = sparseSize / oldSparseSize; return { shapeA: input.A.shape, widthA: input.A.memory.width, heightA: input.A.memory.height, shapeOutput: outputShape, widthOutput: outputSize.width, heightOutput: outputSize.height, sparseDims: input.sparseShape.length, newSparseStrides, oldSparseStrides, nnzFraction, }; } getInputInfoString(input) { return `${input.A.shape}-${input.nnz}-${input.shape}-${input.sparseShape}`; } } export const defaultReshapeIndicesD = new Dispatcher((dtype) => new ReshapeIndicesOperation(gpuConstructor, dtype)); export function reshapeGPU(tensor, values, indices, shape, copy) { const oldSparseSize = getSize(tensor.getSparseShape()); const sparseShape = []; const denseShape = []; let sparseSize = 1; for (let i = 0; i < shape.length; i++) { if (sparseSize < oldSparseSize) { sparseSize *= shape[i]; sparseShape.push(shape[i]); } else { denseShape.push(shape[i]); } } const nnzFraction = sparseSize / oldSparseSize; const nnz = tensor.nnz * nnzFraction; const newValues = values.reshape([nnz, ...denseShape], copy); let newIndices; if (!copy && compareShapes(sparseShape, tensor.getSparseShape())) { newIndices = indices; } else { newIndices = defaultReshapeIndicesD.calc({ A: indices, sparseShape: tensor.getSparseShape(), shape: shape, nnz: tensor.nnz, }, 'uint32'); } return new SparseTensor(newValues, newIndices, shape, denseShape.length); } //# sourceMappingURL=gpu.js.map