playcanvas
Version:
Open-source WebGL/WebGPU 3D engine for the web
164 lines (161 loc) • 6.02 kB
JavaScript
import { Compute } from '../../platform/graphics/compute.js';
import { Shader } from '../../platform/graphics/shader.js';
import { StorageBuffer } from '../../platform/graphics/storage-buffer.js';
import { BindGroupFormat, BindStorageBufferFormat, BindUniformBufferFormat } from '../../platform/graphics/bind-group-format.js';
import { UniformBufferFormat, UniformFormat } from '../../platform/graphics/uniform-buffer-format.js';
import { UNIFORMTYPE_UINT, SHADERSTAGE_COMPUTE, SHADERLANGUAGE_WGSL } from '../../platform/graphics/constants.js';
import { prefixSumSource } from '../shader-lib/wgsl/chunks/radix-sort/compute-prefix-sum.js';
const WORKGROUP_SIZE_X = 16;
const WORKGROUP_SIZE_Y = 16;
const THREADS_PER_WORKGROUP = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y;
const ITEMS_PER_WORKGROUP = 2 * THREADS_PER_WORKGROUP;
class PrefixSumKernel {
destroy() {
this.destroyPasses();
this._scanShader?.destroy();
this._addBlockShader?.destroy();
this._bindGroupFormat?.destroy();
this._scanShader = null;
this._addBlockShader = null;
this._bindGroupFormat = null;
this._uniformBufferFormat = null;
}
_createFormatsAndShaders() {
this._uniformBufferFormat = new UniformBufferFormat(this.device, [
new UniformFormat('elementCount', UNIFORMTYPE_UINT)
]);
this._bindGroupFormat = new BindGroupFormat(this.device, [
new BindStorageBufferFormat('items', SHADERSTAGE_COMPUTE, false),
new BindStorageBufferFormat('blockSums', SHADERSTAGE_COMPUTE, false),
new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE)
]);
this._scanShader = this._createShader('PrefixSumScan', 'reduce_downsweep');
this._addBlockShader = this._createShader('PrefixSumAddBlock', 'add_block_sums');
}
createPassesRecursive(dataBuffer, count) {
const workgroupCount = Math.ceil(count / ITEMS_PER_WORKGROUP);
const { x: dispatchX, y: dispatchY } = this.findOptimalDispatchSize(workgroupCount);
const blockSumBuffer = new StorageBuffer(this.device, workgroupCount * 4);
const scanCompute = new Compute(this.device, this._scanShader, 'PrefixSumScan');
scanCompute.setParameter('items', dataBuffer);
scanCompute.setParameter('blockSums', blockSumBuffer);
const pass = {
scanCompute,
addBlockCompute: null,
blockSumBuffer,
dispatchX,
dispatchY,
count,
allocatedCount: count
};
this.passes.push(pass);
if (workgroupCount > 1) {
this.createPassesRecursive(blockSumBuffer, workgroupCount);
const addBlockCompute = new Compute(this.device, this._addBlockShader, 'PrefixSumAddBlock');
addBlockCompute.setParameter('items', dataBuffer);
addBlockCompute.setParameter('blockSums', blockSumBuffer);
pass.addBlockCompute = addBlockCompute;
}
}
_createShader(name, entryPoint) {
const cdefines = new Map();
cdefines.set('{WORKGROUP_SIZE_X}', WORKGROUP_SIZE_X);
cdefines.set('{WORKGROUP_SIZE_Y}', WORKGROUP_SIZE_Y);
cdefines.set('{THREADS_PER_WORKGROUP}', THREADS_PER_WORKGROUP);
cdefines.set('{ITEMS_PER_WORKGROUP}', ITEMS_PER_WORKGROUP);
return new Shader(this.device, {
name: name,
shaderLanguage: SHADERLANGUAGE_WGSL,
cshader: prefixSumSource,
cdefines: cdefines,
computeEntryPoint: entryPoint,
computeBindGroupFormat: this._bindGroupFormat,
computeUniformBufferFormats: {
uniforms: this._uniformBufferFormat
}
});
}
findOptimalDispatchSize(workgroupCount) {
const maxDimension = this.device.limits.maxComputeWorkgroupsPerDimension || 65535;
if (workgroupCount <= maxDimension) {
return {
x: workgroupCount,
y: 1
};
}
const x = Math.floor(Math.sqrt(workgroupCount));
const y = Math.ceil(workgroupCount / x);
return {
x,
y
};
}
resize(dataBuffer, count) {
const requiredPasses = this._countPassesNeeded(count);
const currentPasses = this.passes.length;
if (requiredPasses > currentPasses) {
this.destroyPasses();
this.createPassesRecursive(dataBuffer, count);
return;
}
let levelCount = count;
for(let i = 0; i < this.passes.length; i++){
const workgroupCount = Math.ceil(levelCount / ITEMS_PER_WORKGROUP);
const { x: dispatchX, y: dispatchY } = this.findOptimalDispatchSize(workgroupCount);
this.passes[i].count = levelCount;
this.passes[i].dispatchX = dispatchX;
this.passes[i].dispatchY = dispatchY;
levelCount = workgroupCount;
if (workgroupCount <= 1) {
break;
}
}
}
destroyPasses() {
for (const pass of this.passes){
pass.blockSumBuffer?.destroy();
}
this.passes.length = 0;
}
_countPassesNeeded(count) {
let passes = 0;
let levelCount = count;
while(levelCount > 0){
passes++;
const workgroupCount = Math.ceil(levelCount / ITEMS_PER_WORKGROUP);
if (workgroupCount <= 1) break;
levelCount = workgroupCount;
}
return passes;
}
dispatch(device) {
for(let i = 0; i < this.passes.length; i++){
const pass = this.passes[i];
pass.scanCompute.setParameter('elementCount', pass.count);
pass.scanCompute.setupDispatch(pass.dispatchX, pass.dispatchY, 1);
device.computeDispatch([
pass.scanCompute
], 'PrefixSumScan');
}
for(let i = this.passes.length - 1; i >= 0; i--){
const pass = this.passes[i];
if (pass.addBlockCompute) {
pass.addBlockCompute.setParameter('elementCount', pass.count);
pass.addBlockCompute.setupDispatch(pass.dispatchX, pass.dispatchY, 1);
device.computeDispatch([
pass.addBlockCompute
], 'PrefixSumAddBlock');
}
}
}
constructor(device){
this.passes = [];
this._uniformBufferFormat = null;
this._bindGroupFormat = null;
this._scanShader = null;
this._addBlockShader = null;
this.device = device;
this._createFormatsAndShaders();
}
}
export { PrefixSumKernel };