playcanvas
Version:
Open-source WebGL/WebGPU 3D engine for the web
256 lines (253 loc) • 12.2 kB
JavaScript
import { Vec2 } from '../../core/math/vec2.js';
import { Compute } from '../../platform/graphics/compute.js';
import { Shader } from '../../platform/graphics/shader.js';
import { StorageBuffer } from '../../platform/graphics/storage-buffer.js';
import { BindStorageBufferFormat, BindUniformBufferFormat, BindGroupFormat } from '../../platform/graphics/bind-group-format.js';
import { UniformBufferFormat, UniformFormat } from '../../platform/graphics/uniform-buffer-format.js';
import { SHADERSTAGE_COMPUTE, BUFFERUSAGE_COPY_SRC, BUFFERUSAGE_COPY_DST, SHADERLANGUAGE_WGSL, UNIFORMTYPE_UINT } from '../../platform/graphics/constants.js';
import { PrefixSumKernel } from './prefix-sum-kernel.js';
import { radixSort4bitSource } from '../shader-lib/wgsl/chunks/radix-sort/compute-radix-sort-4bit.js';
import { radixSortReorderSource } from '../shader-lib/wgsl/chunks/radix-sort/compute-radix-sort-reorder.js';
const BITS_PER_PASS = 4;
const BUCKET_COUNT = 16;
const WORKGROUP_SIZE_X = 16;
const WORKGROUP_SIZE_Y = 16;
const THREADS_PER_WORKGROUP = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y;
const ELEMENTS_PER_THREAD = 8;
const ELEMENTS_PER_WORKGROUP = THREADS_PER_WORKGROUP * ELEMENTS_PER_THREAD;
class ComputeRadixSort {
destroy() {
this._destroyBuffers();
this._destroyPasses();
this._histogramBindGroupFormat?.destroy();
this._reorderBindGroupFormat?.destroy();
this._histogramBindGroupFormat = null;
this._reorderBindGroupFormat = null;
this._uniformBufferFormat = null;
}
_destroyPasses() {
for (const pass of this._passes){
pass.histogramCompute.shader?.destroy();
pass.reorderCompute.shader?.destroy();
}
this._passes.length = 0;
this._numBits = 0;
}
_destroyBuffers() {
this._keys0?.destroy();
this._keys1?.destroy();
this._values0?.destroy();
this._values1?.destroy();
this._blockSums?.destroy();
this._sortedIndices?.destroy();
this._prefixSumKernel?.destroy();
this._keys0 = null;
this._keys1 = null;
this._values0 = null;
this._values1 = null;
this._blockSums = null;
this._sortedIndices = null;
this._prefixSumKernel = null;
this._workgroupCount = 0;
this._allocatedWorkgroupCount = 0;
}
get sortedIndices() {
return this._sortedIndices;
}
get sortedKeys() {
if (!this._keys0) return null;
const numPasses = this._numBits / BITS_PER_PASS;
return numPasses % 2 === 0 ? this._keys1 : this._keys0;
}
_ensureBindGroupFormats(indirect) {
if (this._histogramBindGroupFormat && this._indirect === indirect) {
return;
}
this._histogramBindGroupFormat?.destroy();
this._reorderBindGroupFormat?.destroy();
const device = this.device;
const histogramEntries = [
new BindStorageBufferFormat('input', SHADERSTAGE_COMPUTE, true),
new BindStorageBufferFormat('block_sums', SHADERSTAGE_COMPUTE, false),
new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE)
];
const reorderEntries = [
new BindStorageBufferFormat('inputKeys', SHADERSTAGE_COMPUTE, true),
new BindStorageBufferFormat('outputKeys', SHADERSTAGE_COMPUTE, false),
new BindStorageBufferFormat('prefix_block_sum', SHADERSTAGE_COMPUTE, true),
new BindStorageBufferFormat('inputValues', SHADERSTAGE_COMPUTE, true),
new BindStorageBufferFormat('outputValues', SHADERSTAGE_COMPUTE, false),
new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE)
];
if (indirect) {
histogramEntries.push(new BindStorageBufferFormat('sortElementCount', SHADERSTAGE_COMPUTE, true));
reorderEntries.push(new BindStorageBufferFormat('sortElementCount', SHADERSTAGE_COMPUTE, true));
}
this._histogramBindGroupFormat = new BindGroupFormat(device, histogramEntries);
this._reorderBindGroupFormat = new BindGroupFormat(device, reorderEntries);
}
_createPasses(numBits, indirect, hasInitialValues, skipLastPassKeyWrite) {
this._destroyPasses();
this._numBits = numBits;
this._ensureBindGroupFormats(indirect);
this._indirect = indirect;
this._hasInitialValues = hasInitialValues;
this._skipLastPassKeyWrite = skipLastPassKeyWrite;
const numPasses = numBits / BITS_PER_PASS;
const suffix = indirect ? '-Indirect' : '';
for(let pass = 0; pass < numPasses; pass++){
const bitOffset = pass * BITS_PER_PASS;
const isFirstPass = pass === 0 && !hasInitialValues;
const isLastPass = skipLastPassKeyWrite && pass === numPasses - 1;
const histogramShader = this._createShader(`RadixSort4bit-Histogram${suffix}-${bitOffset}`, radixSort4bitSource, bitOffset, false, false, this._histogramBindGroupFormat, indirect);
const reorderShader = this._createShader(`RadixSort4bit-Reorder${suffix}-${bitOffset}`, radixSortReorderSource, bitOffset, isFirstPass, isLastPass, this._reorderBindGroupFormat, indirect);
const histogramCompute = new Compute(this.device, histogramShader, `RadixSort4bit-Histogram${suffix}-${bitOffset}`);
const reorderCompute = new Compute(this.device, reorderShader, `RadixSort4bit-Reorder${suffix}-${bitOffset}`);
this._passes.push({
histogramCompute,
reorderCompute
});
}
}
_allocateBuffers(elementCount, numBits, indirect, hasInitialValues, skipLastPassKeyWrite) {
const effectiveCount = Math.max(elementCount, this.capacity);
const workgroupCount = Math.ceil(effectiveCount / ELEMENTS_PER_WORKGROUP);
const buffersNeedRealloc = workgroupCount !== this._allocatedWorkgroupCount || !this._keys0;
const passesNeedRecreate = numBits !== this._numBits || indirect !== this._indirect || hasInitialValues !== this._hasInitialValues || skipLastPassKeyWrite !== this._skipLastPassKeyWrite;
if (buffersNeedRealloc) {
this._destroyBuffers();
this._allocatedWorkgroupCount = workgroupCount;
this.capacity = effectiveCount;
const elementSize = effectiveCount * 4;
const blockSumSize = BUCKET_COUNT * workgroupCount * 4;
this._keys0 = new StorageBuffer(this.device, elementSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST);
this._keys1 = new StorageBuffer(this.device, elementSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST);
this._values0 = new StorageBuffer(this.device, elementSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST);
this._values1 = new StorageBuffer(this.device, elementSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST);
this._blockSums = new StorageBuffer(this.device, blockSumSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST);
this._sortedIndices = new StorageBuffer(this.device, elementSize, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST);
this._prefixSumKernel = new PrefixSumKernel(this.device);
}
this._workgroupCount = workgroupCount;
Compute.calcDispatchSize(workgroupCount, this._dispatchSize, this.device.limits.maxComputeWorkgroupsPerDimension || 65535);
this._prefixSumKernel.resize(this._blockSums, BUCKET_COUNT * workgroupCount);
if (passesNeedRecreate) {
this._createPasses(numBits, indirect, hasInitialValues, skipLastPassKeyWrite);
}
}
_createShader(name, source, currentBit, isFirstPass, isLastPass, bindGroupFormat, indirect) {
const cdefines = new Map();
cdefines.set('{WORKGROUP_SIZE_X}', WORKGROUP_SIZE_X);
cdefines.set('{WORKGROUP_SIZE_Y}', WORKGROUP_SIZE_Y);
cdefines.set('{THREADS_PER_WORKGROUP}', THREADS_PER_WORKGROUP);
cdefines.set('{ELEMENTS_PER_THREAD}', ELEMENTS_PER_THREAD);
cdefines.set('{CURRENT_BIT}', currentBit);
cdefines.set('{IS_FIRST_PASS}', isFirstPass ? 1 : 0);
cdefines.set('{IS_LAST_PASS}', isLastPass ? 1 : 0);
if (indirect) {
cdefines.set('USE_INDIRECT_SORT', '');
}
return new Shader(this.device, {
name: name,
shaderLanguage: SHADERLANGUAGE_WGSL,
cshader: source,
cdefines: cdefines,
computeBindGroupFormat: bindGroupFormat,
computeUniformBufferFormats: {
uniforms: this._uniformBufferFormat
}
});
}
sort(keysBuffer, elementCount, numBits = 16, initialValues, skipLastPassKeyWrite = false) {
return this._execute(keysBuffer, elementCount, numBits, false, -1, null, initialValues, skipLastPassKeyWrite);
}
sortIndirect(keysBuffer, maxElementCount, numBits, dispatchSlot, sortElementCountBuffer, initialValues, skipLastPassKeyWrite = false) {
return this._execute(keysBuffer, maxElementCount, numBits, true, dispatchSlot, sortElementCountBuffer, initialValues, skipLastPassKeyWrite);
}
_execute(keysBuffer, elementCount, numBits, indirect, dispatchSlot, sortElementCountBuffer, initialValues, skipLastPassKeyWrite = false) {
this._elementCount = elementCount;
const hasInitialValues = !!initialValues;
this._allocateBuffers(elementCount, numBits, indirect, hasInitialValues, skipLastPassKeyWrite);
const device = this.device;
const numPasses = numBits / BITS_PER_PASS;
const suffix = indirect ? '-Indirect' : '';
let currentKeys = keysBuffer;
let currentValues = initialValues ?? this._values0;
let nextKeys = this._keys0;
let nextValues = this._values1;
for(let pass = 0; pass < numPasses; pass++){
const { histogramCompute, reorderCompute } = this._passes[pass];
const isLastPass = pass === numPasses - 1;
if (indirect) {
this._blockSums.clear();
}
histogramCompute.setParameter('input', currentKeys);
histogramCompute.setParameter('block_sums', this._blockSums);
histogramCompute.setParameter('workgroupCount', this._workgroupCount);
histogramCompute.setParameter('elementCount', elementCount);
if (indirect) {
histogramCompute.setParameter('sortElementCount', sortElementCountBuffer);
histogramCompute.setupIndirectDispatch(dispatchSlot);
} else {
histogramCompute.setupDispatch(this._dispatchSize.x, this._dispatchSize.y, 1);
}
device.computeDispatch([
histogramCompute
], `RadixSort-Histogram${suffix}`);
this._prefixSumKernel.dispatch(device);
const outputValues = isLastPass ? this._sortedIndices : nextValues;
reorderCompute.setParameter('inputKeys', currentKeys);
reorderCompute.setParameter('outputKeys', nextKeys);
reorderCompute.setParameter('prefix_block_sum', this._blockSums);
reorderCompute.setParameter('inputValues', currentValues);
reorderCompute.setParameter('outputValues', outputValues);
reorderCompute.setParameter('workgroupCount', this._workgroupCount);
reorderCompute.setParameter('elementCount', elementCount);
if (indirect) {
reorderCompute.setParameter('sortElementCount', sortElementCountBuffer);
reorderCompute.setupIndirectDispatch(dispatchSlot);
} else {
reorderCompute.setupDispatch(this._dispatchSize.x, this._dispatchSize.y, 1);
}
device.computeDispatch([
reorderCompute
], `RadixSort-Reorder${suffix}`);
if (!isLastPass) {
currentKeys = nextKeys;
nextKeys = currentKeys === this._keys0 ? this._keys1 : this._keys0;
const tempValues = currentValues;
currentValues = nextValues;
nextValues = tempValues;
}
}
return this._sortedIndices;
}
constructor(device){
this._elementCount = 0;
this._workgroupCount = 0;
this._allocatedWorkgroupCount = 0;
this.capacity = 0;
this._numBits = 0;
this._keys0 = null;
this._keys1 = null;
this._values0 = null;
this._values1 = null;
this._blockSums = null;
this._sortedIndices = null;
this._prefixSumKernel = null;
this._dispatchSize = new Vec2(1, 1);
this._histogramBindGroupFormat = null;
this._reorderBindGroupFormat = null;
this._uniformBufferFormat = null;
this._passes = [];
this._indirect = false;
this._hasInitialValues = false;
this._skipLastPassKeyWrite = false;
this.device = device;
this._uniformBufferFormat = new UniformBufferFormat(device, [
new UniformFormat('workgroupCount', UNIFORMTYPE_UINT),
new UniformFormat('elementCount', UNIFORMTYPE_UINT)
]);
}
}
export { ComputeRadixSort, ELEMENTS_PER_WORKGROUP as RADIX_SORT_ELEMENTS_PER_WORKGROUP };