UNPKG

playcanvas

Version:

Open-source WebGL/WebGPU 3D engine for the web

256 lines (253 loc) 12.2 kB
import { Vec2 } from '../../core/math/vec2.js'; import { Compute } from '../../platform/graphics/compute.js'; import { Shader } from '../../platform/graphics/shader.js'; import { StorageBuffer } from '../../platform/graphics/storage-buffer.js'; import { BindStorageBufferFormat, BindUniformBufferFormat, BindGroupFormat } from '../../platform/graphics/bind-group-format.js'; import { UniformBufferFormat, UniformFormat } from '../../platform/graphics/uniform-buffer-format.js'; import { SHADERSTAGE_COMPUTE, BUFFERUSAGE_COPY_SRC, BUFFERUSAGE_COPY_DST, SHADERLANGUAGE_WGSL, UNIFORMTYPE_UINT } from '../../platform/graphics/constants.js'; import { PrefixSumKernel } from './prefix-sum-kernel.js'; import { radixSort4bitSource } from '../shader-lib/wgsl/chunks/radix-sort/compute-radix-sort-4bit.js'; import { radixSortReorderSource } from '../shader-lib/wgsl/chunks/radix-sort/compute-radix-sort-reorder.js'; const BITS_PER_PASS = 4; const BUCKET_COUNT = 16; const WORKGROUP_SIZE_X = 16; const WORKGROUP_SIZE_Y = 16; const THREADS_PER_WORKGROUP = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; const ELEMENTS_PER_THREAD = 8; const ELEMENTS_PER_WORKGROUP = THREADS_PER_WORKGROUP * ELEMENTS_PER_THREAD; class ComputeRadixSort { destroy() { this._destroyBuffers(); this._destroyPasses(); this._histogramBindGroupFormat?.destroy(); this._reorderBindGroupFormat?.destroy(); this._histogramBindGroupFormat = null; this._reorderBindGroupFormat = null; this._uniformBufferFormat = null; } _destroyPasses() { for (const pass of this._passes){ pass.histogramCompute.shader?.destroy(); pass.reorderCompute.shader?.destroy(); } this._passes.length = 0; this._numBits = 0; } _destroyBuffers() { this._keys0?.destroy(); this._keys1?.destroy(); this._values0?.destroy(); this._values1?.destroy(); this._blockSums?.destroy(); this._sortedIndices?.destroy(); this._prefixSumKernel?.destroy(); this._keys0 = null; this._keys1 = null; this._values0 = null; this._values1 = null; this._blockSums = null; this._sortedIndices = null; this._prefixSumKernel = null; this._workgroupCount = 0; this._allocatedWorkgroupCount = 0; } get sortedIndices() { return this._sortedIndices; } get sortedKeys() { if (!this._keys0) return null; const numPasses = this._numBits / BITS_PER_PASS; return numPasses % 2 === 0 ? this._keys1 : this._keys0; } _ensureBindGroupFormats(indirect) { if (this._histogramBindGroupFormat && this._indirect === indirect) { return; } this._histogramBindGroupFormat?.destroy(); this._reorderBindGroupFormat?.destroy(); const device = this.device; const histogramEntries = [ new BindStorageBufferFormat('input', SHADERSTAGE_COMPUTE, true), new BindStorageBufferFormat('block_sums', SHADERSTAGE_COMPUTE, false), new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE) ]; const reorderEntries = [ new BindStorageBufferFormat('inputKeys', SHADERSTAGE_COMPUTE, true), new BindStorageBufferFormat('outputKeys', SHADERSTAGE_COMPUTE, false), new BindStorageBufferFormat('prefix_block_sum', SHADERSTAGE_COMPUTE, true), new BindStorageBufferFormat('inputValues', SHADERSTAGE_COMPUTE, true), new BindStorageBufferFormat('outputValues', SHADERSTAGE_COMPUTE, false), new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE) ]; if (indirect) { histogramEntries.push(new BindStorageBufferFormat('sortElementCount', SHADERSTAGE_COMPUTE, true)); reorderEntries.push(new BindStorageBufferFormat('sortElementCount', SHADERSTAGE_COMPUTE, true)); } this._histogramBindGroupFormat = new BindGroupFormat(device, histogramEntries); this._reorderBindGroupFormat = new BindGroupFormat(device, reorderEntries); } _createPasses(numBits, indirect, hasInitialValues, skipLastPassKeyWrite) { this._destroyPasses(); this._numBits = numBits; this._ensureBindGroupFormats(indirect); this._indirect = indirect; this._hasInitialValues = hasInitialValues; this._skipLastPassKeyWrite = skipLastPassKeyWrite; const numPasses = numBits / BITS_PER_PASS; const suffix = indirect ? '-Indirect' : ''; for(let pass = 0; pass < numPasses; pass++){ const bitOffset = pass * BITS_PER_PASS; const isFirstPass = pass === 0 && !hasInitialValues; const isLastPass = skipLastPassKeyWrite && pass === numPasses - 1; const histogramShader = this._createShader(`RadixSort4bit-Histogram${suffix}-${bitOffset}`, radixSort4bitSource, bitOffset, false, false, this._histogramBindGroupFormat, indirect); const reorderShader = this._createShader(`RadixSort4bit-Reorder${suffix}-${bitOffset}`, radixSortReorderSource, bitOffset, isFirstPass, isLastPass, this._reorderBindGroupFormat, indirect); const histogramCompute = new Compute(this.device, histogramShader, `RadixSort4bit-Histogram${suffix}-${bitOffset}`); const reorderCompute = new Compute(this.device, reorderShader, `RadixSort4bit-Reorder${suffix}-${bitOffset}`); this._passes.push({ histogramCompute, reorderCompute }); } } _allocateBuffers(elementCount, numBits, indirect, hasInitialValues, skipLastPassKeyWrite) { const effectiveCount = Math.max(elementCount, this.capacity); const workgroupCount = Math.ceil(effectiveCount / ELEMENTS_PER_WORKGROUP); const buffersNeedRealloc = workgroupCount !== this._allocatedWorkgroupCount || !this._keys0; const passesNeedRecreate = numBits !== this._numBits || indirect !== this._indirect || hasInitialValues !== this._hasInitialValues || skipLastPassKeyWrite !== this._skipLastPassKeyWrite; if (buffersNeedRealloc) { this._destroyBuffers(); this._allocatedWorkgroupCount = workgroupCount; this.capacity = effectiveCount; const elementSize = effectiveCount * 4; const blockSumSize = BUCKET_COUNT * workgroupCount * 4; this._keys0 = new StorageBuffer(this.device, elementSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST); this._keys1 = new StorageBuffer(this.device, elementSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST); this._values0 = new StorageBuffer(this.device, elementSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST); this._values1 = new StorageBuffer(this.device, elementSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST); this._blockSums = new StorageBuffer(this.device, blockSumSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST); this._sortedIndices = new StorageBuffer(this.device, elementSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST); this._prefixSumKernel = new PrefixSumKernel(this.device); } this._workgroupCount = workgroupCount; Compute.calcDispatchSize(workgroupCount, this._dispatchSize, this.device.limits.maxComputeWorkgroupsPerDimension || 65535); this._prefixSumKernel.resize(this._blockSums, BUCKET_COUNT * workgroupCount); if (passesNeedRecreate) { this._createPasses(numBits, indirect, hasInitialValues, skipLastPassKeyWrite); } } _createShader(name, source, currentBit, isFirstPass, isLastPass, bindGroupFormat, indirect) { const cdefines = new Map(); cdefines.set('{WORKGROUP_SIZE_X}', WORKGROUP_SIZE_X); cdefines.set('{WORKGROUP_SIZE_Y}', WORKGROUP_SIZE_Y); cdefines.set('{THREADS_PER_WORKGROUP}', THREADS_PER_WORKGROUP); cdefines.set('{ELEMENTS_PER_THREAD}', ELEMENTS_PER_THREAD); cdefines.set('{CURRENT_BIT}', currentBit); cdefines.set('{IS_FIRST_PASS}', isFirstPass ? 1 : 0); cdefines.set('{IS_LAST_PASS}', isLastPass ? 1 : 0); if (indirect) { cdefines.set('USE_INDIRECT_SORT', ''); } return new Shader(this.device, { name: name, shaderLanguage: SHADERLANGUAGE_WGSL, cshader: source, cdefines: cdefines, computeBindGroupFormat: bindGroupFormat, computeUniformBufferFormats: { uniforms: this._uniformBufferFormat } }); } sort(keysBuffer, elementCount, numBits = 16, initialValues, skipLastPassKeyWrite = false) { return this._execute(keysBuffer, elementCount, numBits, false, -1, null, initialValues, skipLastPassKeyWrite); } sortIndirect(keysBuffer, maxElementCount, numBits, dispatchSlot, sortElementCountBuffer, initialValues, skipLastPassKeyWrite = false) { return this._execute(keysBuffer, maxElementCount, numBits, true, dispatchSlot, sortElementCountBuffer, initialValues, skipLastPassKeyWrite); } _execute(keysBuffer, elementCount, numBits, indirect, dispatchSlot, sortElementCountBuffer, initialValues, skipLastPassKeyWrite = false) { this._elementCount = elementCount; const hasInitialValues = !!initialValues; this._allocateBuffers(elementCount, numBits, indirect, hasInitialValues, skipLastPassKeyWrite); const device = this.device; const numPasses = numBits / BITS_PER_PASS; const suffix = indirect ? '-Indirect' : ''; let currentKeys = keysBuffer; let currentValues = initialValues ?? this._values0; let nextKeys = this._keys0; let nextValues = this._values1; for(let pass = 0; pass < numPasses; pass++){ const { histogramCompute, reorderCompute } = this._passes[pass]; const isLastPass = pass === numPasses - 1; if (indirect) { this._blockSums.clear(); } histogramCompute.setParameter('input', currentKeys); histogramCompute.setParameter('block_sums', this._blockSums); histogramCompute.setParameter('workgroupCount', this._workgroupCount); histogramCompute.setParameter('elementCount', elementCount); if (indirect) { histogramCompute.setParameter('sortElementCount', sortElementCountBuffer); histogramCompute.setupIndirectDispatch(dispatchSlot); } else { histogramCompute.setupDispatch(this._dispatchSize.x, this._dispatchSize.y, 1); } device.computeDispatch([ histogramCompute ], `RadixSort-Histogram${suffix}`); this._prefixSumKernel.dispatch(device); const outputValues = isLastPass ? this._sortedIndices : nextValues; reorderCompute.setParameter('inputKeys', currentKeys); reorderCompute.setParameter('outputKeys', nextKeys); reorderCompute.setParameter('prefix_block_sum', this._blockSums); reorderCompute.setParameter('inputValues', currentValues); reorderCompute.setParameter('outputValues', outputValues); reorderCompute.setParameter('workgroupCount', this._workgroupCount); reorderCompute.setParameter('elementCount', elementCount); if (indirect) { reorderCompute.setParameter('sortElementCount', sortElementCountBuffer); reorderCompute.setupIndirectDispatch(dispatchSlot); } else { reorderCompute.setupDispatch(this._dispatchSize.x, this._dispatchSize.y, 1); } device.computeDispatch([ reorderCompute ], `RadixSort-Reorder${suffix}`); if (!isLastPass) { currentKeys = nextKeys; nextKeys = currentKeys === this._keys0 ? this._keys1 : this._keys0; const tempValues = currentValues; currentValues = nextValues; nextValues = tempValues; } } return this._sortedIndices; } constructor(device){ this._elementCount = 0; this._workgroupCount = 0; this._allocatedWorkgroupCount = 0; this.capacity = 0; this._numBits = 0; this._keys0 = null; this._keys1 = null; this._values0 = null; this._values1 = null; this._blockSums = null; this._sortedIndices = null; this._prefixSumKernel = null; this._dispatchSize = new Vec2(1, 1); this._histogramBindGroupFormat = null; this._reorderBindGroupFormat = null; this._uniformBufferFormat = null; this._passes = []; this._indirect = false; this._hasInitialValues = false; this._skipLastPassKeyWrite = false; this.device = device; this._uniformBufferFormat = new UniformBufferFormat(device, [ new UniformFormat('workgroupCount', UNIFORMTYPE_UINT), new UniformFormat('elementCount', UNIFORMTYPE_UINT) ]); } } export { ComputeRadixSort, ELEMENTS_PER_WORKGROUP as RADIX_SORT_ELEMENTS_PER_WORKGROUP };