@luma.gl/core
Version:
The luma.gl core Device API
158 lines (137 loc) • 5.45 kB
text/typescript
// luma.gl
// SPDX-License-Identifier: MIT
// Copyright (c) vis.gl contributors
import type {
Bindings,
BindingsByGroup,
ComputeShaderLayout,
ShaderLayout
} from '../adapter/types/shader-layout';
import type {Device} from '../adapter/device';
import type {ComputePipeline} from '../adapter/resources/compute-pipeline';
import type {RenderPipeline} from '../adapter/resources/render-pipeline';
import {normalizeBindingsByGroup} from '../adapter-utils/bind-groups';
type AnyPipeline = RenderPipeline | ComputePipeline;
type AnyShaderLayout = ShaderLayout | ComputeShaderLayout;
type BindGroupCacheKeys = Partial<Record<number, object>>;
type BindGroupMap = Partial<Record<number, unknown>>;
type LayoutCache = Partial<Record<number, object>>;
type LayoutBindGroupCache = {
bindGroupsBySource: WeakMap<object, unknown | null>;
emptyBindGroup?: unknown;
};
export class BindGroupFactory {
readonly device: Device;
private readonly _layoutCacheByPipeline: WeakMap<AnyPipeline, LayoutCache> = new WeakMap();
private readonly _bindGroupCacheByLayout: WeakMap<object, LayoutBindGroupCache> = new WeakMap();
constructor(device: Device) {
this.device = device;
}
getBindGroups(
pipeline: AnyPipeline,
bindings?: Bindings | BindingsByGroup,
bindGroupCacheKeys?: BindGroupCacheKeys
): BindGroupMap {
if (this.device.type !== 'webgpu' || pipeline.shaderLayout.bindings.length === 0) {
return {};
}
const bindingsByGroup = normalizeBindingsByGroup(pipeline.shaderLayout, bindings);
const resolvedBindGroups: BindGroupMap = {};
for (const group of getBindGroupIndicesUpToMax(pipeline.shaderLayout.bindings)) {
const groupBindings = bindingsByGroup[group];
const bindGroupLayout = this._getBindGroupLayout(pipeline, group);
const bindGroupLabel = getBindGroupLabel(pipeline, pipeline.shaderLayout, group);
if (!groupBindings || Object.keys(groupBindings).length === 0) {
if (!hasBindingsInGroup(pipeline.shaderLayout.bindings, group)) {
resolvedBindGroups[group] = this._getEmptyBindGroup(
bindGroupLayout,
pipeline.shaderLayout,
group,
bindGroupLabel
);
}
continue;
}
const bindGroupCacheKey = bindGroupCacheKeys?.[group];
if (bindGroupCacheKey) {
const layoutCache = this._getLayoutBindGroupCache(bindGroupLayout);
if (layoutCache.bindGroupsBySource.has(bindGroupCacheKey)) {
resolvedBindGroups[group] = layoutCache.bindGroupsBySource.get(bindGroupCacheKey) || null;
continue;
}
const bindGroup = this.device._createBindGroupWebGPU(
bindGroupLayout,
pipeline.shaderLayout,
groupBindings,
group,
bindGroupLabel
);
layoutCache.bindGroupsBySource.set(bindGroupCacheKey, bindGroup);
resolvedBindGroups[group] = bindGroup;
} else {
resolvedBindGroups[group] = this.device._createBindGroupWebGPU(
bindGroupLayout,
pipeline.shaderLayout,
groupBindings,
group,
bindGroupLabel
);
}
}
return resolvedBindGroups;
}
private _getBindGroupLayout(pipeline: AnyPipeline, group: number): object {
let layoutCache = this._layoutCacheByPipeline.get(pipeline);
if (!layoutCache) {
layoutCache = {};
this._layoutCacheByPipeline.set(pipeline, layoutCache);
}
layoutCache[group] ||= this.device._createBindGroupLayoutWebGPU(pipeline, group) as object;
return layoutCache[group];
}
private _getEmptyBindGroup(
bindGroupLayout: object,
shaderLayout: AnyShaderLayout,
group: number,
label: string
): unknown {
const layoutCache = this._getLayoutBindGroupCache(bindGroupLayout);
layoutCache.emptyBindGroup ||=
this.device._createBindGroupWebGPU(bindGroupLayout, shaderLayout, {}, group, label) || null;
return layoutCache.emptyBindGroup;
}
private _getLayoutBindGroupCache(bindGroupLayout: object): LayoutBindGroupCache {
let layoutCache = this._bindGroupCacheByLayout.get(bindGroupLayout);
if (!layoutCache) {
layoutCache = {bindGroupsBySource: new WeakMap()};
this._bindGroupCacheByLayout.set(bindGroupLayout, layoutCache);
}
return layoutCache;
}
}
export function _getDefaultBindGroupFactory(device: Device): BindGroupFactory {
device._factories.bindGroupFactory ||= new BindGroupFactory(device);
return device._factories.bindGroupFactory as BindGroupFactory;
}
function getBindGroupIndicesUpToMax(bindings: AnyShaderLayout['bindings']): number[] {
const maxGroup = bindings.reduce(
(highestGroup, binding) => Math.max(highestGroup, binding.group),
-1
);
return Array.from({length: maxGroup + 1}, (_, group) => group);
}
function hasBindingsInGroup(bindings: AnyShaderLayout['bindings'], group: number): boolean {
return bindings.some(binding => binding.group === group);
}
function getBindGroupLabel(
pipeline: AnyPipeline,
shaderLayout: AnyShaderLayout,
group: number
): string {
const bindingNames = shaderLayout.bindings
.filter(binding => binding.group === group)
.sort((left, right) => left.location - right.location)
.map(binding => binding.name);
const bindingSuffix = bindingNames.length > 0 ? bindingNames.join(',') : 'empty';
return `${pipeline.id}/group${group}[${bindingSuffix}]`;
}