UNPKG

d3-force-webgpu

Version:

GPU-accelerated force-directed graph layout with adaptive CPU/GPU selection. Drop-in replacement for d3-force with WebGPU support.

184 lines (153 loc) 4.67 kB
export default function(x, y) { var nodes, device, pipeline, bindGroup, centerParamsBuffer, strength = 1; if (x == null) x = 0; if (y == null) y = 0; async function force() { if (!device || !pipeline || !bindGroup) return; const commandEncoder = device.createCommandEncoder(); const passEncoder = commandEncoder.beginComputePass(); passEncoder.setPipeline(pipeline); passEncoder.setBindGroup(0, bindGroup); passEncoder.dispatchWorkgroups(1); // Single workgroup for reduction passEncoder.end(); device.queue.submit([commandEncoder.finish()]); // Wait for completion await device.queue.onSubmittedWorkDone(); } force.initialize = function(_nodes, _random, _device, nodeBuffer) { nodes = _nodes; device = _device; const shaderCode = ` struct Node { position: vec2<f32>, velocity: vec2<f32>, fixedPosition: vec2<f32>, index: f32, _padding: f32, } struct CenterParams { center: vec2<f32>, strength: f32, nodeCount: f32, } @group(0) @binding(0) var<storage, read_write> nodes: array<Node>; @group(0) @binding(1) var<uniform> params: CenterParams; @group(0) @binding(2) var<storage, read_write> reduction: array<vec2<f32>>; var<workgroup> shared_sum: array<vec2<f32>, 64>; @compute @workgroup_size(64) fn main( @builtin(global_invocation_id) global_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>, @builtin(workgroup_id) workgroup_id: vec3<u32> ) { let tid = local_id.x; let gid = global_id.x; let nodeCount = u32(params.nodeCount); // First pass: compute center of mass if (workgroup_id.x == 0u) { var sum = vec2<f32>(0.0, 0.0); // Each thread sums multiple nodes for (var i = gid; i < nodeCount; i += 64u) { sum += nodes[i].position; } shared_sum[tid] = sum; workgroupBarrier(); // Reduction in shared memory for (var s = 32u; s > 0u; s >>= 1u) { if (tid < s) { shared_sum[tid] += shared_sum[tid + s]; } workgroupBarrier(); } // Write result if (tid == 0u) { reduction[0] = shared_sum[0] / f32(nodeCount); } } workgroupBarrier(); storageBarrier(); // Second pass: apply centering force if (gid < nodeCount) { let center_of_mass = reduction[0]; let delta = (params.center - center_of_mass) * params.strength; nodes[gid].position += delta; } }`; const shaderModule = device.createShaderModule({ label: 'Center Force Shader', code: shaderCode }); centerParamsBuffer = device.createBuffer({ label: 'Center Parameters', size: 4 * 4, // vec2 + 2 floats usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, mappedAtCreation: true }); new Float32Array(centerParamsBuffer.getMappedRange()).set([ x, y, strength, nodes.length ]); centerParamsBuffer.unmap(); const reductionBuffer = device.createBuffer({ label: 'Reduction Buffer', size: 8, // vec2 usage: GPUBufferUsage.STORAGE }); const bindGroupLayout = device.createBindGroupLayout({ entries: [ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } } ] }); pipeline = device.createComputePipeline({ layout: device.createPipelineLayout({ bindGroupLayouts: [bindGroupLayout] }), compute: { module: shaderModule, entryPoint: 'main' } }); bindGroup = device.createBindGroup({ layout: bindGroupLayout, entries: [ { binding: 0, resource: { buffer: nodeBuffer } }, { binding: 1, resource: { buffer: centerParamsBuffer } }, { binding: 2, resource: { buffer: reductionBuffer } } ] }); }; force.x = function(_) { return arguments.length ? (x = +_, updateCenterParams(), force) : x; }; force.y = function(_) { return arguments.length ? (y = +_, updateCenterParams(), force) : y; }; force.strength = function(_) { return arguments.length ? (strength = +_, updateCenterParams(), force) : strength; }; function updateCenterParams() { if (device && centerParamsBuffer && nodes) { const params = new Float32Array([x, y, strength, nodes.length]); device.queue.writeBuffer(centerParamsBuffer, 0, params); } } return force; }