d3-force-webgpu
Version:
Force-directed graph layout using velocity Verlet integration with WebGPU acceleration.
691 lines (597 loc) • 14.9 kB
JavaScript
// Shader source code as strings
// These are embedded directly to avoid async loading issues
export const manyBodyShader = `
// Many-Body Force Compute Shader
// Implements N-body gravitational/repulsive force using tile-based algorithm
struct Node {
x: f32,
y: f32,
vx: f32,
vy: f32,
fx: f32,
fy: f32,
strength: f32,
radius: f32,
}
struct Params {
alpha: f32,
velocityDecay: f32,
nodeCount: u32,
linkCount: u32,
centerX: f32,
centerY: f32,
centerStrength: f32,
theta2: f32,
distanceMin2: f32,
distanceMax2: f32,
iterations: u32,
collisionRadius: f32,
collisionStrength: f32,
collisionIterations: u32,
forceXTarget: f32,
forceXStrength: f32,
forceYTarget: f32,
forceYStrength: f32,
radialX: f32,
radialY: f32,
radialRadius: f32,
radialStrength: f32,
_pad1: f32,
_pad2: f32,
}
@group(0) @binding(0) var<storage, read_write> nodes: array<Node>;
@group(0) @binding(1) var<uniform> params: Params;
const TILE_SIZE: u32 = 256u;
var<workgroup> tile: array<vec4<f32>, 256>;
fn jiggle(seed: u32) -> f32 {
let s = (seed * 1103515245u + 12345u) & 0x7fffffffu;
return (f32(s) / f32(0x7fffffff) - 0.5) * 1e-6;
}
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>
) {
let i = global_id.x;
let nodeCount = params.nodeCount;
let isValid = i < nodeCount;
// Load node data (use defaults for out-of-bounds threads)
var myPos = vec2<f32>(0.0, 0.0);
var myVx: f32 = 0.0;
var myVy: f32 = 0.0;
if (isValid) {
let node = nodes[i];
myPos = vec2<f32>(node.x, node.y);
myVx = node.vx;
myVy = node.vy;
}
var forceX: f32 = 0.0;
var forceY: f32 = 0.0;
let alpha = params.alpha;
let distanceMin2 = params.distanceMin2;
let distanceMax2 = params.distanceMax2;
let numTiles = (nodeCount + TILE_SIZE - 1u) / TILE_SIZE;
for (var t: u32 = 0u; t < numTiles; t++) {
// All threads participate in loading the tile
let tileIdx = t * TILE_SIZE + local_id.x;
if (tileIdx < nodeCount) {
let other = nodes[tileIdx];
tile[local_id.x] = vec4<f32>(other.x, other.y, other.strength, 0.0);
} else {
tile[local_id.x] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
// All threads must hit this barrier
workgroupBarrier();
// Only valid threads compute forces
if (isValid) {
let tileEnd = min(TILE_SIZE, nodeCount - t * TILE_SIZE);
for (var j: u32 = 0u; j < tileEnd; j++) {
let otherIdx = t * TILE_SIZE + j;
if (otherIdx != i) {
let other = tile[j];
var dx = other.x - myPos.x;
var dy = other.y - myPos.y;
var l2 = dx * dx + dy * dy;
if (l2 < distanceMax2) {
if (dx == 0.0) {
dx = jiggle(i * nodeCount + otherIdx);
l2 += dx * dx;
}
if (dy == 0.0) {
dy = jiggle(i * nodeCount + otherIdx + 1u);
l2 += dy * dy;
}
if (l2 < distanceMin2) {
l2 = sqrt(distanceMin2 * l2);
}
let strength = other.z;
let force = strength * alpha / l2;
forceX += dx * force;
forceY += dy * force;
}
}
}
}
// All threads must hit this barrier
workgroupBarrier();
}
// Only valid threads write results
if (isValid) {
nodes[i].vx = myVx + forceX;
nodes[i].vy = myVy + forceY;
}
}
`;
export const linkShader = `
// Link Force Compute Shader
struct Node {
x: f32,
y: f32,
vx: f32,
vy: f32,
fx: f32,
fy: f32,
strength: f32,
radius: f32,
}
struct Link {
sourceIdx: f32,
targetIdx: f32,
distance: f32,
strength: f32,
bias: f32,
_pad1: f32,
_pad2: f32,
_pad3: f32,
}
struct Params {
alpha: f32,
velocityDecay: f32,
nodeCount: u32,
linkCount: u32,
centerX: f32,
centerY: f32,
centerStrength: f32,
theta2: f32,
distanceMin2: f32,
distanceMax2: f32,
iterations: u32,
collisionRadius: f32,
collisionStrength: f32,
collisionIterations: u32,
forceXTarget: f32,
forceXStrength: f32,
forceYTarget: f32,
forceYStrength: f32,
radialX: f32,
radialY: f32,
radialRadius: f32,
radialStrength: f32,
_pad1: f32,
_pad2: f32,
}
@group(0) @binding(0) var<storage, read_write> nodes: array<Node>;
@group(0) @binding(1) var<storage, read> links: array<Link>;
@group(0) @binding(2) var<uniform> params: Params;
fn jiggle(seed: u32) -> f32 {
let s = (seed * 1103515245u + 12345u) & 0x7fffffffu;
return (f32(s) / f32(0x7fffffff) - 0.5) * 1e-6;
}
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let nodeIdx = global_id.x;
let nodeCount = params.nodeCount;
let linkCount = params.linkCount;
let alpha = params.alpha;
if (nodeIdx >= nodeCount) {
return;
}
let node = nodes[nodeIdx];
var dvx: f32 = 0.0;
var dvy: f32 = 0.0;
for (var i: u32 = 0u; i < linkCount; i++) {
let link = links[i];
let sourceIdx = u32(link.sourceIdx);
let targetIdx = u32(link.targetIdx);
if (sourceIdx != nodeIdx && targetIdx != nodeIdx) {
continue;
}
let srcNode = nodes[sourceIdx];
let dstNode = nodes[targetIdx];
var dx = dstNode.x + dstNode.vx - srcNode.x - srcNode.vx;
var dy = dstNode.y + dstNode.vy - srcNode.y - srcNode.vy;
if (dx == 0.0 && dy == 0.0) {
dx = jiggle(i * 2u);
dy = jiggle(i * 2u + 1u);
}
let l = sqrt(dx * dx + dy * dy);
let force = (l - link.distance) / l * alpha * link.strength;
let fx = dx * force;
let fy = dy * force;
if (nodeIdx == targetIdx) {
dvx -= fx * link.bias;
dvy -= fy * link.bias;
} else {
dvx += fx * (1.0 - link.bias);
dvy += fy * (1.0 - link.bias);
}
}
nodes[nodeIdx].vx = node.vx + dvx;
nodes[nodeIdx].vy = node.vy + dvy;
}
`;
export const integrateShader = `
// Position Integration Compute Shader
struct Node {
x: f32,
y: f32,
vx: f32,
vy: f32,
fx: f32,
fy: f32,
strength: f32,
radius: f32,
}
struct Params {
alpha: f32,
velocityDecay: f32,
nodeCount: u32,
linkCount: u32,
centerX: f32,
centerY: f32,
centerStrength: f32,
theta2: f32,
distanceMin2: f32,
distanceMax2: f32,
iterations: u32,
collisionRadius: f32,
collisionStrength: f32,
collisionIterations: u32,
forceXTarget: f32,
forceXStrength: f32,
forceYTarget: f32,
forceYStrength: f32,
radialX: f32,
radialY: f32,
radialRadius: f32,
radialStrength: f32,
_pad1: f32,
_pad2: f32,
}
@group(0) @binding(0) var<storage, read_write> nodes: array<Node>;
@group(0) @binding(1) var<uniform> params: Params;
fn isNaN(v: f32) -> bool {
return !(v == v);
}
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let i = global_id.x;
let nodeCount = params.nodeCount;
if (i >= nodeCount) {
return;
}
var node = nodes[i];
let velocityDecay = params.velocityDecay;
if (isNaN(node.fx)) {
node.vx = node.vx * velocityDecay;
node.x = node.x + node.vx;
} else {
node.x = node.fx;
node.vx = 0.0;
}
if (isNaN(node.fy)) {
node.vy = node.vy * velocityDecay;
node.y = node.y + node.vy;
} else {
node.y = node.fy;
node.vy = 0.0;
}
nodes[i] = node;
}
`;
export const collisionShader = `
// Collision Force Compute Shader
struct Node {
x: f32,
y: f32,
vx: f32,
vy: f32,
fx: f32,
fy: f32,
strength: f32,
radius: f32,
}
struct Params {
alpha: f32,
velocityDecay: f32,
nodeCount: u32,
linkCount: u32,
centerX: f32,
centerY: f32,
centerStrength: f32,
theta2: f32,
distanceMin2: f32,
distanceMax2: f32,
iterations: u32,
collisionRadius: f32,
collisionStrength: f32,
collisionIterations: u32,
forceXTarget: f32,
forceXStrength: f32,
forceYTarget: f32,
forceYStrength: f32,
radialX: f32,
radialY: f32,
radialRadius: f32,
radialStrength: f32,
_pad1: f32,
_pad2: f32,
}
@group(0) @binding(0) var<storage, read_write> nodes: array<Node>;
@group(0) @binding(1) var<uniform> params: Params;
const TILE_SIZE: u32 = 256u;
var<workgroup> tile: array<vec4<f32>, 256>;
fn jiggle(seed: u32) -> f32 {
let s = (seed * 1103515245u + 12345u) & 0x7fffffffu;
return (f32(s) / f32(0x7fffffff) - 0.5) * 1e-6;
}
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>
) {
let i = global_id.x;
let nodeCount = params.nodeCount;
let isValid = i < nodeCount;
// Load node data (use defaults for out-of-bounds threads)
var myX: f32 = 0.0;
var myY: f32 = 0.0;
var myRadius: f32 = 0.0;
var myVx: f32 = 0.0;
var myVy: f32 = 0.0;
if (isValid) {
let node = nodes[i];
myX = node.x;
myY = node.y;
myRadius = node.radius;
myVx = node.vx;
myVy = node.vy;
}
var dvx: f32 = 0.0;
var dvy: f32 = 0.0;
let strength = params.collisionStrength;
let numTiles = (nodeCount + TILE_SIZE - 1u) / TILE_SIZE;
for (var t: u32 = 0u; t < numTiles; t++) {
// All threads participate in loading the tile
let tileIdx = t * TILE_SIZE + local_id.x;
if (tileIdx < nodeCount) {
let other = nodes[tileIdx];
tile[local_id.x] = vec4<f32>(other.x, other.y, other.radius, 0.0);
} else {
tile[local_id.x] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
// All threads must hit this barrier
workgroupBarrier();
// Only valid threads compute collisions
if (isValid) {
let tileEnd = min(TILE_SIZE, nodeCount - t * TILE_SIZE);
for (var j: u32 = 0u; j < tileEnd; j++) {
let otherIdx = t * TILE_SIZE + j;
if (otherIdx != i) {
let other = tile[j];
let otherRadius = other.z;
let combinedRadius = myRadius + otherRadius;
var dx = myX - other.x;
var dy = myY - other.y;
var l2 = dx * dx + dy * dy;
let minDist2 = combinedRadius * combinedRadius;
if (l2 < minDist2) {
if (dx == 0.0) {
dx = jiggle(min(i, otherIdx) * nodeCount + max(i, otherIdx));
l2 += dx * dx;
}
if (dy == 0.0) {
dy = jiggle(min(i, otherIdx) * nodeCount + max(i, otherIdx) + 1u);
l2 += dy * dy;
}
let l = sqrt(l2);
let overlap = combinedRadius - l;
let totalRadius = myRadius + otherRadius;
let myWeight = otherRadius / totalRadius;
let impulse = overlap * strength * 0.5;
let nx = dx / l;
let ny = dy / l;
dvx += nx * impulse * myWeight;
dvy += ny * impulse * myWeight;
}
}
}
}
// All threads must hit this barrier
workgroupBarrier();
}
// Only valid threads write results
if (isValid) {
nodes[i].vx = myVx + dvx;
nodes[i].vy = myVy + dvy;
}
}
`;
export const forceXShader = `
// X-Positioning Force Compute Shader
// Pushes nodes toward a target x position
struct Node {
x: f32,
y: f32,
vx: f32,
vy: f32,
fx: f32,
fy: f32,
strength: f32,
radius: f32,
}
struct Params {
alpha: f32,
velocityDecay: f32,
nodeCount: u32,
linkCount: u32,
centerX: f32,
centerY: f32,
centerStrength: f32,
theta2: f32,
distanceMin2: f32,
distanceMax2: f32,
iterations: u32,
collisionRadius: f32,
collisionStrength: f32,
collisionIterations: u32,
forceXTarget: f32,
forceXStrength: f32,
forceYTarget: f32,
forceYStrength: f32,
radialX: f32,
radialY: f32,
radialRadius: f32,
radialStrength: f32,
_pad1: f32,
_pad2: f32,
}
@group(0) @binding(0) var<storage, read_write> nodes: array<Node>;
@group(0) @binding(1) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let i = global_id.x;
if (i >= params.nodeCount) {
return;
}
let node = nodes[i];
let alpha = params.alpha;
let targetX = params.forceXTarget;
let strength = params.forceXStrength;
let dx = targetX - node.x;
nodes[i].vx = node.vx + dx * strength * alpha;
}
`;
export const forceYShader = `
// Y-Positioning Force Compute Shader
// Pushes nodes toward a target y position
struct Node {
x: f32,
y: f32,
vx: f32,
vy: f32,
fx: f32,
fy: f32,
strength: f32,
radius: f32,
}
struct Params {
alpha: f32,
velocityDecay: f32,
nodeCount: u32,
linkCount: u32,
centerX: f32,
centerY: f32,
centerStrength: f32,
theta2: f32,
distanceMin2: f32,
distanceMax2: f32,
iterations: u32,
collisionRadius: f32,
collisionStrength: f32,
collisionIterations: u32,
forceXTarget: f32,
forceXStrength: f32,
forceYTarget: f32,
forceYStrength: f32,
radialX: f32,
radialY: f32,
radialRadius: f32,
radialStrength: f32,
_pad1: f32,
_pad2: f32,
}
@group(0) @binding(0) var<storage, read_write> nodes: array<Node>;
@group(0) @binding(1) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let i = global_id.x;
if (i >= params.nodeCount) {
return;
}
let node = nodes[i];
let alpha = params.alpha;
let targetY = params.forceYTarget;
let strength = params.forceYStrength;
let dy = targetY - node.y;
nodes[i].vy = node.vy + dy * strength * alpha;
}
`;
export const forceRadialShader = `
// Radial Force Compute Shader
// Pushes nodes toward a target radius from a center point
struct Node {
x: f32,
y: f32,
vx: f32,
vy: f32,
fx: f32,
fy: f32,
strength: f32,
radius: f32,
}
struct Params {
alpha: f32,
velocityDecay: f32,
nodeCount: u32,
linkCount: u32,
centerX: f32,
centerY: f32,
centerStrength: f32,
theta2: f32,
distanceMin2: f32,
distanceMax2: f32,
iterations: u32,
collisionRadius: f32,
collisionStrength: f32,
collisionIterations: u32,
forceXTarget: f32,
forceXStrength: f32,
forceYTarget: f32,
forceYStrength: f32,
radialX: f32,
radialY: f32,
radialRadius: f32,
radialStrength: f32,
_pad1: f32,
_pad2: f32,
}
@group(0) @binding(0) var<storage, read_write> nodes: array<Node>;
@group(0) @binding(1) var<uniform> params: Params;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let i = global_id.x;
if (i >= params.nodeCount) {
return;
}
let node = nodes[i];
let alpha = params.alpha;
let cx = params.radialX;
let cy = params.radialY;
let targetRadius = params.radialRadius;
let strength = params.radialStrength;
let dx = node.x - cx;
let dy = node.y - cy;
let r = sqrt(dx * dx + dy * dy);
if (r > 0.0) {
let k = (targetRadius - r) * strength * alpha / r;
nodes[i].vx = node.vx + dx * k;
nodes[i].vy = node.vy + dy * k;
}
}
`;