UNPKG

@specs-feup/clava

Version:

A C/C++ source-to-source compiler written in Typescript

506 lines (424 loc) 13.2 kB
import Io from "@specs-feup/lara/api/lara/Io.js"; import Strings from "@specs-feup/lara/api/lara/Strings.js"; import { BuiltinType, Call, FileJp, FunctionJp, Statement, Type, } from "../../Joinpoints.js"; import ClavaJoinPoints from "../ClavaJoinPoints.js"; export interface OpenClKernelReplacerConfiguration { kernelName: string; kernelFile: string; bufferSizes: Record<string, string>; localSize: string[] | string; iterNumbers: string[] | string; outputBuffers: string[]; } export default class KernelReplacer { private call: Call; private function: FunctionJp; private stmt: Statement; private file: FileJp; private bufferSizes: Map<string, string> = new Map(); private inBuffers: Map<string, Buffer>; private outBuffers: Map<string, Buffer> = new Map(); private inOutBuffers: Map<string, Buffer> = new Map(); private kernelCode: string; private kernelName: string; private deviceType: DeviceType; private errorHandling: ErrorHandling; private localSize: string[]; private iter: string[]; constructor( $call: Call, kernelName: string, kernelCodePath: string, bufferSizes: Record<string, string>, localSize: string[] | string, numIters: string[] | string ) { // TODO: verify all parameters // join points this.call = $call; this.function = $call.definition; this.stmt = $call.getAncestor("statement") as Statement; this.file = $call.getAncestor("file") as FileJp; // buffer information for (const key of Object.keys(bufferSizes)) { this.bufferSizes.set(key, bufferSizes[key]); } this.inBuffers = this.makeInBuffers(); // kernel name and source const kernelFile = Io.getPath(this.file.path, kernelCodePath); if (!Io.isFile(kernelFile)) { throw ( "[KernelReplacer] Cannot read OpenCL file in location " + kernelFile.toString() + " (" + kernelCodePath + ")" ); } this.kernelCode = Io.readFile(kernelFile); this.kernelName = kernelName; // device type this.deviceType = DeviceType.CL_DEVICE_TYPE_ALL; // TODO: make a setter for this // error handling this.errorHandling = ErrorHandling.EXIT; // TODO: make setter for this // local size and number of iterations (global size) if (localSize instanceof Array) { this.localSize = localSize; } else { this.localSize = [localSize]; } if (numIters instanceof Array) { this.iter = numIters; } else { this.iter = [numIters]; } if (this.localSize.length !== this.iter.length) { throw new Error( "KernelReplacer(): localSize and numIters must have the same number of dimensions" ); } } /* ------------------------- PUBLIC METHODS ------------------------- */ setOutput(paramName: string) { const inBuf = this.inBuffers.get(paramName); if (inBuf == undefined) { throw new Error(`No input buffer found for parameter '${paramName}'`); } this.outBuffers.set(paramName, inBuf); this.outBuffers.get(paramName)!._kind = BufferKind.OUTPUT; this.inBuffers.delete(paramName); } replaceCall() { const $parentFile = this.file; $parentFile.addInclude("CL/cl.hpp", true); $parentFile.insertBegin("#define __CL_ENABLE_EXCEPTIONS"); const type = ClavaJoinPoints.typeLiteral("const char *"); const sourceStringName = this.kernelName + "_source_code"; $parentFile.addGlobal(sourceStringName, type, this.makeKernelCode()); const code = this.makeCode(sourceStringName); const $codeStmt = ClavaJoinPoints.stmtLiteral(code); this.stmt.replaceWith($codeStmt); } /* ------------------------- PRIVATE METHODS ------------------------ */ private makeKernelCode() { return '"' + Strings.escapeJson(this.kernelCode) + '"'; } private makeCode(sourceStringName: string) { let code = "// start of OpenCL code\n"; code += "cl::Program program;\n"; code += "std::vector<cl::Device> devices;\n"; code += "try {\n"; code += KernelReplacer.SetupCode( this.deviceType, this.makeErrorHandlingCode() ); code += this.makeBuffersCode(); code += KernelReplacer.KernelCreation(sourceStringName, this.kernelName); code += this.makeArgBindCode(); code += KernelReplacer.SizesDecl( this.localSize.join(", "), this.makeGlobalSizeCode() ); code += KernelReplacer.EnqueueKernel(); code += this.makeOutputBuffersCode(); code += KernelReplacer.ExceptionCode(); code += "\n// end of OpenCL code\n\n"; return code; } private makeGlobalSizeCode() { const codes = []; let code = "cl::NDRange globalSize("; for (let i = 0; i < this.localSize.length; i++) { const local = this.localSize[i]; const global = this.iter[i]; codes.push( "(int)(ceil(" + global + "/(float)" + local + ")*" + local + ")" ); } code += codes.join(", "); code += ");"; return code; } private makeOutputBuffersCode() { let code = "\n// Read back buffers\n"; this.inOutBuffers.forEach((inOutBuf) => { code += KernelReplacer.BufferCopyOut( inOutBuf._bufferName, inOutBuf._size, inOutBuf._argName ); }); this.outBuffers.forEach((outBuf) => { code += KernelReplacer.BufferCopyOut( outBuf._bufferName, outBuf._size, outBuf._argName ); }); return code; } private makeArgBindCode() { let code = ""; for (let index = 0; index < this.call.args.length; index++) { const $arg = this.call.args[index]; const paramName = this.function.params[index].name; const inTry = this.inBuffers.get(paramName); if (inTry !== undefined) { code += KernelReplacer.ArgBind(String(index), inTry._bufferName); } else { const inOutTry = this.inOutBuffers.get(paramName); if (inOutTry !== undefined) { code += KernelReplacer.ArgBind(String(index), inOutTry._bufferName); } else { const outTry = this.outBuffers.get(paramName); if (outTry !== undefined) { code += KernelReplacer.ArgBind(String(index), outTry._bufferName); } else { code += KernelReplacer.ArgBind(String(index), $arg.code); } } } } return "\n// Bind kernel arguments to kernel\n" + code; } private makeBuffersCode() { let code = "\n// Create device memory buffers\n"; this.inBuffers.forEach((inBuf) => { code += KernelReplacer.BufferDecl( inBuf._bufferName, inBuf._kind, inBuf._size ); }); this.outBuffers.forEach((outBuf) => { code += KernelReplacer.BufferDecl( outBuf._bufferName, outBuf._kind, outBuf._size ); }); this.inOutBuffers.forEach((inOutBuf) => { code += KernelReplacer.BufferDecl( inOutBuf._bufferName, inOutBuf._kind, inOutBuf._size ); }); code += "\n// Bind memory buffers\n"; this.inBuffers.forEach((inBuf) => { code += KernelReplacer.BufferCopyIn( inBuf._bufferName, inBuf._size, inBuf._argName ); }); this.inOutBuffers.forEach((inOutBuf) => { code += KernelReplacer.BufferCopyIn( inOutBuf._bufferName, inOutBuf._size, inOutBuf._argName ); }); return code; } private makeErrorHandlingCode() { switch (this.errorHandling) { case ErrorHandling.EXIT: return "exit(EXIT_FAILURE);"; default: return "exit(EXIT_FAILURE);"; } } private makeInBuffers() { const buffers: Map<string, Buffer> = new Map(); // iterate over function parameters const params = this.function.params; for (let i = 0; i < params.length; i++) { const $param = params[i]; // pick arrays/pointers if ($param.type.isArray || $param.type.isPointer) { const bufferSize = this.getBufferSize($param.name); const $baseType = this.getBaseType($param.type); const argName = this.call.args[i].code; const info = new Buffer( BufferKind.INPUT, $param.name, $baseType, i, bufferSize, argName, $param.name + "_buffer" ); buffers.set($param.name, info); } } return buffers; } private getBufferSize(paramName: string): string { const bufferSize = this.bufferSizes.get(paramName); if (bufferSize == undefined) { throw new Error(`Ǹo buffer size found for parameter '${paramName}'`); } return bufferSize; } private getBaseType($type: Type) { let $newType = $type; while (!($newType instanceof BuiltinType)) { $newType = $newType.unwrap; } return $newType; } /* ---------------------------- CODEDEFS ---------------------------- */ private static SetupCode(deviceType: string, errorHandling: string): string { return `// Query platforms std::vector<cl::Platform> platforms; cl::Platform::get(&platforms); if (platforms.size() == 0) { std::cout << "Platform size 0\n"; ${errorHandling} } // Get list of devices on default platform and create context cl_context_properties properties[] = { CL_CONTEXT_PLATFORM, (cl_context_properties)(platforms[0])(), 0}; cl::Context context(${deviceType}, properties); devices = context.getInfo<CL_CONTEXT_DEVICES>(); // Create command queue for first device cl::CommandQueue queue(context, devices[0], 0); `; } private static ExceptionCode(): string { return `} catch (cl::Error err) { std::cerr << "ERROR: "<<err.what()<<"("<<err.err()<<")"<<std::endl; if (err.err() == CL_BUILD_PROGRAM_FAILURE) { for (cl::Device dev : devices) { // Check the build status cl_build_status status = program.getBuildInfo<CL_PROGRAM_BUILD_STATUS>(dev); if (status != CL_BUILD_ERROR) continue; // Get the build log std::string name = dev.getInfo<CL_DEVICE_NAME>(); std::string buildlog = program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(dev); std::cerr << "Build log for " << name << ":" << std::endl << buildlog << std::endl; } } else { throw err; } } `; } private static BufferCopyOut( bufferName: string, bufferSize: string, argName: string ): string { return `queue.enqueueReadBuffer(${bufferName}, CL_TRUE, 0, ${bufferSize}, ${argName}); `; } private static EnqueueKernel(): string { return `// Enqueue kernel cl::Event event; queue.enqueueNDRangeKernel( kernel, cl::NullRange, globalSize, localSize, NULL, &event); // Block until kernel completion event.wait(); `; } private static SizesDecl(localsize: string, globalsizeCode: string): string { return `// Number of work items in each local work group cl::NDRange localSize(${localsize}); // Number of total work items - localSize must be devisor ${globalsizeCode} `; } private static ArgBind(index: string, arg: string): string { return `kernel.setArg(${index}, ${arg}); `; } private static KernelCreation( sourceString: string, kernelName: string ): string { return `//Build kernel from source string cl::Program::Sources source(1, std::make_pair(${sourceString},strlen(${sourceString}))); program = cl::Program(context, source); program.build(devices); // Create kernel object cl::Kernel kernel(program, "${kernelName}"); `; } private static BufferDecl( bufferName: string, bufferKind: string, bufferSize: string ): string { return `cl::Buffer ${bufferName} = cl::Buffer(context, ${bufferKind}, ${bufferSize}); `; } private static BufferCopyIn( bufferName: string, bufferSize: string, argName: string ): string { return `queue.enqueueWriteBuffer(${bufferName}, CL_TRUE, 0, ${bufferSize}, ${argName}); `; } } /* ------------------------- PRIVATE CLASSES ------------------------ */ class Buffer { _kind: BufferKind; _paramName: string; _baseType: BuiltinType; _index: number; _size: string; _argName: string; _bufferName: string; constructor( kind: BufferKind, paramName: string, baseType: BuiltinType, index: number, size: string, argName: string, bufferName: string ) { this._kind = kind; this._paramName = paramName; this._baseType = baseType; this._index = index; this._size = size; this._argName = argName; this._bufferName = bufferName; } } /* ------------------------------ ENUMS ----------------------------- */ enum BufferKind { INPUT = "CL_MEM_READ_ONLY", OUTPUT = "CL_MEM_WRITE_ONLY", INPUT_OUTPUT = "CL_MEM_READ_WRITE", } enum DeviceType { CL_DEVICE_TYPE_ALL = "CL_DEVICE_TYPE_ALL", CL_DEVICE_TYPE_CPU = "CL_DEVICE_TYPE_CPU", CL_DEVICE_TYPE_GPU = "CL_DEVICE_TYPE_GPU", CL_DEVICE_TYPE_ACCELERATOR = "CL_DEVICE_TYPE_ACCELERATOR", CL_DEVICE_TYPE_DEFAULT = "CL_DEVICE_TYPE_DEFAULT", } enum ErrorHandling { EXIT = 0, RETURN = 1, USER = 2, }