image-palette-webgpu
Version:
A tiny zero-dependency browser package that extracts dominant color or color palette from an image using WebGPU API with various algorithms
158 lines (138 loc) • 5.9 kB
JavaScript
import { setupBuildHistogram } from './pipelines/buildHistogram.js';
import { setupComputeMoments } from './pipelines/computeMoments.js';
import { setupCreateBox } from './pipelines/createBox.js';
import { setupCreateResult } from './pipelines/createResult.js';
import { floatArrayToHex } from '../utils/color_utils.js';
export async function extractDominantColorsWuGPU(device, source, K) {
const WORKGROUP_SIZE = 16;
const width = source.width;
const height = source.height;
const TOTAL_SIZE = 35937;
const {
weightsBuffer,
momentsRBuffer,
momentsGBuffer,
momentsBBuffer,
momentsBuffer: mBuffer,
buildHistogramPipeline,
inputBindGroup,
buildHistogramBindGroup,
buildHistogramBindGroupLayout
} = await setupBuildHistogram(device, source);
const {
computeMomentsAxisBindGroups,
computeMomentsPipeline
} = await setupComputeMoments(device, buildHistogramBindGroupLayout);
const {
momentsBuffer,
momentsBindGroup,
totalCubesNumUniformBuffer,
momentsBindGroupLayout,
cubesBuffer,
cubesBindGroup,
createBoxPipeline
} = await setupCreateBox(device, K);
const {
resultsBuffer,
cubesResultBindGroup,
resultsBindGroup,
createResultPipeline
} = await setupCreateResult(device, K, momentsBindGroupLayout, cubesBuffer, totalCubesNumUniformBuffer);
let encoder = device.createCommandEncoder();
const buildHistogramPass = encoder.beginComputePass();
buildHistogramPass.setPipeline(buildHistogramPipeline);
buildHistogramPass.setBindGroup(0, inputBindGroup);
buildHistogramPass.setBindGroup(1, buildHistogramBindGroup);
buildHistogramPass.dispatchWorkgroups(Math.ceil(width / WORKGROUP_SIZE), Math.ceil(height / WORKGROUP_SIZE));
buildHistogramPass.end();
const workGroupsPerDim = Math.ceil(32 / WORKGROUP_SIZE);
const momentPass = encoder.beginComputePass();
momentPass.setPipeline(computeMomentsPipeline);
momentPass.setBindGroup(0, buildHistogramBindGroup);
for (let axis = 0; axis < 3; axis++) {
momentPass.setBindGroup(1, computeMomentsAxisBindGroups[axis]);
momentPass.dispatchWorkgroups(workGroupsPerDim, workGroupsPerDim);
}
momentPass.end();
encoder.copyBufferToBuffer(
momentsRBuffer, 0,
momentsBuffer, 0,
TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT
);
encoder.copyBufferToBuffer(
momentsGBuffer, 0,
momentsBuffer, TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT,
TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT
);
encoder.copyBufferToBuffer(
momentsBBuffer, 0,
momentsBuffer, 2 * TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT,
TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT
);
encoder.copyBufferToBuffer(
weightsBuffer, 0,
momentsBuffer, 3 * TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT,
TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT
);
encoder.copyBufferToBuffer(
mBuffer, 0,
momentsBuffer, 4 * TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT,
TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT
);
device.queue.submit([encoder.finish()]);
for (let i = 1; i < K; i++) {
encoder = device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(createBoxPipeline);
device.queue.writeBuffer(totalCubesNumUniformBuffer, 0, new Uint32Array([i]));
pass.setBindGroup(0, momentsBindGroup);
pass.setBindGroup(1, cubesBindGroup);
pass.dispatchWorkgroups(1);
pass.end();
device.queue.submit([encoder.finish()]);
}
encoder = device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(createResultPipeline);
pass.setBindGroup(0, momentsBindGroup);
pass.setBindGroup(1, cubesResultBindGroup);
pass.setBindGroup(2, resultsBindGroup);
pass.dispatchWorkgroups(1);
pass.end();
device.queue.submit([encoder.finish()]);
await device.queue.onSubmittedWorkDone();
return resultsBuffer;
}
/**
* Extracts dominant colors from an image source using WebGPU API with Wu algorithm.
* @param {ImageBitmapSource} imageSource - The image source to process.
* @param {number} K - The number of dominant colors to extract.
* @returns {Promise<Array<string>>} A promise that resolves to an array of dominant colors.
*/
export async function extractDominantColorsWu(imageSource, K) {
const adapter = await navigator.gpu?.requestAdapter();
const device = await adapter?.requestDevice();
if (!device) {
window.alert('WebGPU not supported');
throw new Error('WebGPU not supported');
}
const source = await createImageBitmap(imageSource, { colorSpaceConversion: 'none' });
const resultsBuffer = await extractDominantColorsWuGPU(device, source, K);
const stagingResultsBuffer = device.createBuffer({
size: 3 * K * Float32Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
});
const encoder = device.createCommandEncoder();
encoder.copyBufferToBuffer(
resultsBuffer, 0,
stagingResultsBuffer, 0,
3 * K * Float32Array.BYTES_PER_ELEMENT
);
device.queue.submit([encoder.finish()]);
await stagingResultsBuffer.mapAsync(GPUMapMode.READ, 0, 3 * K * Float32Array.BYTES_PER_ELEMENT);
const mappedData = stagingResultsBuffer.getMappedRange();
const results = new Float32Array(mappedData.slice(0));
stagingResultsBuffer.unmap();
const hexColors = floatArrayToHex(results.filter(x => x >= 0));
return hexColors;
}