UNPKG

@specs-feup/clava

Version:

A C/C++ source-to-source compiler written in Typescript

332 lines (269 loc) 8.71 kB
import { FileJp, FunctionJp, Loop, Varref } from "../../Joinpoints.js"; import ClavaCode from "../ClavaCode.js"; import ClavaJoinPoints from "../ClavaJoinPoints.js"; import MpiAccessPattern from "./MpiAccessPattern.js"; import MpiUtils from "./MpiUtils.js"; import MpiAccessPatterns from "./patterns/MpiAccessPatterns.js"; /** * Applies an MPI scatter-gather strategy to loops. */ export default class MpiScatterGatherLoop { private $loop: Loop; private inputJps: Varref[] = []; private inputAccesses: MpiAccessPattern[] = []; private outputJps: Varref[] = []; private outputAccesses: MpiAccessPattern[] = []; constructor($loop: Loop) { this.$loop = $loop; // Check if loop can be parallelize if (this.$loop.iterationsExpr === undefined) { throw "Could not determine expression with number of iterations of the loop. Check if the loop is in the Canonical Loop Form, according to the OpenMP standard."; } } addInput(varName: string, accessPattern: MpiAccessPattern) { this.addVariable(varName, accessPattern, this.inputJps, this.inputAccesses); } addOutput(varName: string, accessPattern: MpiAccessPattern) { this.addVariable( varName, accessPattern, this.outputJps, this.outputAccesses ); } /** * Adapts code to use the MPI strategy. */ execute() { const $mainFunction = ClavaCode.getFunctionDefinition( "main", true ) as FunctionJp; const $mainFile = $mainFunction.getAncestor("file") as FileJp | undefined; if ($mainFile == undefined) { throw "Could not find file of main function"; } // Add include $mainFile.addInclude("mpi.h"); $mainFile.addInclude("iostream", true); // Add global variables const $intType = ClavaJoinPoints.builtinType("int"); $mainFile.addGlobal(MpiUtils.VAR_NUM_TASKS, $intType, "0"); $mainFile.addGlobal(MpiUtils.VAR_NUM_WORKERS, $intType, "0"); const $rankDecl = $mainFile.addGlobal(MpiUtils.VAR_RANK, $intType, "0"); // Create decl const mpiWorkerFunction = ClavaJoinPoints.declLiteral( this.buildMpiWorker() ); // Add MPI Worker $rankDecl.insertAfter(mpiWorkerFunction); // Replace loop with MPI Master routine this.replaceLoop(); // Add MPI initialization this.addMpiInit($mainFunction); } /** PRIVATE SECTION **/ private static FUNCTION_MPI_WORKER = "mpi_worker"; private static VAR_WORKER_NUM_ELEMS = "mpi_loop_num_elems"; private static VAR_MASTER_TOTAL_ITER = "clava_mpi_total_iter"; private replaceLoop() { let masterSend = ""; for (let i = 0; i < this.inputJps.length; i++) { const $inputJp = this.inputJps[i]; const accessPattern = this.inputAccesses[i]; masterSend += accessPattern.sendMaster( $inputJp, MpiScatterGatherLoop.VAR_MASTER_TOTAL_ITER ); masterSend += "\n"; } let masterReceive = ""; for (let i = 0; i < this.outputJps.length; i++) { const $inputJp = this.outputJps[i]; const accessPattern = this.outputAccesses[i]; masterReceive += accessPattern.receiveMaster( $inputJp, MpiScatterGatherLoop.VAR_MASTER_TOTAL_ITER ); masterReceive += "\n"; } this.$loop.replaceWith( MpiScatterGatherLoop.MpiMaster( MpiUtils.VAR_NUM_WORKERS, this.$loop.iterationsExpr.code, masterSend, masterReceive, MpiUtils._VAR_MPI_STATUS ) ); } private buildMpiWorker() { const workerLoopCode = this.getWorkerLoopCode(); let workerReceive = ""; for (let i = 0; i < this.inputJps.length; i++) { const $inputJp = this.inputJps[i]; const accessPattern = this.inputAccesses[i]; workerReceive += accessPattern.receiveWorker( $inputJp, MpiScatterGatherLoop.VAR_WORKER_NUM_ELEMS ); workerReceive += "\n"; } let outputDecl = ""; for (let i = 0; i < this.outputJps.length; i++) { const $outputJp = this.outputJps[i]; const accessPattern = this.outputAccesses[i]; outputDecl += accessPattern.outputDeclWorker( $outputJp, MpiScatterGatherLoop.VAR_WORKER_NUM_ELEMS ); outputDecl += "\n"; } let workerSend = ""; for (let i = 0; i < this.outputJps.length; i++) { const $outputJp = this.outputJps[i]; const accessPattern = this.outputAccesses[i]; workerSend += accessPattern.sendWorker( $outputJp, MpiScatterGatherLoop.VAR_WORKER_NUM_ELEMS ); workerSend += "\n"; } return MpiScatterGatherLoop.MpiWorker( MpiScatterGatherLoop.FUNCTION_MPI_WORKER, MpiUtils._VAR_MPI_STATUS, MpiScatterGatherLoop.VAR_WORKER_NUM_ELEMS, workerReceive, outputDecl, workerLoopCode, workerSend ); } private getWorkerLoopCode() { // Copy loop const $workerLoop = this.$loop.copy() as Loop; // Adjust start and end of loop $workerLoop.initValue = "0"; $workerLoop.endValue = MpiScatterGatherLoop.VAR_WORKER_NUM_ELEMS; // TODO: Adapt loop body, if needed return $workerLoop.code; } private addMpiInit($mainFunction: FunctionJp) { // Add params to main, if no params if ($mainFunction.params.length === 0) { $mainFunction.setParamsFromStrings(["int argc", "char** argv"]); } const numMainParams = $mainFunction.params.length; if (numMainParams !== 2) { throw `Expected main() function to have 2 paramters, has '${numMainParams}'`; } const argc = $mainFunction.params[0].name; const argv = $mainFunction.params[1].name; $mainFunction.body.insertBegin( MpiScatterGatherLoop.MpiInit( argc, argv, MpiUtils.VAR_RANK, MpiUtils.VAR_NUM_TASKS, MpiUtils.VAR_NUM_WORKERS, MpiScatterGatherLoop.FUNCTION_MPI_WORKER ) ); } private addVariable( varName: string, accessPattern: MpiAccessPattern | undefined = MpiAccessPatterns.SCALAR_PATTERN, namesArray: Varref[], accessesArray: MpiAccessPattern[] ) { // Check if loop contains a reference to the variable let firstVarref = undefined; for (const $v of this.$loop.getDescendants("varref")) { const $varref = $v as Varref; if ($varref.name === varName) { firstVarref = $varref; break; } } if (firstVarref === undefined) { throw `Could not find a reference to the variable '${varName}' in the loop located at ${this.$loop.location}`; } namesArray.push(firstVarref); accessesArray.push(accessPattern); } /** CODEDEFS **/ // TODO: std::cerr should not be hardcoded, lara.code.Logger should be used instead private static MpiInit( argc: string, argv: string, rank: string, numTasks: string, numWorkers: string, mpiWorker: string ) { return ` MPI_Init(&${argc}, &${argv}); MPI_Comm_rank(MPI_COMM_WORLD, &${rank}); MPI_Comm_size(MPI_COMM_WORLD, &${numTasks}); ${numWorkers} = ${numTasks} - 1; if(${numWorkers} == 0) { std::cerr << "This program does not support working with a single process." << std::endl; return 1; } if(${rank} > 0) { ${mpiWorker}(); MPI_Finalize(); return 0; } `; } private static MpiWorker( functionName: string, status: string, numElems: string, receiveData: string, outputDecl: string, loop: string, sendData: string ) { return ` void ${functionName}() { MPI_Status ${status}; // Number of loop iterations int ${numElems}; MPI_Recv(&${numElems}, 1, MPI_INT, 0, 1, MPI_COMM_WORLD, &${status}); ${receiveData} ${outputDecl} ${loop} ${sendData} } `; } private static MpiMaster( numWorkers: string, numIterations: string, masterSend: string, masterReceive: string, status: string ) { return ` // Master routine // split iterations of the loop int clava_mpi_total_iter = ${numIterations}; int clava_mpi_loop_limit = clava_mpi_total_iter; // A better distribution calculation could be used int clava_mpi_num_iter = clava_mpi_total_iter / ${numWorkers}; int clava_mpi_num_iter_last = clava_mpi_num_iter + clava_mpi_total_iter % ${numWorkers}; // int clava_mpi_num_iter_last = clava_mpi_num_iter + (clava_mpi_loop_limit - (clava_mpi_num_iter * ${numWorkers})); // send number of iterations for(int i=0; i<${numWorkers}-1; i++) { MPI_Send(&clava_mpi_num_iter, 1, MPI_INT, i+1, 1, MPI_COMM_WORLD); } MPI_Send(&clava_mpi_num_iter_last, 1, MPI_INT, ${numWorkers}, 1, MPI_COMM_WORLD); ${masterSend} MPI_Status ${status}; ${masterReceive} MPI_Finalize(); `; } }