playcanvas
Version:
Open-source WebGL/WebGPU 3D engine for the web
153 lines (152 loc) • 5.62 kB
JavaScript
import { Compute } from "../../platform/graphics/compute.js";
import { Shader } from "../../platform/graphics/shader.js";
import { StorageBuffer } from "../../platform/graphics/storage-buffer.js";
import { BindGroupFormat, BindStorageBufferFormat, BindUniformBufferFormat } from "../../platform/graphics/bind-group-format.js";
import { UniformBufferFormat, UniformFormat } from "../../platform/graphics/uniform-buffer-format.js";
import { SHADERLANGUAGE_WGSL, SHADERSTAGE_COMPUTE, UNIFORMTYPE_UINT } from "../../platform/graphics/constants.js";
import { prefixSumSource } from "../shader-lib/wgsl/chunks/radix-sort/compute-prefix-sum.js";
const WORKGROUP_SIZE_X = 16;
const WORKGROUP_SIZE_Y = 16;
const THREADS_PER_WORKGROUP = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y;
const ITEMS_PER_WORKGROUP = 2 * THREADS_PER_WORKGROUP;
class PrefixSumKernel {
device;
passes = [];
_uniformBufferFormat = null;
_bindGroupFormat = null;
_scanShader = null;
_addBlockShader = null;
constructor(device) {
this.device = device;
this._createFormatsAndShaders();
}
destroy() {
this.destroyPasses();
this._scanShader?.destroy();
this._addBlockShader?.destroy();
this._bindGroupFormat?.destroy();
this._scanShader = null;
this._addBlockShader = null;
this._bindGroupFormat = null;
this._uniformBufferFormat = null;
}
_createFormatsAndShaders() {
this._uniformBufferFormat = new UniformBufferFormat(this.device, [
new UniformFormat("elementCount", UNIFORMTYPE_UINT)
]);
this._bindGroupFormat = new BindGroupFormat(this.device, [
new BindStorageBufferFormat("items", SHADERSTAGE_COMPUTE, false),
new BindStorageBufferFormat("blockSums", SHADERSTAGE_COMPUTE, false),
new BindUniformBufferFormat("uniforms", SHADERSTAGE_COMPUTE)
]);
this._scanShader = this._createShader("PrefixSumScan", "reduce_downsweep");
this._addBlockShader = this._createShader("PrefixSumAddBlock", "add_block_sums");
}
createPassesRecursive(dataBuffer, count) {
const workgroupCount = Math.ceil(count / ITEMS_PER_WORKGROUP);
const { x: dispatchX, y: dispatchY } = this.findOptimalDispatchSize(workgroupCount);
const blockSumBuffer = new StorageBuffer(this.device, workgroupCount * 4);
const scanCompute = new Compute(this.device, this._scanShader, "PrefixSumScan");
scanCompute.setParameter("items", dataBuffer);
scanCompute.setParameter("blockSums", blockSumBuffer);
const pass = {
scanCompute,
addBlockCompute: null,
blockSumBuffer,
dispatchX,
dispatchY,
count,
allocatedCount: count
};
this.passes.push(pass);
if (workgroupCount > 1) {
this.createPassesRecursive(blockSumBuffer, workgroupCount);
const addBlockCompute = new Compute(this.device, this._addBlockShader, "PrefixSumAddBlock");
addBlockCompute.setParameter("items", dataBuffer);
addBlockCompute.setParameter("blockSums", blockSumBuffer);
pass.addBlockCompute = addBlockCompute;
}
}
_createShader(name, entryPoint) {
const cdefines = /* @__PURE__ */ new Map();
cdefines.set("{WORKGROUP_SIZE_X}", WORKGROUP_SIZE_X);
cdefines.set("{WORKGROUP_SIZE_Y}", WORKGROUP_SIZE_Y);
cdefines.set("{THREADS_PER_WORKGROUP}", THREADS_PER_WORKGROUP);
cdefines.set("{ITEMS_PER_WORKGROUP}", ITEMS_PER_WORKGROUP);
return new Shader(this.device, {
name,
shaderLanguage: SHADERLANGUAGE_WGSL,
cshader: prefixSumSource,
cdefines,
computeEntryPoint: entryPoint,
computeBindGroupFormat: this._bindGroupFormat,
computeUniformBufferFormats: { uniforms: this._uniformBufferFormat }
});
}
findOptimalDispatchSize(workgroupCount) {
const maxDimension = this.device.limits.maxComputeWorkgroupsPerDimension || 65535;
if (workgroupCount <= maxDimension) {
return { x: workgroupCount, y: 1 };
}
const x = Math.floor(Math.sqrt(workgroupCount));
const y = Math.ceil(workgroupCount / x);
return { x, y };
}
resize(dataBuffer, count) {
const requiredPasses = this._countPassesNeeded(count);
const currentPasses = this.passes.length;
if (requiredPasses > currentPasses) {
this.destroyPasses();
this.createPassesRecursive(dataBuffer, count);
return;
}
let levelCount = count;
for (let i = 0; i < this.passes.length; i++) {
const workgroupCount = Math.ceil(levelCount / ITEMS_PER_WORKGROUP);
const { x: dispatchX, y: dispatchY } = this.findOptimalDispatchSize(workgroupCount);
this.passes[i].count = levelCount;
this.passes[i].dispatchX = dispatchX;
this.passes[i].dispatchY = dispatchY;
levelCount = workgroupCount;
if (workgroupCount <= 1) {
break;
}
}
}
destroyPasses() {
for (const pass of this.passes) {
pass.blockSumBuffer?.destroy();
}
this.passes.length = 0;
}
_countPassesNeeded(count) {
let passes = 0;
let levelCount = count;
while (levelCount > 0) {
passes++;
const workgroupCount = Math.ceil(levelCount / ITEMS_PER_WORKGROUP);
if (workgroupCount <= 1) break;
levelCount = workgroupCount;
}
return passes;
}
dispatch(device) {
for (let i = 0; i < this.passes.length; i++) {
const pass = this.passes[i];
pass.scanCompute.setParameter("elementCount", pass.count);
pass.scanCompute.setupDispatch(pass.dispatchX, pass.dispatchY, 1);
device.computeDispatch([pass.scanCompute], "PrefixSumScan");
}
for (let i = this.passes.length - 1; i >= 0; i--) {
const pass = this.passes[i];
if (pass.addBlockCompute) {
pass.addBlockCompute.setParameter("elementCount", pass.count);
pass.addBlockCompute.setupDispatch(pass.dispatchX, pass.dispatchY, 1);
device.computeDispatch([pass.addBlockCompute], "PrefixSumAddBlock");
}
}
}
}
export {
PrefixSumKernel
};