greed.js
Version:
Lightweight, private alternative to Colab. Run PyTorch & NumPy in browser with GPU acceleration (8.8x speedup). Fast, secure, runs locally.
377 lines (311 loc) • 10.2 kB
JavaScript
/**
* PipelineCache - WebGPU compute pipeline management and caching
* Optimizes shader compilation and pipeline creation with intelligent caching
*/
import EventEmitter from '../../core/event-emitter.js';
import { WebGPUShaders } from './webgpu-shaders.js';
class PipelineCache extends EventEmitter {
constructor(device, options = {}) {
super();
this.device = device;
this.config = {
maxCacheSize: options.maxCacheSize || 100,
enableWarmup: options.enableWarmup !== false,
enableMetrics: options.enableMetrics !== false,
shaderOptimization: options.shaderOptimization || 'balanced', // 'speed', 'size', 'balanced'
...options
};
// Pipeline and shader caches
this.pipelines = new Map(); // operationKey -> ComputePipeline
this.shaderModules = new Map(); // shaderHash -> ShaderModule
this.bindGroupLayouts = new Map(); // layoutKey -> BindGroupLayout
// LRU tracking for cache eviction
this.accessOrder = new Map(); // key -> timestamp
// Compilation queue for async pipeline creation
this.compilationQueue = new Map(); // key -> Promise
// Statistics
this.stats = {
hits: 0,
misses: 0,
compilations: 0,
evictions: 0,
averageCompileTime: 0,
totalCompileTime: 0
};
// Shader templates for common operations
this.shaderTemplates = WebGPUShaders.getShaderTemplates();
}
/**
* Get or create compute pipeline for operation
*/
async get(operation, options = {}) {
const key = this._generateKey(operation, options);
// Check if already compiled or in progress
if (this.pipelines.has(key)) {
this._updateAccess(key);
this.stats.hits++;
this.emit('cache:hit', { operation, key });
return this.pipelines.get(key);
}
// Check if compilation is in progress
if (this.compilationQueue.has(key)) {
this.emit('cache:wait', { operation, key });
return await this.compilationQueue.get(key);
}
// Start compilation
this.stats.misses++;
this.emit('cache:miss', { operation, key });
const compilationPromise = this._compilePipeline(operation, options, key);
this.compilationQueue.set(key, compilationPromise);
try {
const pipeline = await compilationPromise;
this.compilationQueue.delete(key);
return pipeline;
} catch (error) {
this.compilationQueue.delete(key);
throw error;
}
}
/**
* Precompile common operations for faster first use
*/
async warmup(operations = null) {
if (!this.config.enableWarmup) {
return;
}
const commonOps = operations || [
'add', 'multiply', 'matmul', 'relu', 'sigmoid',
'softmax', 'conv2d', 'maxpool', 'transpose'
];
this.emit('warmup:start', { operations: commonOps });
const startTime = performance.now();
const warmupPromises = commonOps.map(async (op) => {
try {
await this.get(op, { warmup: true });
this.emit('warmup:operation', { operation: op });
} catch (error) {
this.emit('warmup:error', { operation: op, error });
}
});
await Promise.allSettled(warmupPromises);
const duration = performance.now() - startTime;
this.emit('warmup:complete', { operations: commonOps, duration });
}
/**
* Get bind group layout for operation
*/
getBindGroupLayout(operation, options = {}) {
const key = this._generateLayoutKey(operation, options);
if (this.bindGroupLayouts.has(key)) {
return this.bindGroupLayouts.get(key);
}
const layout = this._createBindGroupLayout(operation, options);
this.bindGroupLayouts.set(key, layout);
return layout;
}
/**
* Get optimal workgroup size for operation
*/
getOptimalWorkgroupSize(operation, tensorShape, deviceLimits) {
return WebGPUShaders.getOptimalWorkgroupSize(operation, tensorShape, deviceLimits);
}
/**
* Generate parameters for operation
*/
generateOperationParams(operation, tensors, options = {}) {
return WebGPUShaders.generateParams(operation, tensors, options);
}
/**
* Create shader module with caching
*/
async createShaderModule(shaderSource, options = {}) {
const hash = this._hashString(shaderSource);
if (this.shaderModules.has(hash)) {
return this.shaderModules.get(hash);
}
try {
const shaderModule = this.device.createShaderModule({
code: shaderSource,
...options
});
this.shaderModules.set(hash, shaderModule);
this.emit('shader:compiled', { hash, size: shaderSource.length });
return shaderModule;
} catch (error) {
this.emit('shader:error', { hash, error, source: shaderSource.substring(0, 100) });
throw error;
}
}
/**
* Generate optimized shader for operation
*/
generateShader(operation, options = {}) {
const template = this.shaderTemplates.get(operation);
if (!template) {
throw new Error(`No shader template found for operation: ${operation}`);
}
const shaderOptions = {
workgroupSize: options.workgroupSize || [8, 8, 1],
dataType: options.dataType || 'f32',
optimization: this.config.shaderOptimization,
...options
};
return template(shaderOptions);
}
/**
* Get cache statistics
*/
getStats() {
const hitRate = this.stats.hits + this.stats.misses > 0 ?
this.stats.hits / (this.stats.hits + this.stats.misses) : 0;
return {
...this.stats,
hitRate,
cacheSize: this.pipelines.size,
shaderCacheSize: this.shaderModules.size,
layoutCacheSize: this.bindGroupLayouts.size,
averageCompileTimeMs: Math.round(this.stats.averageCompileTime * 100) / 100
};
}
/**
* Clear cache and reset statistics
*/
clear() {
this.pipelines.clear();
this.shaderModules.clear();
this.bindGroupLayouts.clear();
this.accessOrder.clear();
this.compilationQueue.clear();
// Reset stats but keep historical data
this.stats.hits = 0;
this.stats.misses = 0;
this.emit('cache:cleared');
}
/**
* Cleanup resources
*/
cleanup() {
this.clear();
this.shaderTemplates.clear();
this.emit('cleanup:complete');
}
// Private methods
async _compilePipeline(operation, options, key) {
const startTime = performance.now();
try {
// Generate shader source
const shaderSource = this.generateShader(operation, options);
// Create shader module
const shaderModule = await this.createShaderModule(shaderSource);
// Create bind group layout
const bindGroupLayout = this.getBindGroupLayout(operation, options);
// Create pipeline layout
const pipelineLayout = this.device.createPipelineLayout({
bindGroupLayouts: [bindGroupLayout]
});
// Create compute pipeline
const pipeline = await this.device.createComputePipelineAsync({
layout: pipelineLayout,
compute: {
module: shaderModule,
entryPoint: 'main'
}
});
// Store in cache
this.pipelines.set(key, pipeline);
this._updateAccess(key);
this._enforceMaxCacheSize();
// Update statistics
const compileTime = performance.now() - startTime;
this.stats.compilations++;
this.stats.totalCompileTime += compileTime;
this.stats.averageCompileTime = this.stats.totalCompileTime / this.stats.compilations;
this.emit('pipeline:compiled', {
operation,
key,
compileTime,
cacheSize: this.pipelines.size
});
return pipeline;
} catch (error) {
this.emit('pipeline:error', { operation, key, error });
throw error;
}
}
_generateKey(operation, options) {
const keyParts = [
operation,
options.workgroupSize?.join(',') || '8,8,1',
options.dataType || 'f32',
options.inputCount || 2,
options.outputCount || 1,
JSON.stringify(options.constants || {})
];
return keyParts.join('|');
}
_generateLayoutKey(operation, options) {
return `${operation}|${options.inputCount || 2}|${options.outputCount || 1}`;
}
_createBindGroupLayout(operation, options) {
const layout = WebGPUShaders.getBufferLayout(operation, options.inputCount, options.outputCount);
const entries = [];
// Input buffers (read-only)
for (let i = 0; i < layout.inputs; i++) {
entries.push({
binding: i,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: 'read-only-storage' }
});
}
// Output buffers (write-only)
for (let i = 0; i < layout.outputs; i++) {
entries.push({
binding: layout.inputs + i,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: 'storage' }
});
}
// Uniform buffer for parameters
if (layout.uniforms > 0) {
entries.push({
binding: layout.inputs + layout.outputs,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: 'uniform' }
});
}
return this.device.createBindGroupLayout({ entries });
}
_updateAccess(key) {
this.accessOrder.set(key, performance.now());
}
_enforceMaxCacheSize() {
if (this.pipelines.size <= this.config.maxCacheSize) {
return;
}
// Find least recently used item
let lruKey = null;
let lruTime = Infinity;
for (const [key, time] of this.accessOrder.entries()) {
if (time < lruTime) {
lruTime = time;
lruKey = key;
}
}
if (lruKey) {
this.pipelines.delete(lruKey);
this.accessOrder.delete(lruKey);
this.stats.evictions++;
this.emit('cache:eviction', { key: lruKey, cacheSize: this.pipelines.size });
}
}
_hashString(str) {
let hash = 0;
for (let i = 0; i < str.length; i++) {
const char = str.charCodeAt(i);
hash = ((hash << 5) - hash) + char;
hash = hash & hash; // Convert to 32-bit integer
}
return hash.toString(36);
}
}
export default PipelineCache;