playcanvas
Version:
Open-source WebGL/WebGPU 3D engine for the web
257 lines (254 loc) • 10.4 kB
JavaScript
import { Compute } from '../../platform/graphics/compute.js';
import { Shader } from '../../platform/graphics/shader.js';
import { StorageBuffer } from '../../platform/graphics/storage-buffer.js';
import { BindGroupFormat, BindStorageBufferFormat, BindUniformBufferFormat } from '../../platform/graphics/bind-group-format.js';
import { UniformBufferFormat, UniformFormat } from '../../platform/graphics/uniform-buffer-format.js';
import { UNIFORMTYPE_UINT, SHADERSTAGE_COMPUTE, SHADERLANGUAGE_WGSL } from '../../platform/graphics/constants.js';
import { prefixSumSource } from '../shader-lib/wgsl/chunks/radix-sort/compute-prefix-sum.js';
/**
* @import { GraphicsDevice } from '../../platform/graphics/graphics-device.js'
*/ // Workgroup configuration
const WORKGROUP_SIZE_X = 16;
const WORKGROUP_SIZE_Y = 16;
const THREADS_PER_WORKGROUP = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; // 256
const ITEMS_PER_WORKGROUP = 2 * THREADS_PER_WORKGROUP; // 512 (2 items per thread)
/**
* Helper class for recursive parallel prefix sum (scan) operations.
* Uses Blelloch algorithm with up-sweep and down-sweep phases.
*
* @ignore
*/ class PrefixSumKernel {
/**
* Destroys the kernel and releases resources.
*/ destroy() {
this.destroyPasses();
this._scanShader?.destroy();
this._addBlockShader?.destroy();
this._bindGroupFormat?.destroy();
this._scanShader = null;
this._addBlockShader = null;
this._bindGroupFormat = null;
this._uniformBufferFormat = null;
}
/**
* Creates bind group format and shaders (called once in constructor).
*
* @private
*/ _createFormatsAndShaders() {
// Create uniform buffer format
this._uniformBufferFormat = new UniformBufferFormat(this.device, [
new UniformFormat('elementCount', UNIFORMTYPE_UINT)
]);
// Create bind group format with uniform buffer
this._bindGroupFormat = new BindGroupFormat(this.device, [
new BindStorageBufferFormat('items', SHADERSTAGE_COMPUTE, false),
new BindStorageBufferFormat('blockSums', SHADERSTAGE_COMPUTE, false),
new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE)
]);
// Create shaders
this._scanShader = this._createShader('PrefixSumScan', 'reduce_downsweep');
this._addBlockShader = this._createShader('PrefixSumAddBlock', 'add_block_sums');
}
/**
* Recursively creates passes for the prefix sum.
*
* @param {StorageBuffer} dataBuffer - Buffer containing data to scan.
* @param {number} count - Number of elements.
* @private
*/ createPassesRecursive(dataBuffer, count) {
const workgroupCount = Math.ceil(count / ITEMS_PER_WORKGROUP);
const { x: dispatchX, y: dispatchY } = this.findOptimalDispatchSize(workgroupCount);
// Create buffer for block sums
const blockSumBuffer = new StorageBuffer(this.device, workgroupCount * 4);
// Create scan compute instance using shared shader
const scanCompute = new Compute(this.device, this._scanShader, 'PrefixSumScan');
scanCompute.setParameter('items', dataBuffer);
scanCompute.setParameter('blockSums', blockSumBuffer);
const pass = {
scanCompute,
addBlockCompute: null,
blockSumBuffer,
dispatchX,
dispatchY,
count,
allocatedCount: count
};
this.passes.push(pass);
if (workgroupCount > 1) {
// Recursively create prefix sum on block sums
this.createPassesRecursive(blockSumBuffer, workgroupCount);
// Create add_block compute instance using shared shader
const addBlockCompute = new Compute(this.device, this._addBlockShader, 'PrefixSumAddBlock');
addBlockCompute.setParameter('items', dataBuffer);
addBlockCompute.setParameter('blockSums', blockSumBuffer);
pass.addBlockCompute = addBlockCompute;
}
}
/**
* Creates a shader for prefix sum operations.
*
* @param {string} name - Shader name.
* @param {string} entryPoint - Entry point function name.
* @returns {Shader} The created shader.
* @private
*/ _createShader(name, entryPoint) {
// Build defines map with {VARIABLE} keys for preprocessor injection
const cdefines = new Map();
cdefines.set('{WORKGROUP_SIZE_X}', WORKGROUP_SIZE_X);
cdefines.set('{WORKGROUP_SIZE_Y}', WORKGROUP_SIZE_Y);
cdefines.set('{THREADS_PER_WORKGROUP}', THREADS_PER_WORKGROUP);
cdefines.set('{ITEMS_PER_WORKGROUP}', ITEMS_PER_WORKGROUP);
return new Shader(this.device, {
name: name,
shaderLanguage: SHADERLANGUAGE_WGSL,
cshader: prefixSumSource,
cdefines: cdefines,
computeEntryPoint: entryPoint,
computeBindGroupFormat: this._bindGroupFormat,
computeUniformBufferFormats: {
uniforms: this._uniformBufferFormat
}
});
}
/**
* Find optimal dispatch dimensions to minimize unused workgroups.
*
* @param {number} workgroupCount - Total workgroups needed.
* @returns {{x: number, y: number}} Dispatch dimensions.
* @private
*/ findOptimalDispatchSize(workgroupCount) {
const maxDimension = this.device.limits.maxComputeWorkgroupsPerDimension || 65535;
if (workgroupCount <= maxDimension) {
return {
x: workgroupCount,
y: 1
};
}
const x = Math.floor(Math.sqrt(workgroupCount));
const y = Math.ceil(workgroupCount / x);
return {
x,
y
};
}
/**
* Resizes the kernel for a new element count. Grows capacity internally if needed.
*
* @param {StorageBuffer} dataBuffer - The buffer to perform prefix sum on.
* @param {number} count - New element count.
*/ resize(dataBuffer, count) {
// Check if we need more passes (count grew beyond current capacity)
const requiredPasses = this._countPassesNeeded(count);
const currentPasses = this.passes.length;
if (requiredPasses > currentPasses) {
// Need more passes - destroy old and recreate with new capacity
this.destroyPasses();
this.createPassesRecursive(dataBuffer, count);
return;
}
// Update counts for each pass level (shrinking or same size)
let levelCount = count;
for(let i = 0; i < this.passes.length; i++){
const workgroupCount = Math.ceil(levelCount / ITEMS_PER_WORKGROUP);
const { x: dispatchX, y: dispatchY } = this.findOptimalDispatchSize(workgroupCount);
this.passes[i].count = levelCount;
this.passes[i].dispatchX = dispatchX;
this.passes[i].dispatchY = dispatchY;
levelCount = workgroupCount;
// If this level doesn't need block sums anymore, stop
if (workgroupCount <= 1) {
break;
}
}
}
/**
* Destroys passes but keeps shaders and formats.
*
* @ignore
*/ destroyPasses() {
for (const pass of this.passes){
pass.blockSumBuffer?.destroy();
}
this.passes.length = 0;
}
/**
* Counts how many recursive passes are needed for a given element count.
*
* @param {number} count - Element count.
* @returns {number} Number of passes needed.
* @private
*/ _countPassesNeeded(count) {
let passes = 0;
let levelCount = count;
while(levelCount > 0){
passes++;
const workgroupCount = Math.ceil(levelCount / ITEMS_PER_WORKGROUP);
if (workgroupCount <= 1) break;
levelCount = workgroupCount;
}
return passes;
}
/**
* Dispatches all prefix sum passes.
*
* @param {GraphicsDevice} device - The graphics device.
*/ dispatch(device) {
// Process all passes in order
for(let i = 0; i < this.passes.length; i++){
const pass = this.passes[i];
// Set element count uniform for this pass level
pass.scanCompute.setParameter('elementCount', pass.count);
pass.scanCompute.setupDispatch(pass.dispatchX, pass.dispatchY, 1);
device.computeDispatch([
pass.scanCompute
], 'PrefixSumScan');
}
// Add block sums in reverse order (skip the last level which has no add_block)
for(let i = this.passes.length - 1; i >= 0; i--){
const pass = this.passes[i];
if (pass.addBlockCompute) {
// Set element count uniform for this pass level
pass.addBlockCompute.setParameter('elementCount', pass.count);
pass.addBlockCompute.setupDispatch(pass.dispatchX, pass.dispatchY, 1);
device.computeDispatch([
pass.addBlockCompute
], 'PrefixSumAddBlock');
}
}
}
/**
* Creates a new PrefixSumKernel instance.
* Call resize() to initialize passes with the desired count.
*
* @param {GraphicsDevice} device - The graphics device.
*/ constructor(device){
/**
* List of pipeline passes (scan + add_block for each level).
*
* @type {Array<{scanCompute: Compute, addBlockCompute: Compute|null, blockSumBuffer: StorageBuffer, dispatchX: number, dispatchY: number, count: number, allocatedCount: number}>}
*/ this.passes = [];
/**
* Uniform buffer format (shared across all passes).
*
* @type {UniformBufferFormat|null}
*/ this._uniformBufferFormat = null;
/**
* Bind group format (shared across all passes).
*
* @type {BindGroupFormat|null}
*/ this._bindGroupFormat = null;
/**
* Scan shader (shared, element count is a uniform).
*
* @type {Shader|null}
*/ this._scanShader = null;
/**
* Add block shader (shared, element count is a uniform).
*
* @type {Shader|null}
*/ this._addBlockShader = null;
this.device = device;
this._createFormatsAndShaders();
}
}
export { PrefixSumKernel };