cuda-wasm
Version:
High-performance CUDA to WebAssembly/WebGPU transpiler with Rust safety - Run GPU kernels in browsers and Node.js
269 lines (222 loc) • 8.2 kB
JavaScript
// WebAssembly Code Generator - Converts parsed CUDA to WASM
const fs = require('fs');
class WasmGenerator {
constructor() {
this.wasm = {
version: 1,
types: [],
functions: [],
memory: {
initial: 256, // 256 pages = 16MB
maximum: 1024 // 1024 pages = 64MB
},
exports: []
};
}
generate(parsedCuda) {
// Reset state
this.wasm.types = [];
this.wasm.functions = [];
this.wasm.exports = [];
// Convert each kernel to WASM function
for (const kernel of parsedCuda.kernels) {
this.convertKernel(kernel);
}
// Convert device functions
for (const func of parsedCuda.deviceFunctions) {
this.convertDeviceFunction(func);
}
return this.generateWAT();
}
convertKernel(kernel) {
// Create function type
const typeIdx = this.addFunctionType(kernel.parameters, kernel.returnType);
// Generate function body
const funcBody = this.generateFunctionBody(kernel);
// Add function
const funcIdx = this.wasm.functions.length;
this.wasm.functions.push({
name: kernel.name,
typeIdx,
locals: this.extractLocals(kernel.body),
body: funcBody
});
// Export kernel
this.wasm.exports.push({
name: kernel.name,
kind: 'func',
index: funcIdx
});
}
convertDeviceFunction(func) {
const typeIdx = this.addFunctionType(func.parameters, func.returnType);
const funcBody = this.generateFunctionBody(func);
this.wasm.functions.push({
name: func.name,
typeIdx,
locals: this.extractLocals(func.body),
body: funcBody
});
}
addFunctionType(params, returnType) {
const paramTypes = params.map(p => this.cudaTypeToWasm(p.type));
const resultTypes = returnType === 'void' ? [] : [this.cudaTypeToWasm(returnType)];
// Check if type already exists
for (let i = 0; i < this.wasm.types.length; i++) {
const type = this.wasm.types[i];
if (JSON.stringify(type.params) === JSON.stringify(paramTypes) &&
JSON.stringify(type.results) === JSON.stringify(resultTypes)) {
return i;
}
}
// Add new type
this.wasm.types.push({
params: paramTypes,
results: resultTypes
});
return this.wasm.types.length - 1;
}
cudaTypeToWasm(cudaType) {
const typeMap = {
'float': 'f32',
'double': 'f64',
'int': 'i32',
'unsigned int': 'i32',
'uint': 'i32',
'long': 'i64',
'unsigned long': 'i64',
'char': 'i32',
'unsigned char': 'i32',
'short': 'i32',
'unsigned short': 'i32',
'bool': 'i32'
};
// Handle pointers
if (cudaType.includes('*')) {
return 'i32'; // Pointers are 32-bit indices in WASM
}
return typeMap[cudaType.trim()] || 'i32';
}
extractLocals(body) {
const locals = [];
// Simple extraction of local variables
const varRegex = /\b(int|float|double|char|short|long|unsigned|bool)\s+(\w+)\s*[=;]/g;
let match;
while ((match = varRegex.exec(body)) !== null) {
locals.push({
name: match[2],
type: this.cudaTypeToWasm(match[1])
});
}
return locals;
}
generateFunctionBody(kernel) {
const instructions = [];
// Parse kernel body and generate WASM instructions
const lines = kernel.body.split('\n');
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) continue;
// Handle thread index calculations
if (trimmed.includes('threadIdx.x')) {
instructions.push('global.get $threadIdx_x');
}
if (trimmed.includes('blockIdx.x')) {
instructions.push('global.get $blockIdx_x');
}
if (trimmed.includes('blockDim.x')) {
instructions.push('global.get $blockDim_x');
}
// Handle array access
const arrayMatch = trimmed.match(/(\w+)\[([^\]]+)\]\s*=\s*([^;]+)/);
if (arrayMatch) {
const array = arrayMatch[1];
const index = arrayMatch[2];
const value = arrayMatch[3];
// Generate load/store instructions
instructions.push(`local.get $${array}`);
instructions.push(`${this.parseExpression(index)}`);
instructions.push('i32.const 4'); // sizeof(float)
instructions.push('i32.mul');
instructions.push('i32.add');
instructions.push(`${this.parseExpression(value)}`);
instructions.push('f32.store');
}
}
return instructions;
}
parseExpression(expr) {
// Simple expression parser
if (expr.includes('+')) {
const parts = expr.split('+');
return `${this.parseExpression(parts[0].trim())} ${this.parseExpression(parts[1].trim())} f32.add`;
}
if (expr.includes('*')) {
const parts = expr.split('*');
return `${this.parseExpression(parts[0].trim())} ${this.parseExpression(parts[1].trim())} f32.mul`;
}
// Variable reference
if (expr.match(/^\w+$/)) {
return `local.get $${expr}`;
}
// Constant
if (!isNaN(expr)) {
return `f32.const ${expr}`;
}
return '';
}
generateWAT() {
let wat = '(module\n';
// Add memory
wat += ` (memory $mem ${this.wasm.memory.initial} ${this.wasm.memory.maximum})\n`;
wat += ' (export "memory" (memory $mem))\n\n';
// Add globals for thread/block info
wat += ' (global $threadIdx_x (mut i32) (i32.const 0))\n';
wat += ' (global $blockIdx_x (mut i32) (i32.const 0))\n';
wat += ' (global $blockDim_x (mut i32) (i32.const 256))\n\n';
// Add types
this.wasm.types.forEach((type, idx) => {
wat += ` (type $t${idx} (func`;
if (type.params.length > 0) {
wat += ' (param';
type.params.forEach(p => wat += ` ${p}`);
wat += ')';
}
if (type.results.length > 0) {
wat += ' (result';
type.results.forEach(r => wat += ` ${r}`);
wat += ')';
}
wat += '))\n';
});
wat += '\n';
// Add functions
this.wasm.functions.forEach((func, idx) => {
wat += ` (func $${func.name} (type $t${func.typeIdx})`;
// Add locals
if (func.locals.length > 0) {
func.locals.forEach(local => {
wat += `\n (local $${local.name} ${local.type})`;
});
}
// Add body
wat += '\n';
func.body.forEach(inst => {
wat += ` ${inst}\n`;
});
wat += ' )\n\n';
});
// Add exports
this.wasm.exports.forEach(exp => {
wat += ` (export "${exp.name}" (func $${exp.name}))\n`;
});
wat += ')\n';
return wat;
}
generateBinary(wat) {
// In a real implementation, this would use wabt or similar to convert WAT to WASM
// For now, return the WAT text
return Buffer.from(wat);
}
}
module.exports = WasmGenerator;