three-mesh-bvh
Version:
A BVH implementation to speed up raycasting against three.js meshes.
277 lines (193 loc) • 7.52 kB
JavaScript
export const shaderStructs = /* glsl */`
struct BVH {
usampler2D index;
sampler2D position;
sampler2D bvhBounds;
usampler2D bvhContents;
};
// Note that a struct cannot be used for the hit record including faceIndices, faceNormal, barycoord,
// side, and dist because on some mobile GPUS (such as Adreno) numbers are afforded less precision specifically
// when in a struct leading to inaccurate hit results. See KhronosGroup/WebGL#3351 for more details.
`;
export const shaderIntersectFunction = /* glsl */`
uvec4 uTexelFetch1D( usampler2D tex, uint index ) {
uint width = uint( textureSize( tex, 0 ).x );
uvec2 uv;
uv.x = index % width;
uv.y = index / width;
return texelFetch( tex, ivec2( uv ), 0 );
}
ivec4 iTexelFetch1D( isampler2D tex, uint index ) {
uint width = uint( textureSize( tex, 0 ).x );
uvec2 uv;
uv.x = index % width;
uv.y = index / width;
return texelFetch( tex, ivec2( uv ), 0 );
}
vec4 texelFetch1D( sampler2D tex, uint index ) {
uint width = uint( textureSize( tex, 0 ).x );
uvec2 uv;
uv.x = index % width;
uv.y = index / width;
return texelFetch( tex, ivec2( uv ), 0 );
}
vec4 textureSampleBarycoord( sampler2D tex, vec3 barycoord, uvec3 faceIndices ) {
return
barycoord.x * texelFetch1D( tex, faceIndices.x ) +
barycoord.y * texelFetch1D( tex, faceIndices.y ) +
barycoord.z * texelFetch1D( tex, faceIndices.z );
}
void ndcToCameraRay(
vec2 coord, mat4 cameraWorld, mat4 invProjectionMatrix,
out vec3 rayOrigin, out vec3 rayDirection
) {
// get camera look direction and near plane for camera clipping
vec4 lookDirection = cameraWorld * vec4( 0.0, 0.0, - 1.0, 0.0 );
vec4 nearVector = invProjectionMatrix * vec4( 0.0, 0.0, - 1.0, 1.0 );
float near = abs( nearVector.z / nearVector.w );
// get the camera direction and position from camera matrices
vec4 origin = cameraWorld * vec4( 0.0, 0.0, 0.0, 1.0 );
vec4 direction = invProjectionMatrix * vec4( coord, 0.5, 1.0 );
direction /= direction.w;
direction = cameraWorld * direction - origin;
// slide the origin along the ray until it sits at the near clip plane position
origin.xyz += direction.xyz * near / dot( direction, lookDirection );
rayOrigin = origin.xyz;
rayDirection = direction.xyz;
}
float intersectsBounds( vec3 rayOrigin, vec3 rayDirection, vec3 boundsMin, vec3 boundsMax ) {
// https://www.reddit.com/r/opengl/comments/8ntzz5/fast_glsl_ray_box_intersection/
// https://tavianator.com/2011/ray_box.html
vec3 invDir = 1.0 / rayDirection;
// find intersection distances for each plane
vec3 tMinPlane = invDir * ( boundsMin - rayOrigin );
vec3 tMaxPlane = invDir * ( boundsMax - rayOrigin );
// get the min and max distances from each intersection
vec3 tMinHit = min( tMaxPlane, tMinPlane );
vec3 tMaxHit = max( tMaxPlane, tMinPlane );
// get the furthest hit distance
vec2 t = max( tMinHit.xx, tMinHit.yz );
float t0 = max( t.x, t.y );
// get the minimum hit distance
t = min( tMaxHit.xx, tMaxHit.yz );
float t1 = min( t.x, t.y );
// set distance to 0.0 if the ray starts inside the box
float dist = max( t0, 0.0 );
return t1 >= dist ? dist : INFINITY;
}
bool intersectsTriangle(
vec3 rayOrigin, vec3 rayDirection, vec3 a, vec3 b, vec3 c,
out vec3 barycoord, out vec3 norm, out float dist, out float side
) {
// https://stackoverflow.com/questions/42740765/intersection-between-line-and-triangle-in-3d
vec3 edge1 = b - a;
vec3 edge2 = c - a;
norm = cross( edge1, edge2 );
float det = - dot( rayDirection, norm );
float invdet = 1.0 / det;
vec3 AO = rayOrigin - a;
vec3 DAO = cross( AO, rayDirection );
vec4 uvt;
uvt.x = dot( edge2, DAO ) * invdet;
uvt.y = - dot( edge1, DAO ) * invdet;
uvt.z = dot( AO, norm ) * invdet;
uvt.w = 1.0 - uvt.x - uvt.y;
// set the hit information
barycoord = uvt.wxy; // arranged in A, B, C order
dist = uvt.z;
side = sign( det );
norm = side * normalize( norm );
// add an epsilon to avoid misses between triangles
uvt += vec4( TRI_INTERSECT_EPSILON );
return all( greaterThanEqual( uvt, vec4( 0.0 ) ) );
}
bool intersectTriangles(
BVH bvh, vec3 rayOrigin, vec3 rayDirection, uint offset, uint count,
inout float minDistance,
// output variables
out uvec4 faceIndices, out vec3 faceNormal, out vec3 barycoord,
out float side, out float dist
) {
bool found = false;
vec3 localBarycoord, localNormal;
float localDist, localSide;
for ( uint i = offset, l = offset + count; i < l; i ++ ) {
uvec3 indices = uTexelFetch1D( bvh.index, i ).xyz;
vec3 a = texelFetch1D( bvh.position, indices.x ).rgb;
vec3 b = texelFetch1D( bvh.position, indices.y ).rgb;
vec3 c = texelFetch1D( bvh.position, indices.z ).rgb;
if (
intersectsTriangle( rayOrigin, rayDirection, a, b, c, localBarycoord, localNormal, localDist, localSide )
&& localDist < minDistance
) {
found = true;
minDistance = localDist;
faceIndices = uvec4( indices.xyz, i );
faceNormal = localNormal;
side = localSide;
barycoord = localBarycoord;
dist = localDist;
}
}
return found;
}
float intersectsBVHNodeBounds( vec3 rayOrigin, vec3 rayDirection, BVH bvh, uint currNodeIndex ) {
vec3 boundsMin = texelFetch1D( bvh.bvhBounds, currNodeIndex * 2u + 0u ).xyz;
vec3 boundsMax = texelFetch1D( bvh.bvhBounds, currNodeIndex * 2u + 1u ).xyz;
return intersectsBounds( rayOrigin, rayDirection, boundsMin, boundsMax );
}
bool bvhIntersectFirstHit(
BVH bvh, vec3 rayOrigin, vec3 rayDirection,
// output variables
out uvec4 faceIndices, out vec3 faceNormal, out vec3 barycoord,
out float side, out float dist
) {
// stack needs to be twice as long as the deepest tree we expect because
// we push both the left and right child onto the stack every traversal
int ptr = 0;
uint stack[ 60 ];
stack[ 0 ] = 0u;
float triangleDistance = 1e20;
bool found = false;
while ( ptr > - 1 && ptr < 60 ) {
uint currNodeIndex = stack[ ptr ];
ptr --;
// check if we intersect the current bounds
float boundsHitDistance = intersectsBVHNodeBounds( rayOrigin, rayDirection, bvh, currNodeIndex );
if ( boundsHitDistance == INFINITY || boundsHitDistance > triangleDistance ) {
continue;
}
uvec2 boundsInfo = uTexelFetch1D( bvh.bvhContents, currNodeIndex ).xy;
bool isLeaf = bool( boundsInfo.x & 0xffff0000u );
if ( isLeaf ) {
uint count = boundsInfo.x & 0x0000ffffu;
uint offset = boundsInfo.y;
found = intersectTriangles(
bvh, rayOrigin, rayDirection, offset, count, triangleDistance,
faceIndices, faceNormal, barycoord, side, dist
) || found;
} else {
uint leftIndex = currNodeIndex + 1u;
uint splitAxis = boundsInfo.x & 0x0000ffffu;
uint rightIndex = boundsInfo.y;
bool leftToRight = rayDirection[ splitAxis ] >= 0.0;
uint c1 = leftToRight ? leftIndex : rightIndex;
uint c2 = leftToRight ? rightIndex : leftIndex;
// set c2 in the stack so we traverse it later. We need to keep track of a pointer in
// the stack while we traverse. The second pointer added is the one that will be
// traversed first
ptr ++;
stack[ ptr ] = c2;
ptr ++;
stack[ ptr ] = c1;
}
}
return found;
}
`;