UNPKG

playcanvas

Version:

Open-source WebGL/WebGPU 3D engine for the web

164 lines (161 loc) 6.02 kB
import { Compute } from '../../platform/graphics/compute.js'; import { Shader } from '../../platform/graphics/shader.js'; import { StorageBuffer } from '../../platform/graphics/storage-buffer.js'; import { BindGroupFormat, BindStorageBufferFormat, BindUniformBufferFormat } from '../../platform/graphics/bind-group-format.js'; import { UniformBufferFormat, UniformFormat } from '../../platform/graphics/uniform-buffer-format.js'; import { UNIFORMTYPE_UINT, SHADERSTAGE_COMPUTE, SHADERLANGUAGE_WGSL } from '../../platform/graphics/constants.js'; import { prefixSumSource } from '../shader-lib/wgsl/chunks/radix-sort/compute-prefix-sum.js'; const WORKGROUP_SIZE_X = 16; const WORKGROUP_SIZE_Y = 16; const THREADS_PER_WORKGROUP = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; const ITEMS_PER_WORKGROUP = 2 * THREADS_PER_WORKGROUP; class PrefixSumKernel { destroy() { this.destroyPasses(); this._scanShader?.destroy(); this._addBlockShader?.destroy(); this._bindGroupFormat?.destroy(); this._scanShader = null; this._addBlockShader = null; this._bindGroupFormat = null; this._uniformBufferFormat = null; } _createFormatsAndShaders() { this._uniformBufferFormat = new UniformBufferFormat(this.device, [ new UniformFormat('elementCount', UNIFORMTYPE_UINT) ]); this._bindGroupFormat = new BindGroupFormat(this.device, [ new BindStorageBufferFormat('items', SHADERSTAGE_COMPUTE, false), new BindStorageBufferFormat('blockSums', SHADERSTAGE_COMPUTE, false), new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE) ]); this._scanShader = this._createShader('PrefixSumScan', 'reduce_downsweep'); this._addBlockShader = this._createShader('PrefixSumAddBlock', 'add_block_sums'); } createPassesRecursive(dataBuffer, count) { const workgroupCount = Math.ceil(count / ITEMS_PER_WORKGROUP); const { x: dispatchX, y: dispatchY } = this.findOptimalDispatchSize(workgroupCount); const blockSumBuffer = new StorageBuffer(this.device, workgroupCount * 4); const scanCompute = new Compute(this.device, this._scanShader, 'PrefixSumScan'); scanCompute.setParameter('items', dataBuffer); scanCompute.setParameter('blockSums', blockSumBuffer); const pass = { scanCompute, addBlockCompute: null, blockSumBuffer, dispatchX, dispatchY, count, allocatedCount: count }; this.passes.push(pass); if (workgroupCount > 1) { this.createPassesRecursive(blockSumBuffer, workgroupCount); const addBlockCompute = new Compute(this.device, this._addBlockShader, 'PrefixSumAddBlock'); addBlockCompute.setParameter('items', dataBuffer); addBlockCompute.setParameter('blockSums', blockSumBuffer); pass.addBlockCompute = addBlockCompute; } } _createShader(name, entryPoint) { const cdefines = new Map(); cdefines.set('{WORKGROUP_SIZE_X}', WORKGROUP_SIZE_X); cdefines.set('{WORKGROUP_SIZE_Y}', WORKGROUP_SIZE_Y); cdefines.set('{THREADS_PER_WORKGROUP}', THREADS_PER_WORKGROUP); cdefines.set('{ITEMS_PER_WORKGROUP}', ITEMS_PER_WORKGROUP); return new Shader(this.device, { name: name, shaderLanguage: SHADERLANGUAGE_WGSL, cshader: prefixSumSource, cdefines: cdefines, computeEntryPoint: entryPoint, computeBindGroupFormat: this._bindGroupFormat, computeUniformBufferFormats: { uniforms: this._uniformBufferFormat } }); } findOptimalDispatchSize(workgroupCount) { const maxDimension = this.device.limits.maxComputeWorkgroupsPerDimension || 65535; if (workgroupCount <= maxDimension) { return { x: workgroupCount, y: 1 }; } const x = Math.floor(Math.sqrt(workgroupCount)); const y = Math.ceil(workgroupCount / x); return { x, y }; } resize(dataBuffer, count) { const requiredPasses = this._countPassesNeeded(count); const currentPasses = this.passes.length; if (requiredPasses > currentPasses) { this.destroyPasses(); this.createPassesRecursive(dataBuffer, count); return; } let levelCount = count; for(let i = 0; i < this.passes.length; i++){ const workgroupCount = Math.ceil(levelCount / ITEMS_PER_WORKGROUP); const { x: dispatchX, y: dispatchY } = this.findOptimalDispatchSize(workgroupCount); this.passes[i].count = levelCount; this.passes[i].dispatchX = dispatchX; this.passes[i].dispatchY = dispatchY; levelCount = workgroupCount; if (workgroupCount <= 1) { break; } } } destroyPasses() { for (const pass of this.passes){ pass.blockSumBuffer?.destroy(); } this.passes.length = 0; } _countPassesNeeded(count) { let passes = 0; let levelCount = count; while(levelCount > 0){ passes++; const workgroupCount = Math.ceil(levelCount / ITEMS_PER_WORKGROUP); if (workgroupCount <= 1) break; levelCount = workgroupCount; } return passes; } dispatch(device) { for(let i = 0; i < this.passes.length; i++){ const pass = this.passes[i]; pass.scanCompute.setParameter('elementCount', pass.count); pass.scanCompute.setupDispatch(pass.dispatchX, pass.dispatchY, 1); device.computeDispatch([ pass.scanCompute ], 'PrefixSumScan'); } for(let i = this.passes.length - 1; i >= 0; i--){ const pass = this.passes[i]; if (pass.addBlockCompute) { pass.addBlockCompute.setParameter('elementCount', pass.count); pass.addBlockCompute.setupDispatch(pass.dispatchX, pass.dispatchY, 1); device.computeDispatch([ pass.addBlockCompute ], 'PrefixSumAddBlock'); } } } constructor(device){ this.passes = []; this._uniformBufferFormat = null; this._bindGroupFormat = null; this._scanShader = null; this._addBlockShader = null; this.device = device; this._createFormatsAndShaders(); } } export { PrefixSumKernel };