UNPKG

playcanvas

Version:

Open-source WebGL/WebGPU 3D engine for the web

153 lines (152 loc) 5.62 kB
import { Compute } from "../../platform/graphics/compute.js"; import { Shader } from "../../platform/graphics/shader.js"; import { StorageBuffer } from "../../platform/graphics/storage-buffer.js"; import { BindGroupFormat, BindStorageBufferFormat, BindUniformBufferFormat } from "../../platform/graphics/bind-group-format.js"; import { UniformBufferFormat, UniformFormat } from "../../platform/graphics/uniform-buffer-format.js"; import { SHADERLANGUAGE_WGSL, SHADERSTAGE_COMPUTE, UNIFORMTYPE_UINT } from "../../platform/graphics/constants.js"; import { prefixSumSource } from "../shader-lib/wgsl/chunks/radix-sort/compute-prefix-sum.js"; const WORKGROUP_SIZE_X = 16; const WORKGROUP_SIZE_Y = 16; const THREADS_PER_WORKGROUP = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; const ITEMS_PER_WORKGROUP = 2 * THREADS_PER_WORKGROUP; class PrefixSumKernel { device; passes = []; _uniformBufferFormat = null; _bindGroupFormat = null; _scanShader = null; _addBlockShader = null; constructor(device) { this.device = device; this._createFormatsAndShaders(); } destroy() { this.destroyPasses(); this._scanShader?.destroy(); this._addBlockShader?.destroy(); this._bindGroupFormat?.destroy(); this._scanShader = null; this._addBlockShader = null; this._bindGroupFormat = null; this._uniformBufferFormat = null; } _createFormatsAndShaders() { this._uniformBufferFormat = new UniformBufferFormat(this.device, [ new UniformFormat("elementCount", UNIFORMTYPE_UINT) ]); this._bindGroupFormat = new BindGroupFormat(this.device, [ new BindStorageBufferFormat("items", SHADERSTAGE_COMPUTE, false), new BindStorageBufferFormat("blockSums", SHADERSTAGE_COMPUTE, false), new BindUniformBufferFormat("uniforms", SHADERSTAGE_COMPUTE) ]); this._scanShader = this._createShader("PrefixSumScan", "reduce_downsweep"); this._addBlockShader = this._createShader("PrefixSumAddBlock", "add_block_sums"); } createPassesRecursive(dataBuffer, count) { const workgroupCount = Math.ceil(count / ITEMS_PER_WORKGROUP); const { x: dispatchX, y: dispatchY } = this.findOptimalDispatchSize(workgroupCount); const blockSumBuffer = new StorageBuffer(this.device, workgroupCount * 4); const scanCompute = new Compute(this.device, this._scanShader, "PrefixSumScan"); scanCompute.setParameter("items", dataBuffer); scanCompute.setParameter("blockSums", blockSumBuffer); const pass = { scanCompute, addBlockCompute: null, blockSumBuffer, dispatchX, dispatchY, count, allocatedCount: count }; this.passes.push(pass); if (workgroupCount > 1) { this.createPassesRecursive(blockSumBuffer, workgroupCount); const addBlockCompute = new Compute(this.device, this._addBlockShader, "PrefixSumAddBlock"); addBlockCompute.setParameter("items", dataBuffer); addBlockCompute.setParameter("blockSums", blockSumBuffer); pass.addBlockCompute = addBlockCompute; } } _createShader(name, entryPoint) { const cdefines = /* @__PURE__ */ new Map(); cdefines.set("{WORKGROUP_SIZE_X}", WORKGROUP_SIZE_X); cdefines.set("{WORKGROUP_SIZE_Y}", WORKGROUP_SIZE_Y); cdefines.set("{THREADS_PER_WORKGROUP}", THREADS_PER_WORKGROUP); cdefines.set("{ITEMS_PER_WORKGROUP}", ITEMS_PER_WORKGROUP); return new Shader(this.device, { name, shaderLanguage: SHADERLANGUAGE_WGSL, cshader: prefixSumSource, cdefines, computeEntryPoint: entryPoint, computeBindGroupFormat: this._bindGroupFormat, computeUniformBufferFormats: { uniforms: this._uniformBufferFormat } }); } findOptimalDispatchSize(workgroupCount) { const maxDimension = this.device.limits.maxComputeWorkgroupsPerDimension || 65535; if (workgroupCount <= maxDimension) { return { x: workgroupCount, y: 1 }; } const x = Math.floor(Math.sqrt(workgroupCount)); const y = Math.ceil(workgroupCount / x); return { x, y }; } resize(dataBuffer, count) { const requiredPasses = this._countPassesNeeded(count); const currentPasses = this.passes.length; if (requiredPasses > currentPasses) { this.destroyPasses(); this.createPassesRecursive(dataBuffer, count); return; } let levelCount = count; for (let i = 0; i < this.passes.length; i++) { const workgroupCount = Math.ceil(levelCount / ITEMS_PER_WORKGROUP); const { x: dispatchX, y: dispatchY } = this.findOptimalDispatchSize(workgroupCount); this.passes[i].count = levelCount; this.passes[i].dispatchX = dispatchX; this.passes[i].dispatchY = dispatchY; levelCount = workgroupCount; if (workgroupCount <= 1) { break; } } } destroyPasses() { for (const pass of this.passes) { pass.blockSumBuffer?.destroy(); } this.passes.length = 0; } _countPassesNeeded(count) { let passes = 0; let levelCount = count; while (levelCount > 0) { passes++; const workgroupCount = Math.ceil(levelCount / ITEMS_PER_WORKGROUP); if (workgroupCount <= 1) break; levelCount = workgroupCount; } return passes; } dispatch(device) { for (let i = 0; i < this.passes.length; i++) { const pass = this.passes[i]; pass.scanCompute.setParameter("elementCount", pass.count); pass.scanCompute.setupDispatch(pass.dispatchX, pass.dispatchY, 1); device.computeDispatch([pass.scanCompute], "PrefixSumScan"); } for (let i = this.passes.length - 1; i >= 0; i--) { const pass = this.passes[i]; if (pass.addBlockCompute) { pass.addBlockCompute.setParameter("elementCount", pass.count); pass.addBlockCompute.setupDispatch(pass.dispatchX, pass.dispatchY, 1); device.computeDispatch([pass.addBlockCompute], "PrefixSumAddBlock"); } } } } export { PrefixSumKernel };