]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Merge ^/head r318560 through r318657.
[FreeBSD/FreeBSD.git] / contrib / llvm / lib / Transforms / Coroutines / CoroSplit.cpp
1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This pass builds the coroutine frame and outlines resume and destroy parts
10 // of the coroutine into separate functions.
11 //
12 // We present a coroutine to an LLVM as an ordinary function with suspension
13 // points marked up with intrinsics. We let the optimizer party on the coroutine
14 // as a single function for as long as possible. Shortly before the coroutine is
15 // eligible to be inlined into its callers, we split up the coroutine into parts
16 // corresponding to an initial, resume and destroy invocations of the coroutine,
17 // add them to the current SCC and restart the IPO pipeline to optimize the
18 // coroutine subfunctions we extracted before proceeding to the caller of the
19 // coroutine.
20 //===----------------------------------------------------------------------===//
21
22 #include "CoroInternal.h"
23 #include "llvm/Analysis/CallGraphSCCPass.h"
24 #include "llvm/IR/DebugInfoMetadata.h"
25 #include "llvm/IR/InstIterator.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/LegacyPassManager.h"
28 #include "llvm/IR/Verifier.h"
29 #include "llvm/Transforms/Scalar.h"
30 #include "llvm/Transforms/Utils/Cloning.h"
31 #include "llvm/Transforms/Utils/Local.h"
32 #include "llvm/Transforms/Utils/ValueMapper.h"
33
34 using namespace llvm;
35
36 #define DEBUG_TYPE "coro-split"
37
38 // Create an entry block for a resume function with a switch that will jump to
39 // suspend points.
40 static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
41   LLVMContext &C = F.getContext();
42
43   // resume.entry:
44   //  %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
45   //  i32 2
46   //  % index = load i32, i32* %index.addr
47   //  switch i32 %index, label %unreachable [
48   //    i32 0, label %resume.0
49   //    i32 1, label %resume.1
50   //    ...
51   //  ]
52
53   auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
54   auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
55
56   IRBuilder<> Builder(NewEntry);
57   auto *FramePtr = Shape.FramePtr;
58   auto *FrameTy = Shape.FrameTy;
59   auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
60       FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
61   auto *Index = Builder.CreateLoad(GepIndex, "index");
62   auto *Switch =
63       Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
64   Shape.ResumeSwitch = Switch;
65
66   size_t SuspendIndex = 0;
67   for (CoroSuspendInst *S : Shape.CoroSuspends) {
68     ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
69
70     // Replace CoroSave with a store to Index:
71     //    %index.addr = getelementptr %f.frame... (index field number)
72     //    store i32 0, i32* %index.addr1
73     auto *Save = S->getCoroSave();
74     Builder.SetInsertPoint(Save);
75     if (S->isFinal()) {
76       // Final suspend point is represented by storing zero in ResumeFnAddr.
77       auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0,
78                                                           0, "ResumeFn.addr");
79       auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
80           cast<PointerType>(GepIndex->getType())->getElementType()));
81       Builder.CreateStore(NullPtr, GepIndex);
82     } else {
83       auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
84           FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
85       Builder.CreateStore(IndexVal, GepIndex);
86     }
87     Save->replaceAllUsesWith(ConstantTokenNone::get(C));
88     Save->eraseFromParent();
89
90     // Split block before and after coro.suspend and add a jump from an entry
91     // switch:
92     //
93     //  whateverBB:
94     //    whatever
95     //    %0 = call i8 @llvm.coro.suspend(token none, i1 false)
96     //    switch i8 %0, label %suspend[i8 0, label %resume
97     //                                 i8 1, label %cleanup]
98     // becomes:
99     //
100     //  whateverBB:
101     //     whatever
102     //     br label %resume.0.landing
103     //
104     //  resume.0: ; <--- jump from the switch in the resume.entry
105     //     %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
106     //     br label %resume.0.landing
107     //
108     //  resume.0.landing:
109     //     %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
110     //     switch i8 % 1, label %suspend [i8 0, label %resume
111     //                                    i8 1, label %cleanup]
112
113     auto *SuspendBB = S->getParent();
114     auto *ResumeBB =
115         SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
116     auto *LandingBB = ResumeBB->splitBasicBlock(
117         S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
118     Switch->addCase(IndexVal, ResumeBB);
119
120     cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
121     auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
122     S->replaceAllUsesWith(PN);
123     PN->addIncoming(Builder.getInt8(-1), SuspendBB);
124     PN->addIncoming(S, ResumeBB);
125
126     ++SuspendIndex;
127   }
128
129   Builder.SetInsertPoint(UnreachBB);
130   Builder.CreateUnreachable();
131
132   return NewEntry;
133 }
134
135 // In Resumers, we replace fallthrough coro.end with ret void and delete the
136 // rest of the block.
137 static void replaceFallthroughCoroEnd(IntrinsicInst *End,
138                                       ValueToValueMapTy &VMap) {
139   auto *NewE = cast<IntrinsicInst>(VMap[End]);
140   ReturnInst::Create(NewE->getContext(), nullptr, NewE);
141
142   // Remove the rest of the block, by splitting it into an unreachable block.
143   auto *BB = NewE->getParent();
144   BB->splitBasicBlock(NewE);
145   BB->getTerminator()->eraseFromParent();
146 }
147
148 // In Resumers, we replace unwind coro.end with True to force the immediate
149 // unwind to caller.
150 static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) {
151   if (Shape.CoroEnds.empty())
152     return;
153
154   LLVMContext &Context = Shape.CoroEnds.front()->getContext();
155   auto *True = ConstantInt::getTrue(Context);
156   for (CoroEndInst *CE : Shape.CoroEnds) {
157     if (!CE->isUnwind())
158       continue;
159
160     auto *NewCE = cast<IntrinsicInst>(VMap[CE]);
161
162     // If coro.end has an associated bundle, add cleanupret instruction.
163     if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) {
164       Value *FromPad = Bundle->Inputs[0];
165       auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE);
166       NewCE->getParent()->splitBasicBlock(NewCE);
167       CleanupRet->getParent()->getTerminator()->eraseFromParent();
168     }
169
170     NewCE->replaceAllUsesWith(True);
171     NewCE->eraseFromParent();
172   }
173 }
174
175 // Rewrite final suspend point handling. We do not use suspend index to
176 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the
177 // coroutine frame, since it is undefined behavior to resume a coroutine
178 // suspended at the final suspend point. Thus, in the resume function, we can
179 // simply remove the last case (when coro::Shape is built, the final suspend
180 // point (if present) is always the last element of CoroSuspends array).
181 // In the destroy function, we add a code sequence to check if ResumeFnAddress
182 // is Null, and if so, jump to the appropriate label to handle cleanup from the
183 // final suspend point.
184 static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr,
185                                coro::Shape &Shape, SwitchInst *Switch,
186                                bool IsDestroy) {
187   assert(Shape.HasFinalSuspend);
188   auto FinalCaseIt = std::prev(Switch->case_end());
189   BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
190   Switch->removeCase(FinalCaseIt);
191   if (IsDestroy) {
192     BasicBlock *OldSwitchBB = Switch->getParent();
193     auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
194     Builder.SetInsertPoint(OldSwitchBB->getTerminator());
195     auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr,
196                                                         0, 0, "ResumeFn.addr");
197     auto *Load = Builder.CreateLoad(GepIndex);
198     auto *NullPtr =
199         ConstantPointerNull::get(cast<PointerType>(Load->getType()));
200     auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
201     Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
202     OldSwitchBB->getTerminator()->eraseFromParent();
203   }
204 }
205
206 // Create a resume clone by cloning the body of the original function, setting
207 // new entry block and replacing coro.suspend an appropriate value to force
208 // resume or cleanup pass for every suspend point.
209 static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
210                              BasicBlock *ResumeEntry, int8_t FnIndex) {
211   Module *M = F.getParent();
212   auto *FrameTy = Shape.FrameTy;
213   auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0));
214   auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType());
215
216   Function *NewF =
217       Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage,
218                        F.getName() + Suffix, M);
219   NewF->addParamAttr(0, Attribute::NonNull);
220   NewF->addParamAttr(0, Attribute::NoAlias);
221
222   ValueToValueMapTy VMap;
223   // Replace all args with undefs. The buildCoroutineFrame algorithm already
224   // rewritten access to the args that occurs after suspend points with loads
225   // and stores to/from the coroutine frame.
226   for (Argument &A : F.args())
227     VMap[&A] = UndefValue::get(A.getType());
228
229   SmallVector<ReturnInst *, 4> Returns;
230
231   if (DISubprogram *SP = F.getSubprogram()) {
232     // If we have debug info, add mapping for the metadata nodes that should not
233     // be cloned by CloneFunctionInfo.
234     auto &MD = VMap.MD();
235     MD[SP->getUnit()].reset(SP->getUnit());
236     MD[SP->getType()].reset(SP->getType());
237     MD[SP->getFile()].reset(SP->getFile());
238   }
239   CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns);
240
241   // Remove old returns.
242   for (ReturnInst *Return : Returns)
243     changeToUnreachable(Return, /*UseLLVMTrap=*/false);
244
245   // Remove old return attributes.
246   NewF->removeAttributes(
247       AttributeList::ReturnIndex,
248       AttributeFuncs::typeIncompatible(NewF->getReturnType()));
249
250   // Make AllocaSpillBlock the new entry block.
251   auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]);
252   auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
253   Entry->moveBefore(&NewF->getEntryBlock());
254   Entry->getTerminator()->eraseFromParent();
255   BranchInst::Create(SwitchBB, Entry);
256   Entry->setName("entry" + Suffix);
257
258   // Clear all predecessors of the new entry block.
259   auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
260   Entry->replaceAllUsesWith(Switch->getDefaultDest());
261
262   IRBuilder<> Builder(&NewF->getEntryBlock().front());
263
264   // Remap frame pointer.
265   Argument *NewFramePtr = &*NewF->arg_begin();
266   Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);
267   NewFramePtr->takeName(OldFramePtr);
268   OldFramePtr->replaceAllUsesWith(NewFramePtr);
269
270   // Remap vFrame pointer.
271   auto *NewVFrame = Builder.CreateBitCast(
272       NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
273   Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
274   OldVFrame->replaceAllUsesWith(NewVFrame);
275
276   // Rewrite final suspend handling as it is not done via switch (allows to
277   // remove final case from the switch, since it is undefined behavior to resume
278   // the coroutine suspended at the final suspend point.
279   if (Shape.HasFinalSuspend) {
280     auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
281     bool IsDestroy = FnIndex != 0;
282     handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy);
283   }
284
285   // Replace coro suspend with the appropriate resume index.
286   // Replacing coro.suspend with (0) will result in control flow proceeding to
287   // a resume label associated with a suspend point, replacing it with (1) will
288   // result in control flow proceeding to a cleanup label associated with this
289   // suspend point.
290   auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0);
291   for (CoroSuspendInst *CS : Shape.CoroSuspends) {
292     auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]);
293     MappedCS->replaceAllUsesWith(NewValue);
294     MappedCS->eraseFromParent();
295   }
296
297   // Remove coro.end intrinsics.
298   replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap);
299   replaceUnwindCoroEnds(Shape, VMap);
300   // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
301   // to suppress deallocation code.
302   coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
303                         /*Elide=*/FnIndex == 2);
304
305   NewF->setCallingConv(CallingConv::Fast);
306
307   return NewF;
308 }
309
310 static void removeCoroEnds(coro::Shape &Shape) {
311   if (Shape.CoroEnds.empty())
312     return;
313
314   LLVMContext &Context = Shape.CoroEnds.front()->getContext();
315   auto *False = ConstantInt::getFalse(Context);
316
317   for (CoroEndInst *CE : Shape.CoroEnds) {
318     CE->replaceAllUsesWith(False);
319     CE->eraseFromParent();
320   }
321 }
322
323 static void replaceFrameSize(coro::Shape &Shape) {
324   if (Shape.CoroSizes.empty())
325     return;
326
327   // In the same function all coro.sizes should have the same result type.
328   auto *SizeIntrin = Shape.CoroSizes.back();
329   Module *M = SizeIntrin->getModule();
330   const DataLayout &DL = M->getDataLayout();
331   auto Size = DL.getTypeAllocSize(Shape.FrameTy);
332   auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
333
334   for (CoroSizeInst *CS : Shape.CoroSizes) {
335     CS->replaceAllUsesWith(SizeConstant);
336     CS->eraseFromParent();
337   }
338 }
339
340 // Create a global constant array containing pointers to functions provided and
341 // set Info parameter of CoroBegin to point at this constant. Example:
342 //
343 //   @f.resumers = internal constant [2 x void(%f.frame*)*]
344 //                    [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
345 //   define void @f() {
346 //     ...
347 //     call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
348 //                    i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
349 //
350 // Assumes that all the functions have the same signature.
351 static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin,
352                         std::initializer_list<Function *> Fns) {
353
354   SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
355   assert(!Args.empty());
356   Function *Part = *Fns.begin();
357   Module *M = Part->getParent();
358   auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
359
360   auto *ConstVal = ConstantArray::get(ArrTy, Args);
361   auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
362                                 GlobalVariable::PrivateLinkage, ConstVal,
363                                 F.getName() + Twine(".resumers"));
364
365   // Update coro.begin instruction to refer to this constant.
366   LLVMContext &C = F.getContext();
367   auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
368   CoroBegin->getId()->setInfo(BC);
369 }
370
371 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
372 static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
373                             Function *DestroyFn, Function *CleanupFn) {
374
375   IRBuilder<> Builder(Shape.FramePtr->getNextNode());
376   auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32(
377       Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField,
378       "resume.addr");
379   Builder.CreateStore(ResumeFn, ResumeAddr);
380
381   Value *DestroyOrCleanupFn = DestroyFn;
382
383   CoroIdInst *CoroId = Shape.CoroBegin->getId();
384   if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
385     // If there is a CoroAlloc and it returns false (meaning we elide the
386     // allocation, use CleanupFn instead of DestroyFn).
387     DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
388   }
389
390   auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32(
391       Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField,
392       "destroy.addr");
393   Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
394 }
395
396 static void postSplitCleanup(Function &F) {
397   removeUnreachableBlocks(F);
398   llvm::legacy::FunctionPassManager FPM(F.getParent());
399
400   FPM.add(createVerifierPass());
401   FPM.add(createSCCPPass());
402   FPM.add(createCFGSimplificationPass());
403   FPM.add(createEarlyCSEPass());
404   FPM.add(createCFGSimplificationPass());
405
406   FPM.doInitialization();
407   FPM.run(F);
408   FPM.doFinalization();
409 }
410
411 // Coroutine has no suspend points. Remove heap allocation for the coroutine
412 // frame if possible.
413 static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) {
414   auto *CoroId = CoroBegin->getId();
415   auto *AllocInst = CoroId->getCoroAlloc();
416   coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr);
417   if (AllocInst) {
418     IRBuilder<> Builder(AllocInst);
419     // FIXME: Need to handle overaligned members.
420     auto *Frame = Builder.CreateAlloca(FrameTy);
421     auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
422     AllocInst->replaceAllUsesWith(Builder.getFalse());
423     AllocInst->eraseFromParent();
424     CoroBegin->replaceAllUsesWith(VFrame);
425   } else {
426     CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
427   }
428   CoroBegin->eraseFromParent();
429 }
430
431 // look for a very simple pattern
432 //    coro.save
433 //    no other calls
434 //    resume or destroy call
435 //    coro.suspend
436 //
437 // If there are other calls between coro.save and coro.suspend, they can
438 // potentially resume or destroy the coroutine, so it is unsafe to eliminate a
439 // suspend point.
440 static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
441                                  CoroBeginInst *CoroBegin) {
442   auto *Save = Suspend->getCoroSave();
443   auto *BB = Suspend->getParent();
444   if (BB != Save->getParent())
445     return false;
446
447   CallSite SingleCallSite;
448
449   // Check that we have only one CallSite.
450   for (Instruction *I = Save->getNextNode(); I != Suspend;
451        I = I->getNextNode()) {
452     if (isa<CoroFrameInst>(I))
453       continue;
454     if (isa<CoroSubFnInst>(I))
455       continue;
456     if (CallSite CS = CallSite(I)) {
457       if (SingleCallSite)
458         return false;
459       else
460         SingleCallSite = CS;
461     }
462   }
463   auto *CallInstr = SingleCallSite.getInstruction();
464   if (!CallInstr)
465     return false;
466
467   auto *Callee = SingleCallSite.getCalledValue()->stripPointerCasts();
468
469   // See if the callsite is for resumption or destruction of the coroutine.
470   auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
471   if (!SubFn)
472     return false;
473
474   // Does not refer to the current coroutine, we cannot do anything with it.
475   if (SubFn->getFrame() != CoroBegin)
476     return false;
477
478   // Replace llvm.coro.suspend with the value that results in resumption over
479   // the resume or cleanup path.
480   Suspend->replaceAllUsesWith(SubFn->getRawIndex());
481   Suspend->eraseFromParent();
482   Save->eraseFromParent();
483
484   // No longer need a call to coro.resume or coro.destroy.
485   CallInstr->eraseFromParent();
486
487   if (SubFn->user_empty())
488     SubFn->eraseFromParent();
489
490   return true;
491 }
492
493 // Remove suspend points that are simplified.
494 static void simplifySuspendPoints(coro::Shape &Shape) {
495   auto &S = Shape.CoroSuspends;
496   size_t I = 0, N = S.size();
497   if (N == 0)
498     return;
499   for (;;) {
500     if (simplifySuspendPoint(S[I], Shape.CoroBegin)) {
501       if (--N == I)
502         break;
503       std::swap(S[I], S[N]);
504       continue;
505     }
506     if (++I == N)
507       break;
508   }
509   S.resize(N);
510 }
511
512 static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
513   coro::Shape Shape(F);
514   if (!Shape.CoroBegin)
515     return;
516
517   simplifySuspendPoints(Shape);
518   buildCoroutineFrame(F, Shape);
519   replaceFrameSize(Shape);
520
521   // If there are no suspend points, no split required, just remove
522   // the allocation and deallocation blocks, they are not needed.
523   if (Shape.CoroSuspends.empty()) {
524     handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy);
525     removeCoroEnds(Shape);
526     postSplitCleanup(F);
527     coro::updateCallGraph(F, {}, CG, SCC);
528     return;
529   }
530
531   auto *ResumeEntry = createResumeEntryBlock(F, Shape);
532   auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0);
533   auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1);
534   auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2);
535
536   // We no longer need coro.end in F.
537   removeCoroEnds(Shape);
538
539   postSplitCleanup(F);
540   postSplitCleanup(*ResumeClone);
541   postSplitCleanup(*DestroyClone);
542   postSplitCleanup(*CleanupClone);
543
544   // Store addresses resume/destroy/cleanup functions in the coroutine frame.
545   updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
546
547   // Create a constant array referring to resume/destroy/clone functions pointed
548   // by the last argument of @llvm.coro.info, so that CoroElide pass can
549   // determined correct function to call.
550   setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone});
551
552   // Update call graph and add the functions we created to the SCC.
553   coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC);
554 }
555
556 // When we see the coroutine the first time, we insert an indirect call to a
557 // devirt trigger function and mark the coroutine that it is now ready for
558 // split.
559 static void prepareForSplit(Function &F, CallGraph &CG) {
560   Module &M = *F.getParent();
561 #ifndef NDEBUG
562   Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
563   assert(DevirtFn && "coro.devirt.trigger function not found");
564 #endif
565
566   F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
567
568   // Insert an indirect call sequence that will be devirtualized by CoroElide
569   // pass:
570   //    %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
571   //    %1 = bitcast i8* %0 to void(i8*)*
572   //    call void %1(i8* null)
573   coro::LowererBase Lowerer(M);
574   Instruction *InsertPt = F.getEntryBlock().getTerminator();
575   auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(F.getContext()));
576   auto *DevirtFnAddr =
577       Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
578   auto *IndirectCall = CallInst::Create(DevirtFnAddr, Null, "", InsertPt);
579
580   // Update CG graph with an indirect call we just added.
581   CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode());
582 }
583
584 // Make sure that there is a devirtualization trigger function that CoroSplit
585 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not
586 // found, we will create one and add it to the current SCC.
587 static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
588   Module &M = CG.getModule();
589   if (M.getFunction(CORO_DEVIRT_TRIGGER_FN))
590     return;
591
592   LLVMContext &C = M.getContext();
593   auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
594                                  /*IsVarArgs=*/false);
595   Function *DevirtFn =
596       Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
597                        CORO_DEVIRT_TRIGGER_FN, &M);
598   DevirtFn->addFnAttr(Attribute::AlwaysInline);
599   auto *Entry = BasicBlock::Create(C, "entry", DevirtFn);
600   ReturnInst::Create(C, Entry);
601
602   auto *Node = CG.getOrInsertFunction(DevirtFn);
603
604   SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
605   Nodes.push_back(Node);
606   SCC.initialize(Nodes);
607 }
608
609 //===----------------------------------------------------------------------===//
610 //                              Top Level Driver
611 //===----------------------------------------------------------------------===//
612
613 namespace {
614
615 struct CoroSplit : public CallGraphSCCPass {
616   static char ID; // Pass identification, replacement for typeid
617   CoroSplit() : CallGraphSCCPass(ID) {}
618
619   bool Run = false;
620
621   // A coroutine is identified by the presence of coro.begin intrinsic, if
622   // we don't have any, this pass has nothing to do.
623   bool doInitialization(CallGraph &CG) override {
624     Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"});
625     return CallGraphSCCPass::doInitialization(CG);
626   }
627
628   bool runOnSCC(CallGraphSCC &SCC) override {
629     if (!Run)
630       return false;
631
632     // Find coroutines for processing.
633     SmallVector<Function *, 4> Coroutines;
634     for (CallGraphNode *CGN : SCC)
635       if (auto *F = CGN->getFunction())
636         if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
637           Coroutines.push_back(F);
638
639     if (Coroutines.empty())
640       return false;
641
642     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
643     createDevirtTriggerFunc(CG, SCC);
644
645     for (Function *F : Coroutines) {
646       Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
647       StringRef Value = Attr.getValueAsString();
648       DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
649                    << "' state: " << Value << "\n");
650       if (Value == UNPREPARED_FOR_SPLIT) {
651         prepareForSplit(*F, CG);
652         continue;
653       }
654       F->removeFnAttr(CORO_PRESPLIT_ATTR);
655       splitCoroutine(*F, CG, SCC);
656     }
657     return true;
658   }
659
660   void getAnalysisUsage(AnalysisUsage &AU) const override {
661     CallGraphSCCPass::getAnalysisUsage(AU);
662   }
663 };
664 }
665
666 char CoroSplit::ID = 0;
667 INITIALIZE_PASS(
668     CoroSplit, "coro-split",
669     "Split coroutine into a set of functions driving its state machine", false,
670     false)
671
672 Pass *llvm::createCoroSplitPass() { return new CoroSplit(); }