]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Merge lldb trunk r300422 and resolve conflicts.
[FreeBSD/FreeBSD.git] / contrib / llvm / lib / Transforms / Coroutines / CoroSplit.cpp
1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This pass builds the coroutine frame and outlines resume and destroy parts
10 // of the coroutine into separate functions.
11 //
12 // We present a coroutine to an LLVM as an ordinary function with suspension
13 // points marked up with intrinsics. We let the optimizer party on the coroutine
14 // as a single function for as long as possible. Shortly before the coroutine is
15 // eligible to be inlined into its callers, we split up the coroutine into parts
16 // corresponding to an initial, resume and destroy invocations of the coroutine,
17 // add them to the current SCC and restart the IPO pipeline to optimize the
18 // coroutine subfunctions we extracted before proceeding to the caller of the
19 // coroutine.
20 //===----------------------------------------------------------------------===//
21
22 #include "CoroInternal.h"
23 #include "llvm/Analysis/CallGraphSCCPass.h"
24 #include "llvm/IR/DebugInfoMetadata.h"
25 #include "llvm/IR/InstIterator.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/LegacyPassManager.h"
28 #include "llvm/IR/Verifier.h"
29 #include "llvm/Transforms/Scalar.h"
30 #include "llvm/Transforms/Utils/Cloning.h"
31 #include "llvm/Transforms/Utils/Local.h"
32 #include "llvm/Transforms/Utils/ValueMapper.h"
33
34 using namespace llvm;
35
36 #define DEBUG_TYPE "coro-split"
37
38 // Create an entry block for a resume function with a switch that will jump to
39 // suspend points.
40 static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
41   LLVMContext &C = F.getContext();
42
43   // resume.entry:
44   //  %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
45   //  i32 2
46   //  % index = load i32, i32* %index.addr
47   //  switch i32 %index, label %unreachable [
48   //    i32 0, label %resume.0
49   //    i32 1, label %resume.1
50   //    ...
51   //  ]
52
53   auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
54   auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
55
56   IRBuilder<> Builder(NewEntry);
57   auto *FramePtr = Shape.FramePtr;
58   auto *FrameTy = Shape.FrameTy;
59   auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
60       FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
61   auto *Index = Builder.CreateLoad(GepIndex, "index");
62   auto *Switch =
63       Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
64   Shape.ResumeSwitch = Switch;
65
66   size_t SuspendIndex = 0;
67   for (CoroSuspendInst *S : Shape.CoroSuspends) {
68     ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
69
70     // Replace CoroSave with a store to Index:
71     //    %index.addr = getelementptr %f.frame... (index field number)
72     //    store i32 0, i32* %index.addr1
73     auto *Save = S->getCoroSave();
74     Builder.SetInsertPoint(Save);
75     if (S->isFinal()) {
76       // Final suspend point is represented by storing zero in ResumeFnAddr.
77       auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0,
78                                                           0, "ResumeFn.addr");
79       auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
80           cast<PointerType>(GepIndex->getType())->getElementType()));
81       Builder.CreateStore(NullPtr, GepIndex);
82     } else {
83       auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
84           FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
85       Builder.CreateStore(IndexVal, GepIndex);
86     }
87     Save->replaceAllUsesWith(ConstantTokenNone::get(C));
88     Save->eraseFromParent();
89
90     // Split block before and after coro.suspend and add a jump from an entry
91     // switch:
92     //
93     //  whateverBB:
94     //    whatever
95     //    %0 = call i8 @llvm.coro.suspend(token none, i1 false)
96     //    switch i8 %0, label %suspend[i8 0, label %resume
97     //                                 i8 1, label %cleanup]
98     // becomes:
99     //
100     //  whateverBB:
101     //     whatever
102     //     br label %resume.0.landing
103     //
104     //  resume.0: ; <--- jump from the switch in the resume.entry
105     //     %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
106     //     br label %resume.0.landing
107     //
108     //  resume.0.landing:
109     //     %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
110     //     switch i8 % 1, label %suspend [i8 0, label %resume
111     //                                    i8 1, label %cleanup]
112
113     auto *SuspendBB = S->getParent();
114     auto *ResumeBB =
115         SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
116     auto *LandingBB = ResumeBB->splitBasicBlock(
117         S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
118     Switch->addCase(IndexVal, ResumeBB);
119
120     cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
121     auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
122     S->replaceAllUsesWith(PN);
123     PN->addIncoming(Builder.getInt8(-1), SuspendBB);
124     PN->addIncoming(S, ResumeBB);
125
126     ++SuspendIndex;
127   }
128
129   Builder.SetInsertPoint(UnreachBB);
130   Builder.CreateUnreachable();
131
132   return NewEntry;
133 }
134
135 // In Resumers, we replace fallthrough coro.end with ret void and delete the
136 // rest of the block.
137 static void replaceFallthroughCoroEnd(IntrinsicInst *End,
138                                       ValueToValueMapTy &VMap) {
139   auto *NewE = cast<IntrinsicInst>(VMap[End]);
140   ReturnInst::Create(NewE->getContext(), nullptr, NewE);
141
142   // Remove the rest of the block, by splitting it into an unreachable block.
143   auto *BB = NewE->getParent();
144   BB->splitBasicBlock(NewE);
145   BB->getTerminator()->eraseFromParent();
146 }
147
148 // In Resumers, we replace unwind coro.end with True to force the immediate
149 // unwind to caller.
150 static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) {
151   if (Shape.CoroEnds.empty())
152     return;
153
154   LLVMContext &Context = Shape.CoroEnds.front()->getContext();
155   auto *True = ConstantInt::getTrue(Context);
156   for (CoroEndInst *CE : Shape.CoroEnds) {
157     if (!CE->isUnwind())
158       continue;
159
160     auto *NewCE = cast<IntrinsicInst>(VMap[CE]);
161
162     // If coro.end has an associated bundle, add cleanupret instruction.
163     if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) {
164       Value *FromPad = Bundle->Inputs[0];
165       auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE);
166       NewCE->getParent()->splitBasicBlock(NewCE);
167       CleanupRet->getParent()->getTerminator()->eraseFromParent();
168     }
169
170     NewCE->replaceAllUsesWith(True);
171     NewCE->eraseFromParent();
172   }
173 }
174
175 // Rewrite final suspend point handling. We do not use suspend index to
176 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the
177 // coroutine frame, since it is undefined behavior to resume a coroutine
178 // suspended at the final suspend point. Thus, in the resume function, we can
179 // simply remove the last case (when coro::Shape is built, the final suspend
180 // point (if present) is always the last element of CoroSuspends array).
181 // In the destroy function, we add a code sequence to check if ResumeFnAddress
182 // is Null, and if so, jump to the appropriate label to handle cleanup from the
183 // final suspend point.
184 static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr,
185                                coro::Shape &Shape, SwitchInst *Switch,
186                                bool IsDestroy) {
187   assert(Shape.HasFinalSuspend);
188   auto FinalCaseIt = std::prev(Switch->case_end());
189   BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
190   Switch->removeCase(FinalCaseIt);
191   if (IsDestroy) {
192     BasicBlock *OldSwitchBB = Switch->getParent();
193     auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
194     Builder.SetInsertPoint(OldSwitchBB->getTerminator());
195     auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr,
196                                                         0, 0, "ResumeFn.addr");
197     auto *Load = Builder.CreateLoad(GepIndex);
198     auto *NullPtr =
199         ConstantPointerNull::get(cast<PointerType>(Load->getType()));
200     auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
201     Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
202     OldSwitchBB->getTerminator()->eraseFromParent();
203   }
204 }
205
206 // Create a resume clone by cloning the body of the original function, setting
207 // new entry block and replacing coro.suspend an appropriate value to force
208 // resume or cleanup pass for every suspend point.
209 static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
210                              BasicBlock *ResumeEntry, int8_t FnIndex) {
211   Module *M = F.getParent();
212   auto *FrameTy = Shape.FrameTy;
213   auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0));
214   auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType());
215
216   Function *NewF =
217       Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage,
218                        F.getName() + Suffix, M);
219   NewF->addAttribute(1, Attribute::NonNull);
220   NewF->addAttribute(1, Attribute::NoAlias);
221
222   ValueToValueMapTy VMap;
223   // Replace all args with undefs. The buildCoroutineFrame algorithm already
224   // rewritten access to the args that occurs after suspend points with loads
225   // and stores to/from the coroutine frame.
226   for (Argument &A : F.args())
227     VMap[&A] = UndefValue::get(A.getType());
228
229   SmallVector<ReturnInst *, 4> Returns;
230
231   if (DISubprogram *SP = F.getSubprogram()) {
232     // If we have debug info, add mapping for the metadata nodes that should not
233     // be cloned by CloneFunctionInfo.
234     auto &MD = VMap.MD();
235     MD[SP->getUnit()].reset(SP->getUnit());
236     MD[SP->getType()].reset(SP->getType());
237     MD[SP->getFile()].reset(SP->getFile());
238   }
239   CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns);
240
241   // Remove old returns.
242   for (ReturnInst *Return : Returns)
243     changeToUnreachable(Return, /*UseLLVMTrap=*/false);
244
245   // Remove old return attributes.
246   NewF->removeAttributes(
247       AttributeList::ReturnIndex,
248       AttributeList::get(
249           NewF->getContext(), AttributeList::ReturnIndex,
250           AttributeFuncs::typeIncompatible(NewF->getReturnType())));
251
252   // Make AllocaSpillBlock the new entry block.
253   auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]);
254   auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
255   Entry->moveBefore(&NewF->getEntryBlock());
256   Entry->getTerminator()->eraseFromParent();
257   BranchInst::Create(SwitchBB, Entry);
258   Entry->setName("entry" + Suffix);
259
260   // Clear all predecessors of the new entry block.
261   auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
262   Entry->replaceAllUsesWith(Switch->getDefaultDest());
263
264   IRBuilder<> Builder(&NewF->getEntryBlock().front());
265
266   // Remap frame pointer.
267   Argument *NewFramePtr = &*NewF->arg_begin();
268   Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);
269   NewFramePtr->takeName(OldFramePtr);
270   OldFramePtr->replaceAllUsesWith(NewFramePtr);
271
272   // Remap vFrame pointer.
273   auto *NewVFrame = Builder.CreateBitCast(
274       NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
275   Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
276   OldVFrame->replaceAllUsesWith(NewVFrame);
277
278   // Rewrite final suspend handling as it is not done via switch (allows to
279   // remove final case from the switch, since it is undefined behavior to resume
280   // the coroutine suspended at the final suspend point.
281   if (Shape.HasFinalSuspend) {
282     auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
283     bool IsDestroy = FnIndex != 0;
284     handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy);
285   }
286
287   // Replace coro suspend with the appropriate resume index.
288   // Replacing coro.suspend with (0) will result in control flow proceeding to
289   // a resume label associated with a suspend point, replacing it with (1) will
290   // result in control flow proceeding to a cleanup label associated with this
291   // suspend point.
292   auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0);
293   for (CoroSuspendInst *CS : Shape.CoroSuspends) {
294     auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]);
295     MappedCS->replaceAllUsesWith(NewValue);
296     MappedCS->eraseFromParent();
297   }
298
299   // Remove coro.end intrinsics.
300   replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap);
301   replaceUnwindCoroEnds(Shape, VMap);
302   // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
303   // to suppress deallocation code.
304   coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
305                         /*Elide=*/FnIndex == 2);
306
307   NewF->setCallingConv(CallingConv::Fast);
308
309   return NewF;
310 }
311
312 static void removeCoroEnds(coro::Shape &Shape) {
313   if (Shape.CoroEnds.empty())
314     return;
315
316   LLVMContext &Context = Shape.CoroEnds.front()->getContext();
317   auto *False = ConstantInt::getFalse(Context);
318
319   for (CoroEndInst *CE : Shape.CoroEnds) {
320     CE->replaceAllUsesWith(False);
321     CE->eraseFromParent();
322   }
323 }
324
325 static void replaceFrameSize(coro::Shape &Shape) {
326   if (Shape.CoroSizes.empty())
327     return;
328
329   // In the same function all coro.sizes should have the same result type.
330   auto *SizeIntrin = Shape.CoroSizes.back();
331   Module *M = SizeIntrin->getModule();
332   const DataLayout &DL = M->getDataLayout();
333   auto Size = DL.getTypeAllocSize(Shape.FrameTy);
334   auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
335
336   for (CoroSizeInst *CS : Shape.CoroSizes) {
337     CS->replaceAllUsesWith(SizeConstant);
338     CS->eraseFromParent();
339   }
340 }
341
342 // Create a global constant array containing pointers to functions provided and
343 // set Info parameter of CoroBegin to point at this constant. Example:
344 //
345 //   @f.resumers = internal constant [2 x void(%f.frame*)*]
346 //                    [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
347 //   define void @f() {
348 //     ...
349 //     call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
350 //                    i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
351 //
352 // Assumes that all the functions have the same signature.
353 static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin,
354                         std::initializer_list<Function *> Fns) {
355
356   SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
357   assert(!Args.empty());
358   Function *Part = *Fns.begin();
359   Module *M = Part->getParent();
360   auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
361
362   auto *ConstVal = ConstantArray::get(ArrTy, Args);
363   auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
364                                 GlobalVariable::PrivateLinkage, ConstVal,
365                                 F.getName() + Twine(".resumers"));
366
367   // Update coro.begin instruction to refer to this constant.
368   LLVMContext &C = F.getContext();
369   auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
370   CoroBegin->getId()->setInfo(BC);
371 }
372
373 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
374 static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
375                             Function *DestroyFn, Function *CleanupFn) {
376
377   IRBuilder<> Builder(Shape.FramePtr->getNextNode());
378   auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32(
379       Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField,
380       "resume.addr");
381   Builder.CreateStore(ResumeFn, ResumeAddr);
382
383   Value *DestroyOrCleanupFn = DestroyFn;
384
385   CoroIdInst *CoroId = Shape.CoroBegin->getId();
386   if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
387     // If there is a CoroAlloc and it returns false (meaning we elide the
388     // allocation, use CleanupFn instead of DestroyFn).
389     DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
390   }
391
392   auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32(
393       Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField,
394       "destroy.addr");
395   Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
396 }
397
398 static void postSplitCleanup(Function &F) {
399   removeUnreachableBlocks(F);
400   llvm::legacy::FunctionPassManager FPM(F.getParent());
401
402   FPM.add(createVerifierPass());
403   FPM.add(createSCCPPass());
404   FPM.add(createCFGSimplificationPass());
405   FPM.add(createEarlyCSEPass());
406   FPM.add(createCFGSimplificationPass());
407
408   FPM.doInitialization();
409   FPM.run(F);
410   FPM.doFinalization();
411 }
412
413 // Coroutine has no suspend points. Remove heap allocation for the coroutine
414 // frame if possible.
415 static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) {
416   auto *CoroId = CoroBegin->getId();
417   auto *AllocInst = CoroId->getCoroAlloc();
418   coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr);
419   if (AllocInst) {
420     IRBuilder<> Builder(AllocInst);
421     // FIXME: Need to handle overaligned members.
422     auto *Frame = Builder.CreateAlloca(FrameTy);
423     auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
424     AllocInst->replaceAllUsesWith(Builder.getFalse());
425     AllocInst->eraseFromParent();
426     CoroBegin->replaceAllUsesWith(VFrame);
427   } else {
428     CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
429   }
430   CoroBegin->eraseFromParent();
431 }
432
433 // look for a very simple pattern
434 //    coro.save
435 //    no other calls
436 //    resume or destroy call
437 //    coro.suspend
438 //
439 // If there are other calls between coro.save and coro.suspend, they can
440 // potentially resume or destroy the coroutine, so it is unsafe to eliminate a
441 // suspend point.
442 static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
443                                  CoroBeginInst *CoroBegin) {
444   auto *Save = Suspend->getCoroSave();
445   auto *BB = Suspend->getParent();
446   if (BB != Save->getParent())
447     return false;
448
449   CallSite SingleCallSite;
450
451   // Check that we have only one CallSite.
452   for (Instruction *I = Save->getNextNode(); I != Suspend;
453        I = I->getNextNode()) {
454     if (isa<CoroFrameInst>(I))
455       continue;
456     if (isa<CoroSubFnInst>(I))
457       continue;
458     if (CallSite CS = CallSite(I)) {
459       if (SingleCallSite)
460         return false;
461       else
462         SingleCallSite = CS;
463     }
464   }
465   auto *CallInstr = SingleCallSite.getInstruction();
466   if (!CallInstr)
467     return false;
468
469   auto *Callee = SingleCallSite.getCalledValue()->stripPointerCasts();
470
471   // See if the callsite is for resumption or destruction of the coroutine.
472   auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
473   if (!SubFn)
474     return false;
475
476   // Does not refer to the current coroutine, we cannot do anything with it.
477   if (SubFn->getFrame() != CoroBegin)
478     return false;
479
480   // Replace llvm.coro.suspend with the value that results in resumption over
481   // the resume or cleanup path.
482   Suspend->replaceAllUsesWith(SubFn->getRawIndex());
483   Suspend->eraseFromParent();
484   Save->eraseFromParent();
485
486   // No longer need a call to coro.resume or coro.destroy.
487   CallInstr->eraseFromParent();
488
489   if (SubFn->user_empty())
490     SubFn->eraseFromParent();
491
492   return true;
493 }
494
495 // Remove suspend points that are simplified.
496 static void simplifySuspendPoints(coro::Shape &Shape) {
497   auto &S = Shape.CoroSuspends;
498   size_t I = 0, N = S.size();
499   if (N == 0)
500     return;
501   for (;;) {
502     if (simplifySuspendPoint(S[I], Shape.CoroBegin)) {
503       if (--N == I)
504         break;
505       std::swap(S[I], S[N]);
506       continue;
507     }
508     if (++I == N)
509       break;
510   }
511   S.resize(N);
512 }
513
514 static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
515   coro::Shape Shape(F);
516   if (!Shape.CoroBegin)
517     return;
518
519   simplifySuspendPoints(Shape);
520   buildCoroutineFrame(F, Shape);
521   replaceFrameSize(Shape);
522
523   // If there are no suspend points, no split required, just remove
524   // the allocation and deallocation blocks, they are not needed.
525   if (Shape.CoroSuspends.empty()) {
526     handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy);
527     removeCoroEnds(Shape);
528     postSplitCleanup(F);
529     coro::updateCallGraph(F, {}, CG, SCC);
530     return;
531   }
532
533   auto *ResumeEntry = createResumeEntryBlock(F, Shape);
534   auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0);
535   auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1);
536   auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2);
537
538   // We no longer need coro.end in F.
539   removeCoroEnds(Shape);
540
541   postSplitCleanup(F);
542   postSplitCleanup(*ResumeClone);
543   postSplitCleanup(*DestroyClone);
544   postSplitCleanup(*CleanupClone);
545
546   // Store addresses resume/destroy/cleanup functions in the coroutine frame.
547   updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
548
549   // Create a constant array referring to resume/destroy/clone functions pointed
550   // by the last argument of @llvm.coro.info, so that CoroElide pass can
551   // determined correct function to call.
552   setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone});
553
554   // Update call graph and add the functions we created to the SCC.
555   coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC);
556 }
557
558 // When we see the coroutine the first time, we insert an indirect call to a
559 // devirt trigger function and mark the coroutine that it is now ready for
560 // split.
561 static void prepareForSplit(Function &F, CallGraph &CG) {
562   Module &M = *F.getParent();
563 #ifndef NDEBUG
564   Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
565   assert(DevirtFn && "coro.devirt.trigger function not found");
566 #endif
567
568   F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
569
570   // Insert an indirect call sequence that will be devirtualized by CoroElide
571   // pass:
572   //    %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
573   //    %1 = bitcast i8* %0 to void(i8*)*
574   //    call void %1(i8* null)
575   coro::LowererBase Lowerer(M);
576   Instruction *InsertPt = F.getEntryBlock().getTerminator();
577   auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(F.getContext()));
578   auto *DevirtFnAddr =
579       Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
580   auto *IndirectCall = CallInst::Create(DevirtFnAddr, Null, "", InsertPt);
581
582   // Update CG graph with an indirect call we just added.
583   CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode());
584 }
585
586 // Make sure that there is a devirtualization trigger function that CoroSplit
587 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not
588 // found, we will create one and add it to the current SCC.
589 static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
590   Module &M = CG.getModule();
591   if (M.getFunction(CORO_DEVIRT_TRIGGER_FN))
592     return;
593
594   LLVMContext &C = M.getContext();
595   auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
596                                  /*IsVarArgs=*/false);
597   Function *DevirtFn =
598       Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
599                        CORO_DEVIRT_TRIGGER_FN, &M);
600   DevirtFn->addFnAttr(Attribute::AlwaysInline);
601   auto *Entry = BasicBlock::Create(C, "entry", DevirtFn);
602   ReturnInst::Create(C, Entry);
603
604   auto *Node = CG.getOrInsertFunction(DevirtFn);
605
606   SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
607   Nodes.push_back(Node);
608   SCC.initialize(Nodes);
609 }
610
611 //===----------------------------------------------------------------------===//
612 //                              Top Level Driver
613 //===----------------------------------------------------------------------===//
614
615 namespace {
616
617 struct CoroSplit : public CallGraphSCCPass {
618   static char ID; // Pass identification, replacement for typeid
619   CoroSplit() : CallGraphSCCPass(ID) {}
620
621   bool Run = false;
622
623   // A coroutine is identified by the presence of coro.begin intrinsic, if
624   // we don't have any, this pass has nothing to do.
625   bool doInitialization(CallGraph &CG) override {
626     Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"});
627     return CallGraphSCCPass::doInitialization(CG);
628   }
629
630   bool runOnSCC(CallGraphSCC &SCC) override {
631     if (!Run)
632       return false;
633
634     // Find coroutines for processing.
635     SmallVector<Function *, 4> Coroutines;
636     for (CallGraphNode *CGN : SCC)
637       if (auto *F = CGN->getFunction())
638         if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
639           Coroutines.push_back(F);
640
641     if (Coroutines.empty())
642       return false;
643
644     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
645     createDevirtTriggerFunc(CG, SCC);
646
647     for (Function *F : Coroutines) {
648       Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
649       StringRef Value = Attr.getValueAsString();
650       DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
651                    << "' state: " << Value << "\n");
652       if (Value == UNPREPARED_FOR_SPLIT) {
653         prepareForSplit(*F, CG);
654         continue;
655       }
656       F->removeFnAttr(CORO_PRESPLIT_ATTR);
657       splitCoroutine(*F, CG, SCC);
658     }
659     return true;
660   }
661
662   void getAnalysisUsage(AnalysisUsage &AU) const override {
663     CallGraphSCCPass::getAnalysisUsage(AU);
664   }
665 };
666 }
667
668 char CoroSplit::ID = 0;
669 INITIALIZE_PASS(
670     CoroSplit, "coro-split",
671     "Split coroutine into a set of functions driving its state machine", false,
672     false)
673
674 Pass *llvm::createCoroSplitPass() { return new CoroSplit(); }