playcanvas
Version:
Open-source WebGL/WebGPU 3D engine for the web
313 lines (312 loc) • 12.3 kB
TypeScript
/**
* A compute-based GPU radix sort implementation using 4-bit radix (16 buckets).
* Provides stable sorting of 32-bit unsigned integer keys, returning sorted indices.
* WebGPU only.
*
* **Performance characteristics:**
* - 4 passes for 16-bit keys, 8 passes for 32-bit keys
* - Each pass processes 4 bits (16 buckets)
* - Workgroup size: 16x16 = 256 threads, 8 elements per thread = 2048 elements/workgroup
*
* **Algorithm (per pass):**
* 1. **Histogram**: Each thread extracts 4-bit digits from its elements and
* contributes to a per-workgroup histogram using shared memory atomics.
* 2. **Prefix Sum**: Hierarchical Blelloch scan on block histograms to compute
* global offsets for each (digit, workgroup) pair.
* 3. **Ranked Scatter**: Re-reads keys in rounds, computes local ranks using
* per-digit 256-bit bitmasks and hardware popcount, then scatters using:
* `position = global_prefix[digit][workgroup] + cumulative_local_rank`
*
* Based on "Fast 4-way parallel radix sorting on GPUs" algorithm, implemented
* following [WebGPU-Radix-Sort](https://github.com/kishimisu/WebGPU-Radix-Sort)
* by kishimisu (MIT License).
*
* @example
* // Create the radix sort instance (reusable)
* const radixSort = new ComputeRadixSort(device);
*
* // Create a storage buffer with keys to sort
* const keys = new Uint32Array([5, 2, 8, 1, 9, 3]);
* const keysBuffer = new StorageBuffer(device, keys.byteLength, BUFFERUSAGE_COPY_DST);
* keysBuffer.write(keys);
*
* // Sort and get indices buffer (keys with values [5,2,8,1,9,3] → indices [3,1,5,0,2,4])
* const sortedIndices = radixSort.sort(keysBuffer, keys.length, 16); // 16-bit sort
*
* // Use sortedIndices buffer in subsequent GPU operations
* // Clean up when done
* radixSort.destroy();
*
* @category Graphics
* @ignore
*/
export class ComputeRadixSort {
/**
* Creates a new ComputeRadixSort instance.
*
* @param {GraphicsDevice} device - The graphics device (must support compute).
*/
constructor(device: GraphicsDevice);
/**
* The graphics device.
*
* @type {GraphicsDevice}
*/
device: GraphicsDevice;
/**
* Current element count.
*
* @type {number}
*/
_elementCount: number;
/**
* Number of workgroups for current sort.
*
* @type {number}
*/
_workgroupCount: number;
/**
* Allocated workgroup capacity. Tracks the last allocated size; reallocation is triggered
* when the effective workgroup count (derived from element count and capacity) differs.
*
* @type {number}
*/
_allocatedWorkgroupCount: number;
/**
* Minimum element capacity for internal buffers. When set, `_allocateBuffers` uses
* `max(elementCount, capacity)` as the effective size. The caller can lower this value
* to request shrinkage; actual reallocation is deferred to the next sort call.
* After allocation, this is updated to reflect the effective element count.
*
* @type {number}
*/
capacity: number;
/**
* Current number of bits for which passes are created.
*
* @type {number}
*/
_numBits: number;
/**
* Internal keys buffer 0 (ping-pong).
*
* @type {StorageBuffer|null}
*/
_keys0: StorageBuffer | null;
/**
* Internal keys buffer 1 (ping-pong).
*
* @type {StorageBuffer|null}
*/
_keys1: StorageBuffer | null;
/**
* Internal values/indices buffer 0 (ping-pong).
*
* @type {StorageBuffer|null}
*/
_values0: StorageBuffer | null;
/**
* Internal values/indices buffer 1 (ping-pong).
*
* @type {StorageBuffer|null}
*/
_values1: StorageBuffer | null;
/**
* Block sums buffer (16 per workgroup).
*
* @type {StorageBuffer|null}
*/
_blockSums: StorageBuffer | null;
/**
* Output sorted indices buffer.
*
* @type {StorageBuffer|null}
*/
_sortedIndices: StorageBuffer | null;
/**
* Prefix sum kernel for block sums.
*
* @type {PrefixSumKernel|null}
*/
_prefixSumKernel: PrefixSumKernel | null;
/**
* Dispatch dimensions.
*
* @type {Vec2}
*/
_dispatchSize: Vec2;
/**
* Cached bind group format for histogram shader (created lazily for current mode).
*
* @type {BindGroupFormat|null}
*/
_histogramBindGroupFormat: BindGroupFormat | null;
/**
* Cached bind group format for reorder shader (created lazily for current mode).
*
* @type {BindGroupFormat|null}
*/
_reorderBindGroupFormat: BindGroupFormat | null;
/**
* Uniform buffer format for runtime uniforms.
*
* @type {UniformBufferFormat|null}
*/
_uniformBufferFormat: UniformBufferFormat | null;
/**
* Cached compute passes. Each entry contains {histogramCompute, reorderCompute} for one pass.
*
* @type {Array<{histogramCompute: Compute, reorderCompute: Compute}>}
*/
_passes: Array<{
histogramCompute: Compute;
reorderCompute: Compute;
}>;
/**
* Whether the current passes are for indirect sort mode.
*
* @type {boolean}
*/
_indirect: boolean;
/**
* Whether the current passes expect caller-supplied initial values on pass 0.
*
* @type {boolean}
*/
_hasInitialValues: boolean;
/**
* Whether the last pass skips writing sorted keys (only values are written).
* When true, `sortedKeys` will contain stale data after sorting.
*
* @type {boolean}
*/
_skipLastPassKeyWrite: boolean;
/**
* Destroys the ComputeRadixSort instance and releases all resources.
*/
destroy(): void;
/**
* Destroys all cached passes and their shaders.
*
* @private
*/
private _destroyPasses;
/**
* Destroys internal buffers (not passes or bind group formats).
*
* @private
*/
private _destroyBuffers;
/**
* Gets the sorted indices (or values) buffer.
*
* @type {StorageBuffer|null}
*/
get sortedIndices(): StorageBuffer | null;
/**
* Gets the sorted keys buffer after the last sort operation. The keys end up
* in one of the internal ping-pong buffers depending on the number of passes.
*
* @type {StorageBuffer|null}
*/
get sortedKeys(): StorageBuffer | null;
/**
* Ensures bind group formats exist for the given mode. Destroys and recreates
* them if switching between direct and indirect modes.
*
* @param {boolean} indirect - Whether to create indirect sort formats.
* @private
*/
private _ensureBindGroupFormats;
/**
* Creates cached compute passes for all bit offsets.
*
* @param {number} numBits - Number of bits to sort.
* @param {boolean} indirect - Whether to create indirect sort passes.
* @param {boolean} hasInitialValues - Whether pass 0 reads from caller-supplied initial values.
* @param {boolean} skipLastPassKeyWrite - Whether the last pass skips writing keys.
* @private
*/
private _createPasses;
/**
* Allocates or resizes internal buffers and creates passes if needed.
*
* @param {number} elementCount - Number of elements to sort.
* @param {number} numBits - Number of bits to sort.
* @param {boolean} indirect - Whether passes should use indirect dispatch.
* @param {boolean} hasInitialValues - Whether pass 0 reads caller-supplied initial values.
* @param {boolean} skipLastPassKeyWrite - Whether the last pass skips writing keys.
* @private
*/
private _allocateBuffers;
/**
* Creates a shader with constants embedded.
*
* @param {string} name - Shader name.
* @param {string} source - Shader source.
* @param {number} currentBit - Current bit offset for this pass.
* @param {boolean} isFirstPass - Whether this is the first pass (uses GID for indices).
* @param {BindGroupFormat} bindGroupFormat - Bind group format.
* @param {boolean} indirect - Whether to add the USE_INDIRECT_SORT define.
* @returns {Shader} The created shader.
* @private
*/
private _createShader;
/**
* Executes the GPU radix sort using direct dispatch.
*
* @param {StorageBuffer} keysBuffer - Input storage buffer containing u32 keys.
* @param {number} elementCount - Number of elements to sort.
* @param {number} [numBits] - Number of bits to sort (must be multiple of 4). Defaults to 16.
* @param {StorageBuffer} [initialValues] - Optional buffer of initial values for pass 0.
* When provided, the sort produces output values derived from this buffer instead of
* sequential indices. The buffer is only read, never modified.
* @param {boolean} [skipLastPassKeyWrite] - When true, the last pass skips writing sorted
* keys for a small performance gain. Only use when sorted keys are not needed after sorting.
* @returns {StorageBuffer} Storage buffer containing sorted indices (or values if
* initialValues was provided).
*/
sort(keysBuffer: StorageBuffer, elementCount: number, numBits?: number, initialValues?: StorageBuffer, skipLastPassKeyWrite?: boolean): StorageBuffer;
/**
* Executes the GPU radix sort using indirect dispatch. Only sorts `visibleCount`
* elements (GPU-written) instead of the full buffer, reducing sort cost proportionally.
*
* @param {StorageBuffer} keysBuffer - Input storage buffer containing u32 keys.
* @param {number} maxElementCount - Maximum number of elements (buffer allocation size).
* @param {number} numBits - Number of bits to sort (must be multiple of 4).
* @param {number} dispatchSlot - Slot index in the device's indirect dispatch buffer.
* @param {StorageBuffer} sortElementCountBuffer - GPU-written buffer containing visible count.
* @param {StorageBuffer} [initialValues] - Optional buffer of initial values for pass 0.
* When provided, the sort produces output values derived from this buffer instead of
* sequential indices. The buffer is only read, never modified.
* @param {boolean} [skipLastPassKeyWrite] - When true, the last pass skips writing sorted
* keys for a small performance gain. Only use when sorted keys are not needed after sorting.
* @returns {StorageBuffer} Storage buffer containing sorted values.
*/
sortIndirect(keysBuffer: StorageBuffer, maxElementCount: number, numBits: number, dispatchSlot: number, sortElementCountBuffer: StorageBuffer, initialValues?: StorageBuffer, skipLastPassKeyWrite?: boolean): StorageBuffer;
/**
* Shared execution logic for both direct and indirect radix sort.
*
* @param {StorageBuffer} keysBuffer - Input keys buffer.
* @param {number} elementCount - Number of elements (or max elements for indirect).
* @param {number} numBits - Number of bits to sort.
* @param {boolean} indirect - Whether to use indirect dispatch.
* @param {number} dispatchSlot - Indirect dispatch slot index (-1 for direct).
* @param {StorageBuffer|null} sortElementCountBuffer - GPU-written element count (null for direct).
* @param {StorageBuffer} [initialValues] - Optional initial values buffer for pass 0.
* @param {boolean} [skipLastPassKeyWrite] - When true, the last pass skips writing sorted
* keys for a small performance gain. Only use when sorted keys are not needed after sorting.
* @returns {StorageBuffer} Storage buffer containing sorted values.
* @private
*/
private _execute;
}
declare const ELEMENTS_PER_WORKGROUP: number;
import type { GraphicsDevice } from '../../platform/graphics/graphics-device.js';
import { StorageBuffer } from '../../platform/graphics/storage-buffer.js';
import { PrefixSumKernel } from './prefix-sum-kernel.js';
import { Vec2 } from '../../core/math/vec2.js';
import { BindGroupFormat } from '../../platform/graphics/bind-group-format.js';
import { UniformBufferFormat } from '../../platform/graphics/uniform-buffer-format.js';
import { Compute } from '../../platform/graphics/compute.js';
export { ELEMENTS_PER_WORKGROUP as RADIX_SORT_ELEMENTS_PER_WORKGROUP };