]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/tools/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
Merge compiler-rt release_40 branch r292009.
[FreeBSD/FreeBSD.git] / contrib / llvm / tools / clang / lib / CodeGen / CGOpenMPRuntimeNVPTX.cpp
1 //===---- CGOpenMPRuntimeNVPTX.cpp - Interface to OpenMP NVPTX Runtimes ---===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This provides a class for OpenMP runtime code generation specialized to NVPTX
11 // targets.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "CGOpenMPRuntimeNVPTX.h"
16 #include "clang/AST/DeclOpenMP.h"
17 #include "CodeGenFunction.h"
18 #include "clang/AST/StmtOpenMP.h"
19
20 using namespace clang;
21 using namespace CodeGen;
22
23 namespace {
24 enum OpenMPRTLFunctionNVPTX {
25   /// \brief Call to void __kmpc_kernel_init(kmp_int32 thread_limit);
26   OMPRTL_NVPTX__kmpc_kernel_init,
27   /// \brief Call to void __kmpc_kernel_deinit();
28   OMPRTL_NVPTX__kmpc_kernel_deinit,
29   /// \brief Call to void __kmpc_kernel_prepare_parallel(void
30   /// *outlined_function);
31   OMPRTL_NVPTX__kmpc_kernel_prepare_parallel,
32   /// \brief Call to bool __kmpc_kernel_parallel(void **outlined_function);
33   OMPRTL_NVPTX__kmpc_kernel_parallel,
34   /// \brief Call to void __kmpc_kernel_end_parallel();
35   OMPRTL_NVPTX__kmpc_kernel_end_parallel,
36   /// Call to void __kmpc_serialized_parallel(ident_t *loc, kmp_int32
37   /// global_tid);
38   OMPRTL_NVPTX__kmpc_serialized_parallel,
39   /// Call to void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32
40   /// global_tid);
41   OMPRTL_NVPTX__kmpc_end_serialized_parallel,
42 };
43
44 /// Pre(post)-action for different OpenMP constructs specialized for NVPTX.
45 class NVPTXActionTy final : public PrePostActionTy {
46   llvm::Value *EnterCallee;
47   ArrayRef<llvm::Value *> EnterArgs;
48   llvm::Value *ExitCallee;
49   ArrayRef<llvm::Value *> ExitArgs;
50   bool Conditional;
51   llvm::BasicBlock *ContBlock = nullptr;
52
53 public:
54   NVPTXActionTy(llvm::Value *EnterCallee, ArrayRef<llvm::Value *> EnterArgs,
55                 llvm::Value *ExitCallee, ArrayRef<llvm::Value *> ExitArgs,
56                 bool Conditional = false)
57       : EnterCallee(EnterCallee), EnterArgs(EnterArgs), ExitCallee(ExitCallee),
58         ExitArgs(ExitArgs), Conditional(Conditional) {}
59   void Enter(CodeGenFunction &CGF) override {
60     llvm::Value *EnterRes = CGF.EmitRuntimeCall(EnterCallee, EnterArgs);
61     if (Conditional) {
62       llvm::Value *CallBool = CGF.Builder.CreateIsNotNull(EnterRes);
63       auto *ThenBlock = CGF.createBasicBlock("omp_if.then");
64       ContBlock = CGF.createBasicBlock("omp_if.end");
65       // Generate the branch (If-stmt)
66       CGF.Builder.CreateCondBr(CallBool, ThenBlock, ContBlock);
67       CGF.EmitBlock(ThenBlock);
68     }
69   }
70   void Done(CodeGenFunction &CGF) {
71     // Emit the rest of blocks/branches
72     CGF.EmitBranch(ContBlock);
73     CGF.EmitBlock(ContBlock, true);
74   }
75   void Exit(CodeGenFunction &CGF) override {
76     CGF.EmitRuntimeCall(ExitCallee, ExitArgs);
77   }
78 };
79 } // anonymous namespace
80
81 /// Get the GPU warp size.
82 static llvm::Value *getNVPTXWarpSize(CodeGenFunction &CGF) {
83   CGBuilderTy &Bld = CGF.Builder;
84   return Bld.CreateCall(
85       llvm::Intrinsic::getDeclaration(
86           &CGF.CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_warpsize),
87       llvm::None, "nvptx_warp_size");
88 }
89
90 /// Get the id of the current thread on the GPU.
91 static llvm::Value *getNVPTXThreadID(CodeGenFunction &CGF) {
92   CGBuilderTy &Bld = CGF.Builder;
93   return Bld.CreateCall(
94       llvm::Intrinsic::getDeclaration(
95           &CGF.CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x),
96       llvm::None, "nvptx_tid");
97 }
98
99 /// Get the maximum number of threads in a block of the GPU.
100 static llvm::Value *getNVPTXNumThreads(CodeGenFunction &CGF) {
101   CGBuilderTy &Bld = CGF.Builder;
102   return Bld.CreateCall(
103       llvm::Intrinsic::getDeclaration(
104           &CGF.CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x),
105       llvm::None, "nvptx_num_threads");
106 }
107
108 /// Get barrier to synchronize all threads in a block.
109 static void getNVPTXCTABarrier(CodeGenFunction &CGF) {
110   CGBuilderTy &Bld = CGF.Builder;
111   Bld.CreateCall(llvm::Intrinsic::getDeclaration(
112       &CGF.CGM.getModule(), llvm::Intrinsic::nvvm_barrier0));
113 }
114
115 /// Synchronize all GPU threads in a block.
116 static void syncCTAThreads(CodeGenFunction &CGF) { getNVPTXCTABarrier(CGF); }
117
118 /// Get the value of the thread_limit clause in the teams directive.
119 /// The runtime encodes thread_limit in the launch parameter, always starting
120 /// thread_limit+warpSize threads per team.
121 static llvm::Value *getThreadLimit(CodeGenFunction &CGF) {
122   CGBuilderTy &Bld = CGF.Builder;
123   return Bld.CreateSub(getNVPTXNumThreads(CGF), getNVPTXWarpSize(CGF),
124                        "thread_limit");
125 }
126
127 /// Get the thread id of the OMP master thread.
128 /// The master thread id is the first thread (lane) of the last warp in the
129 /// GPU block.  Warp size is assumed to be some power of 2.
130 /// Thread id is 0 indexed.
131 /// E.g: If NumThreads is 33, master id is 32.
132 ///      If NumThreads is 64, master id is 32.
133 ///      If NumThreads is 1024, master id is 992.
134 static llvm::Value *getMasterThreadID(CodeGenFunction &CGF) {
135   CGBuilderTy &Bld = CGF.Builder;
136   llvm::Value *NumThreads = getNVPTXNumThreads(CGF);
137
138   // We assume that the warp size is a power of 2.
139   llvm::Value *Mask = Bld.CreateSub(getNVPTXWarpSize(CGF), Bld.getInt32(1));
140
141   return Bld.CreateAnd(Bld.CreateSub(NumThreads, Bld.getInt32(1)),
142                        Bld.CreateNot(Mask), "master_tid");
143 }
144
145 CGOpenMPRuntimeNVPTX::WorkerFunctionState::WorkerFunctionState(
146     CodeGenModule &CGM)
147     : WorkerFn(nullptr), CGFI(nullptr) {
148   createWorkerFunction(CGM);
149 }
150
151 void CGOpenMPRuntimeNVPTX::WorkerFunctionState::createWorkerFunction(
152     CodeGenModule &CGM) {
153   // Create an worker function with no arguments.
154   CGFI = &CGM.getTypes().arrangeNullaryFunction();
155
156   WorkerFn = llvm::Function::Create(
157       CGM.getTypes().GetFunctionType(*CGFI), llvm::GlobalValue::InternalLinkage,
158       /* placeholder */ "_worker", &CGM.getModule());
159   CGM.SetInternalFunctionAttributes(/*D=*/nullptr, WorkerFn, *CGFI);
160 }
161
162 void CGOpenMPRuntimeNVPTX::emitGenericKernel(const OMPExecutableDirective &D,
163                                              StringRef ParentName,
164                                              llvm::Function *&OutlinedFn,
165                                              llvm::Constant *&OutlinedFnID,
166                                              bool IsOffloadEntry,
167                                              const RegionCodeGenTy &CodeGen) {
168   EntryFunctionState EST;
169   WorkerFunctionState WST(CGM);
170   Work.clear();
171
172   // Emit target region as a standalone region.
173   class NVPTXPrePostActionTy : public PrePostActionTy {
174     CGOpenMPRuntimeNVPTX &RT;
175     CGOpenMPRuntimeNVPTX::EntryFunctionState &EST;
176     CGOpenMPRuntimeNVPTX::WorkerFunctionState &WST;
177
178   public:
179     NVPTXPrePostActionTy(CGOpenMPRuntimeNVPTX &RT,
180                          CGOpenMPRuntimeNVPTX::EntryFunctionState &EST,
181                          CGOpenMPRuntimeNVPTX::WorkerFunctionState &WST)
182         : RT(RT), EST(EST), WST(WST) {}
183     void Enter(CodeGenFunction &CGF) override {
184       RT.emitGenericEntryHeader(CGF, EST, WST);
185     }
186     void Exit(CodeGenFunction &CGF) override {
187       RT.emitGenericEntryFooter(CGF, EST);
188     }
189   } Action(*this, EST, WST);
190   CodeGen.setAction(Action);
191   emitTargetOutlinedFunctionHelper(D, ParentName, OutlinedFn, OutlinedFnID,
192                                    IsOffloadEntry, CodeGen);
193
194   // Create the worker function
195   emitWorkerFunction(WST);
196
197   // Now change the name of the worker function to correspond to this target
198   // region's entry function.
199   WST.WorkerFn->setName(OutlinedFn->getName() + "_worker");
200 }
201
202 // Setup NVPTX threads for master-worker OpenMP scheme.
203 void CGOpenMPRuntimeNVPTX::emitGenericEntryHeader(CodeGenFunction &CGF,
204                                                   EntryFunctionState &EST,
205                                                   WorkerFunctionState &WST) {
206   CGBuilderTy &Bld = CGF.Builder;
207
208   llvm::BasicBlock *WorkerBB = CGF.createBasicBlock(".worker");
209   llvm::BasicBlock *MasterCheckBB = CGF.createBasicBlock(".mastercheck");
210   llvm::BasicBlock *MasterBB = CGF.createBasicBlock(".master");
211   EST.ExitBB = CGF.createBasicBlock(".exit");
212
213   auto *IsWorker =
214       Bld.CreateICmpULT(getNVPTXThreadID(CGF), getThreadLimit(CGF));
215   Bld.CreateCondBr(IsWorker, WorkerBB, MasterCheckBB);
216
217   CGF.EmitBlock(WorkerBB);
218   CGF.EmitCallOrInvoke(WST.WorkerFn, llvm::None);
219   CGF.EmitBranch(EST.ExitBB);
220
221   CGF.EmitBlock(MasterCheckBB);
222   auto *IsMaster =
223       Bld.CreateICmpEQ(getNVPTXThreadID(CGF), getMasterThreadID(CGF));
224   Bld.CreateCondBr(IsMaster, MasterBB, EST.ExitBB);
225
226   CGF.EmitBlock(MasterBB);
227   // First action in sequential region:
228   // Initialize the state of the OpenMP runtime library on the GPU.
229   llvm::Value *Args[] = {getThreadLimit(CGF)};
230   CGF.EmitRuntimeCall(
231       createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_init), Args);
232 }
233
234 void CGOpenMPRuntimeNVPTX::emitGenericEntryFooter(CodeGenFunction &CGF,
235                                                   EntryFunctionState &EST) {
236   if (!EST.ExitBB)
237     EST.ExitBB = CGF.createBasicBlock(".exit");
238
239   llvm::BasicBlock *TerminateBB = CGF.createBasicBlock(".termination.notifier");
240   CGF.EmitBranch(TerminateBB);
241
242   CGF.EmitBlock(TerminateBB);
243   // Signal termination condition.
244   CGF.EmitRuntimeCall(
245       createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_deinit), None);
246   // Barrier to terminate worker threads.
247   syncCTAThreads(CGF);
248   // Master thread jumps to exit point.
249   CGF.EmitBranch(EST.ExitBB);
250
251   CGF.EmitBlock(EST.ExitBB);
252   EST.ExitBB = nullptr;
253 }
254
255 void CGOpenMPRuntimeNVPTX::emitWorkerFunction(WorkerFunctionState &WST) {
256   auto &Ctx = CGM.getContext();
257
258   CodeGenFunction CGF(CGM, /*suppressNewContext=*/true);
259   CGF.disableDebugInfo();
260   CGF.StartFunction(GlobalDecl(), Ctx.VoidTy, WST.WorkerFn, *WST.CGFI, {});
261   emitWorkerLoop(CGF, WST);
262   CGF.FinishFunction();
263 }
264
265 void CGOpenMPRuntimeNVPTX::emitWorkerLoop(CodeGenFunction &CGF,
266                                           WorkerFunctionState &WST) {
267   //
268   // The workers enter this loop and wait for parallel work from the master.
269   // When the master encounters a parallel region it sets up the work + variable
270   // arguments, and wakes up the workers.  The workers first check to see if
271   // they are required for the parallel region, i.e., within the # of requested
272   // parallel threads.  The activated workers load the variable arguments and
273   // execute the parallel work.
274   //
275
276   CGBuilderTy &Bld = CGF.Builder;
277
278   llvm::BasicBlock *AwaitBB = CGF.createBasicBlock(".await.work");
279   llvm::BasicBlock *SelectWorkersBB = CGF.createBasicBlock(".select.workers");
280   llvm::BasicBlock *ExecuteBB = CGF.createBasicBlock(".execute.parallel");
281   llvm::BasicBlock *TerminateBB = CGF.createBasicBlock(".terminate.parallel");
282   llvm::BasicBlock *BarrierBB = CGF.createBasicBlock(".barrier.parallel");
283   llvm::BasicBlock *ExitBB = CGF.createBasicBlock(".exit");
284
285   CGF.EmitBranch(AwaitBB);
286
287   // Workers wait for work from master.
288   CGF.EmitBlock(AwaitBB);
289   // Wait for parallel work
290   syncCTAThreads(CGF);
291
292   Address WorkFn =
293       CGF.CreateDefaultAlignTempAlloca(CGF.Int8PtrTy, /*Name=*/"work_fn");
294   Address ExecStatus =
295       CGF.CreateDefaultAlignTempAlloca(CGF.Int8Ty, /*Name=*/"exec_status");
296   CGF.InitTempAlloca(ExecStatus, Bld.getInt8(/*C=*/0));
297   CGF.InitTempAlloca(WorkFn, llvm::Constant::getNullValue(CGF.Int8PtrTy));
298
299   llvm::Value *Args[] = {WorkFn.getPointer()};
300   llvm::Value *Ret = CGF.EmitRuntimeCall(
301       createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_parallel), Args);
302   Bld.CreateStore(Bld.CreateZExt(Ret, CGF.Int8Ty), ExecStatus);
303
304   // On termination condition (workid == 0), exit loop.
305   llvm::Value *ShouldTerminate =
306       Bld.CreateIsNull(Bld.CreateLoad(WorkFn), "should_terminate");
307   Bld.CreateCondBr(ShouldTerminate, ExitBB, SelectWorkersBB);
308
309   // Activate requested workers.
310   CGF.EmitBlock(SelectWorkersBB);
311   llvm::Value *IsActive =
312       Bld.CreateIsNotNull(Bld.CreateLoad(ExecStatus), "is_active");
313   Bld.CreateCondBr(IsActive, ExecuteBB, BarrierBB);
314
315   // Signal start of parallel region.
316   CGF.EmitBlock(ExecuteBB);
317
318   // Process work items: outlined parallel functions.
319   for (auto *W : Work) {
320     // Try to match this outlined function.
321     auto *ID = Bld.CreatePointerBitCastOrAddrSpaceCast(W, CGM.Int8PtrTy);
322
323     llvm::Value *WorkFnMatch =
324         Bld.CreateICmpEQ(Bld.CreateLoad(WorkFn), ID, "work_match");
325
326     llvm::BasicBlock *ExecuteFNBB = CGF.createBasicBlock(".execute.fn");
327     llvm::BasicBlock *CheckNextBB = CGF.createBasicBlock(".check.next");
328     Bld.CreateCondBr(WorkFnMatch, ExecuteFNBB, CheckNextBB);
329
330     // Execute this outlined function.
331     CGF.EmitBlock(ExecuteFNBB);
332
333     // Insert call to work function.
334     // FIXME: Pass arguments to outlined function from master thread.
335     auto *Fn = cast<llvm::Function>(W);
336     Address ZeroAddr =
337         CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, /*Name=*/".zero.addr");
338     CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C=*/0));
339     llvm::Value *FnArgs[] = {ZeroAddr.getPointer(), ZeroAddr.getPointer()};
340     CGF.EmitCallOrInvoke(Fn, FnArgs);
341
342     // Go to end of parallel region.
343     CGF.EmitBranch(TerminateBB);
344
345     CGF.EmitBlock(CheckNextBB);
346   }
347
348   // Signal end of parallel region.
349   CGF.EmitBlock(TerminateBB);
350   CGF.EmitRuntimeCall(
351       createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_end_parallel),
352       llvm::None);
353   CGF.EmitBranch(BarrierBB);
354
355   // All active and inactive workers wait at a barrier after parallel region.
356   CGF.EmitBlock(BarrierBB);
357   // Barrier after parallel region.
358   syncCTAThreads(CGF);
359   CGF.EmitBranch(AwaitBB);
360
361   // Exit target region.
362   CGF.EmitBlock(ExitBB);
363 }
364
365 /// \brief Returns specified OpenMP runtime function for the current OpenMP
366 /// implementation.  Specialized for the NVPTX device.
367 /// \param Function OpenMP runtime function.
368 /// \return Specified function.
369 llvm::Constant *
370 CGOpenMPRuntimeNVPTX::createNVPTXRuntimeFunction(unsigned Function) {
371   llvm::Constant *RTLFn = nullptr;
372   switch (static_cast<OpenMPRTLFunctionNVPTX>(Function)) {
373   case OMPRTL_NVPTX__kmpc_kernel_init: {
374     // Build void __kmpc_kernel_init(kmp_int32 thread_limit);
375     llvm::Type *TypeParams[] = {CGM.Int32Ty};
376     llvm::FunctionType *FnTy =
377         llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
378     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_init");
379     break;
380   }
381   case OMPRTL_NVPTX__kmpc_kernel_deinit: {
382     // Build void __kmpc_kernel_deinit();
383     llvm::FunctionType *FnTy =
384         llvm::FunctionType::get(CGM.VoidTy, llvm::None, /*isVarArg*/ false);
385     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_deinit");
386     break;
387   }
388   case OMPRTL_NVPTX__kmpc_kernel_prepare_parallel: {
389     /// Build void __kmpc_kernel_prepare_parallel(
390     /// void *outlined_function);
391     llvm::Type *TypeParams[] = {CGM.Int8PtrTy};
392     llvm::FunctionType *FnTy =
393         llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
394     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_prepare_parallel");
395     break;
396   }
397   case OMPRTL_NVPTX__kmpc_kernel_parallel: {
398     /// Build bool __kmpc_kernel_parallel(void **outlined_function);
399     llvm::Type *TypeParams[] = {CGM.Int8PtrPtrTy};
400     llvm::Type *RetTy = CGM.getTypes().ConvertType(CGM.getContext().BoolTy);
401     llvm::FunctionType *FnTy =
402         llvm::FunctionType::get(RetTy, TypeParams, /*isVarArg*/ false);
403     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_parallel");
404     break;
405   }
406   case OMPRTL_NVPTX__kmpc_kernel_end_parallel: {
407     /// Build void __kmpc_kernel_end_parallel();
408     llvm::FunctionType *FnTy =
409         llvm::FunctionType::get(CGM.VoidTy, llvm::None, /*isVarArg*/ false);
410     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_end_parallel");
411     break;
412   }
413   case OMPRTL_NVPTX__kmpc_serialized_parallel: {
414     // Build void __kmpc_serialized_parallel(ident_t *loc, kmp_int32
415     // global_tid);
416     llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty};
417     llvm::FunctionType *FnTy =
418         llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
419     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_serialized_parallel");
420     break;
421   }
422   case OMPRTL_NVPTX__kmpc_end_serialized_parallel: {
423     // Build void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32
424     // global_tid);
425     llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty};
426     llvm::FunctionType *FnTy =
427         llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
428     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_end_serialized_parallel");
429     break;
430   }
431   }
432   return RTLFn;
433 }
434
435 void CGOpenMPRuntimeNVPTX::createOffloadEntry(llvm::Constant *ID,
436                                               llvm::Constant *Addr,
437                                               uint64_t Size, int32_t) {
438   auto *F = dyn_cast<llvm::Function>(Addr);
439   // TODO: Add support for global variables on the device after declare target
440   // support.
441   if (!F)
442     return;
443   llvm::Module *M = F->getParent();
444   llvm::LLVMContext &Ctx = M->getContext();
445
446   // Get "nvvm.annotations" metadata node
447   llvm::NamedMDNode *MD = M->getOrInsertNamedMetadata("nvvm.annotations");
448
449   llvm::Metadata *MDVals[] = {
450       llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "kernel"),
451       llvm::ConstantAsMetadata::get(
452           llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
453   // Append metadata to nvvm.annotations
454   MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
455 }
456
457 void CGOpenMPRuntimeNVPTX::emitTargetOutlinedFunction(
458     const OMPExecutableDirective &D, StringRef ParentName,
459     llvm::Function *&OutlinedFn, llvm::Constant *&OutlinedFnID,
460     bool IsOffloadEntry, const RegionCodeGenTy &CodeGen) {
461   if (!IsOffloadEntry) // Nothing to do.
462     return;
463
464   assert(!ParentName.empty() && "Invalid target region parent name!");
465
466   emitGenericKernel(D, ParentName, OutlinedFn, OutlinedFnID, IsOffloadEntry,
467                     CodeGen);
468 }
469
470 CGOpenMPRuntimeNVPTX::CGOpenMPRuntimeNVPTX(CodeGenModule &CGM)
471     : CGOpenMPRuntime(CGM) {
472   if (!CGM.getLangOpts().OpenMPIsDevice)
473     llvm_unreachable("OpenMP NVPTX can only handle device code.");
474 }
475
476 void CGOpenMPRuntimeNVPTX::emitNumTeamsClause(CodeGenFunction &CGF,
477                                               const Expr *NumTeams,
478                                               const Expr *ThreadLimit,
479                                               SourceLocation Loc) {}
480
481 llvm::Value *CGOpenMPRuntimeNVPTX::emitParallelOrTeamsOutlinedFunction(
482     const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
483     OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
484
485   llvm::Function *OutlinedFun = nullptr;
486   if (isa<OMPTeamsDirective>(D)) {
487     llvm::Value *OutlinedFunVal =
488         CGOpenMPRuntime::emitParallelOrTeamsOutlinedFunction(
489             D, ThreadIDVar, InnermostKind, CodeGen);
490     OutlinedFun = cast<llvm::Function>(OutlinedFunVal);
491     OutlinedFun->removeFnAttr(llvm::Attribute::NoInline);
492     OutlinedFun->addFnAttr(llvm::Attribute::AlwaysInline);
493   } else {
494     llvm::Value *OutlinedFunVal =
495         CGOpenMPRuntime::emitParallelOrTeamsOutlinedFunction(
496             D, ThreadIDVar, InnermostKind, CodeGen);
497     OutlinedFun = cast<llvm::Function>(OutlinedFunVal);
498   }
499
500   return OutlinedFun;
501 }
502
503 void CGOpenMPRuntimeNVPTX::emitTeamsCall(CodeGenFunction &CGF,
504                                          const OMPExecutableDirective &D,
505                                          SourceLocation Loc,
506                                          llvm::Value *OutlinedFn,
507                                          ArrayRef<llvm::Value *> CapturedVars) {
508   if (!CGF.HaveInsertPoint())
509     return;
510
511   Address ZeroAddr =
512       CGF.CreateTempAlloca(CGF.Int32Ty, CharUnits::fromQuantity(4),
513                            /*Name*/ ".zero.addr");
514   CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0));
515   llvm::SmallVector<llvm::Value *, 16> OutlinedFnArgs;
516   OutlinedFnArgs.push_back(ZeroAddr.getPointer());
517   OutlinedFnArgs.push_back(ZeroAddr.getPointer());
518   OutlinedFnArgs.append(CapturedVars.begin(), CapturedVars.end());
519   CGF.EmitCallOrInvoke(OutlinedFn, OutlinedFnArgs);
520 }
521
522 void CGOpenMPRuntimeNVPTX::emitParallelCall(
523     CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn,
524     ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) {
525   if (!CGF.HaveInsertPoint())
526     return;
527
528   emitGenericParallelCall(CGF, Loc, OutlinedFn, CapturedVars, IfCond);
529 }
530
531 void CGOpenMPRuntimeNVPTX::emitGenericParallelCall(
532     CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn,
533     ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) {
534   llvm::Function *Fn = cast<llvm::Function>(OutlinedFn);
535
536   auto &&L0ParallelGen = [this, Fn, &CapturedVars](CodeGenFunction &CGF,
537                                                    PrePostActionTy &) {
538     CGBuilderTy &Bld = CGF.Builder;
539
540     // Prepare for parallel region. Indicate the outlined function.
541     llvm::Value *Args[] = {Bld.CreateBitOrPointerCast(Fn, CGM.Int8PtrTy)};
542     CGF.EmitRuntimeCall(
543         createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel),
544         Args);
545
546     // Activate workers. This barrier is used by the master to signal
547     // work for the workers.
548     syncCTAThreads(CGF);
549
550     // OpenMP [2.5, Parallel Construct, p.49]
551     // There is an implied barrier at the end of a parallel region. After the
552     // end of a parallel region, only the master thread of the team resumes
553     // execution of the enclosing task region.
554     //
555     // The master waits at this barrier until all workers are done.
556     syncCTAThreads(CGF);
557
558     // Remember for post-processing in worker loop.
559     Work.push_back(Fn);
560   };
561
562   auto *RTLoc = emitUpdateLocation(CGF, Loc);
563   auto *ThreadID = getThreadID(CGF, Loc);
564   llvm::Value *Args[] = {RTLoc, ThreadID};
565
566   auto &&SeqGen = [this, Fn, &CapturedVars, &Args](CodeGenFunction &CGF,
567                                                    PrePostActionTy &) {
568     auto &&CodeGen = [this, Fn, &CapturedVars, &Args](CodeGenFunction &CGF,
569                                                       PrePostActionTy &Action) {
570       Action.Enter(CGF);
571
572       llvm::SmallVector<llvm::Value *, 16> OutlinedFnArgs;
573       OutlinedFnArgs.push_back(
574           llvm::ConstantPointerNull::get(CGM.Int32Ty->getPointerTo()));
575       OutlinedFnArgs.push_back(
576           llvm::ConstantPointerNull::get(CGM.Int32Ty->getPointerTo()));
577       OutlinedFnArgs.append(CapturedVars.begin(), CapturedVars.end());
578       CGF.EmitCallOrInvoke(Fn, OutlinedFnArgs);
579     };
580
581     RegionCodeGenTy RCG(CodeGen);
582     NVPTXActionTy Action(
583         createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_serialized_parallel),
584         Args,
585         createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_end_serialized_parallel),
586         Args);
587     RCG.setAction(Action);
588     RCG(CGF);
589   };
590
591   if (IfCond)
592     emitOMPIfClause(CGF, IfCond, L0ParallelGen, SeqGen);
593   else {
594     CodeGenFunction::RunCleanupsScope Scope(CGF);
595     RegionCodeGenTy ThenRCG(L0ParallelGen);
596     ThenRCG(CGF);
597   }
598 }