]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Merge llvm, clang, lld, lldb, compiler-rt and libc++ r304659, and update
[FreeBSD/FreeBSD.git] / contrib / llvm / lib / Transforms / Coroutines / CoroSplit.cpp
1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This pass builds the coroutine frame and outlines resume and destroy parts
10 // of the coroutine into separate functions.
11 //
12 // We present a coroutine to an LLVM as an ordinary function with suspension
13 // points marked up with intrinsics. We let the optimizer party on the coroutine
14 // as a single function for as long as possible. Shortly before the coroutine is
15 // eligible to be inlined into its callers, we split up the coroutine into parts
16 // corresponding to an initial, resume and destroy invocations of the coroutine,
17 // add them to the current SCC and restart the IPO pipeline to optimize the
18 // coroutine subfunctions we extracted before proceeding to the caller of the
19 // coroutine.
20 //===----------------------------------------------------------------------===//
21
22 #include "CoroInternal.h"
23 #include "llvm/Analysis/CallGraphSCCPass.h"
24 #include "llvm/IR/DebugInfoMetadata.h"
25 #include "llvm/IR/InstIterator.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/LegacyPassManager.h"
28 #include "llvm/IR/Verifier.h"
29 #include "llvm/Transforms/Scalar.h"
30 #include "llvm/Transforms/Utils/Cloning.h"
31 #include "llvm/Transforms/Utils/Local.h"
32 #include "llvm/Transforms/Utils/ValueMapper.h"
33
34 using namespace llvm;
35
36 #define DEBUG_TYPE "coro-split"
37
38 // Create an entry block for a resume function with a switch that will jump to
39 // suspend points.
40 static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
41   LLVMContext &C = F.getContext();
42
43   // resume.entry:
44   //  %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
45   //  i32 2
46   //  % index = load i32, i32* %index.addr
47   //  switch i32 %index, label %unreachable [
48   //    i32 0, label %resume.0
49   //    i32 1, label %resume.1
50   //    ...
51   //  ]
52
53   auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
54   auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
55
56   IRBuilder<> Builder(NewEntry);
57   auto *FramePtr = Shape.FramePtr;
58   auto *FrameTy = Shape.FrameTy;
59   auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
60       FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
61   auto *Index = Builder.CreateLoad(GepIndex, "index");
62   auto *Switch =
63       Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
64   Shape.ResumeSwitch = Switch;
65
66   size_t SuspendIndex = 0;
67   for (CoroSuspendInst *S : Shape.CoroSuspends) {
68     ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
69
70     // Replace CoroSave with a store to Index:
71     //    %index.addr = getelementptr %f.frame... (index field number)
72     //    store i32 0, i32* %index.addr1
73     auto *Save = S->getCoroSave();
74     Builder.SetInsertPoint(Save);
75     if (S->isFinal()) {
76       // Final suspend point is represented by storing zero in ResumeFnAddr.
77       auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0,
78                                                           0, "ResumeFn.addr");
79       auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
80           cast<PointerType>(GepIndex->getType())->getElementType()));
81       Builder.CreateStore(NullPtr, GepIndex);
82     } else {
83       auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
84           FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
85       Builder.CreateStore(IndexVal, GepIndex);
86     }
87     Save->replaceAllUsesWith(ConstantTokenNone::get(C));
88     Save->eraseFromParent();
89
90     // Split block before and after coro.suspend and add a jump from an entry
91     // switch:
92     //
93     //  whateverBB:
94     //    whatever
95     //    %0 = call i8 @llvm.coro.suspend(token none, i1 false)
96     //    switch i8 %0, label %suspend[i8 0, label %resume
97     //                                 i8 1, label %cleanup]
98     // becomes:
99     //
100     //  whateverBB:
101     //     whatever
102     //     br label %resume.0.landing
103     //
104     //  resume.0: ; <--- jump from the switch in the resume.entry
105     //     %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
106     //     br label %resume.0.landing
107     //
108     //  resume.0.landing:
109     //     %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
110     //     switch i8 % 1, label %suspend [i8 0, label %resume
111     //                                    i8 1, label %cleanup]
112
113     auto *SuspendBB = S->getParent();
114     auto *ResumeBB =
115         SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
116     auto *LandingBB = ResumeBB->splitBasicBlock(
117         S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
118     Switch->addCase(IndexVal, ResumeBB);
119
120     cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
121     auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
122     S->replaceAllUsesWith(PN);
123     PN->addIncoming(Builder.getInt8(-1), SuspendBB);
124     PN->addIncoming(S, ResumeBB);
125
126     ++SuspendIndex;
127   }
128
129   Builder.SetInsertPoint(UnreachBB);
130   Builder.CreateUnreachable();
131
132   return NewEntry;
133 }
134
135 // In Resumers, we replace fallthrough coro.end with ret void and delete the
136 // rest of the block.
137 static void replaceFallthroughCoroEnd(IntrinsicInst *End,
138                                       ValueToValueMapTy &VMap) {
139   auto *NewE = cast<IntrinsicInst>(VMap[End]);
140   ReturnInst::Create(NewE->getContext(), nullptr, NewE);
141
142   // Remove the rest of the block, by splitting it into an unreachable block.
143   auto *BB = NewE->getParent();
144   BB->splitBasicBlock(NewE);
145   BB->getTerminator()->eraseFromParent();
146 }
147
148 // In Resumers, we replace unwind coro.end with True to force the immediate
149 // unwind to caller.
150 static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) {
151   if (Shape.CoroEnds.empty())
152     return;
153
154   LLVMContext &Context = Shape.CoroEnds.front()->getContext();
155   auto *True = ConstantInt::getTrue(Context);
156   for (CoroEndInst *CE : Shape.CoroEnds) {
157     if (!CE->isUnwind())
158       continue;
159
160     auto *NewCE = cast<IntrinsicInst>(VMap[CE]);
161
162     // If coro.end has an associated bundle, add cleanupret instruction.
163     if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) {
164       Value *FromPad = Bundle->Inputs[0];
165       auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE);
166       NewCE->getParent()->splitBasicBlock(NewCE);
167       CleanupRet->getParent()->getTerminator()->eraseFromParent();
168     }
169
170     NewCE->replaceAllUsesWith(True);
171     NewCE->eraseFromParent();
172   }
173 }
174
175 // Rewrite final suspend point handling. We do not use suspend index to
176 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the
177 // coroutine frame, since it is undefined behavior to resume a coroutine
178 // suspended at the final suspend point. Thus, in the resume function, we can
179 // simply remove the last case (when coro::Shape is built, the final suspend
180 // point (if present) is always the last element of CoroSuspends array).
181 // In the destroy function, we add a code sequence to check if ResumeFnAddress
182 // is Null, and if so, jump to the appropriate label to handle cleanup from the
183 // final suspend point.
184 static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr,
185                                coro::Shape &Shape, SwitchInst *Switch,
186                                bool IsDestroy) {
187   assert(Shape.HasFinalSuspend);
188   auto FinalCaseIt = std::prev(Switch->case_end());
189   BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
190   Switch->removeCase(FinalCaseIt);
191   if (IsDestroy) {
192     BasicBlock *OldSwitchBB = Switch->getParent();
193     auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
194     Builder.SetInsertPoint(OldSwitchBB->getTerminator());
195     auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr,
196                                                         0, 0, "ResumeFn.addr");
197     auto *Load = Builder.CreateLoad(GepIndex);
198     auto *NullPtr =
199         ConstantPointerNull::get(cast<PointerType>(Load->getType()));
200     auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
201     Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
202     OldSwitchBB->getTerminator()->eraseFromParent();
203   }
204 }
205
206 // Create a resume clone by cloning the body of the original function, setting
207 // new entry block and replacing coro.suspend an appropriate value to force
208 // resume or cleanup pass for every suspend point.
209 static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
210                              BasicBlock *ResumeEntry, int8_t FnIndex) {
211   Module *M = F.getParent();
212   auto *FrameTy = Shape.FrameTy;
213   auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0));
214   auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType());
215
216   Function *NewF =
217       Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage,
218                        F.getName() + Suffix, M);
219   NewF->addParamAttr(0, Attribute::NonNull);
220   NewF->addParamAttr(0, Attribute::NoAlias);
221
222   ValueToValueMapTy VMap;
223   // Replace all args with undefs. The buildCoroutineFrame algorithm already
224   // rewritten access to the args that occurs after suspend points with loads
225   // and stores to/from the coroutine frame.
226   for (Argument &A : F.args())
227     VMap[&A] = UndefValue::get(A.getType());
228
229   SmallVector<ReturnInst *, 4> Returns;
230
231   CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns);
232
233   // Remove old returns.
234   for (ReturnInst *Return : Returns)
235     changeToUnreachable(Return, /*UseLLVMTrap=*/false);
236
237   // Remove old return attributes.
238   NewF->removeAttributes(
239       AttributeList::ReturnIndex,
240       AttributeFuncs::typeIncompatible(NewF->getReturnType()));
241
242   // Make AllocaSpillBlock the new entry block.
243   auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]);
244   auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
245   Entry->moveBefore(&NewF->getEntryBlock());
246   Entry->getTerminator()->eraseFromParent();
247   BranchInst::Create(SwitchBB, Entry);
248   Entry->setName("entry" + Suffix);
249
250   // Clear all predecessors of the new entry block.
251   auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
252   Entry->replaceAllUsesWith(Switch->getDefaultDest());
253
254   IRBuilder<> Builder(&NewF->getEntryBlock().front());
255
256   // Remap frame pointer.
257   Argument *NewFramePtr = &*NewF->arg_begin();
258   Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);
259   NewFramePtr->takeName(OldFramePtr);
260   OldFramePtr->replaceAllUsesWith(NewFramePtr);
261
262   // Remap vFrame pointer.
263   auto *NewVFrame = Builder.CreateBitCast(
264       NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
265   Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
266   OldVFrame->replaceAllUsesWith(NewVFrame);
267
268   // Rewrite final suspend handling as it is not done via switch (allows to
269   // remove final case from the switch, since it is undefined behavior to resume
270   // the coroutine suspended at the final suspend point.
271   if (Shape.HasFinalSuspend) {
272     auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
273     bool IsDestroy = FnIndex != 0;
274     handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy);
275   }
276
277   // Replace coro suspend with the appropriate resume index.
278   // Replacing coro.suspend with (0) will result in control flow proceeding to
279   // a resume label associated with a suspend point, replacing it with (1) will
280   // result in control flow proceeding to a cleanup label associated with this
281   // suspend point.
282   auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0);
283   for (CoroSuspendInst *CS : Shape.CoroSuspends) {
284     auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]);
285     MappedCS->replaceAllUsesWith(NewValue);
286     MappedCS->eraseFromParent();
287   }
288
289   // Remove coro.end intrinsics.
290   replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap);
291   replaceUnwindCoroEnds(Shape, VMap);
292   // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
293   // to suppress deallocation code.
294   coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
295                         /*Elide=*/FnIndex == 2);
296
297   NewF->setCallingConv(CallingConv::Fast);
298
299   return NewF;
300 }
301
302 static void removeCoroEnds(coro::Shape &Shape) {
303   if (Shape.CoroEnds.empty())
304     return;
305
306   LLVMContext &Context = Shape.CoroEnds.front()->getContext();
307   auto *False = ConstantInt::getFalse(Context);
308
309   for (CoroEndInst *CE : Shape.CoroEnds) {
310     CE->replaceAllUsesWith(False);
311     CE->eraseFromParent();
312   }
313 }
314
315 static void replaceFrameSize(coro::Shape &Shape) {
316   if (Shape.CoroSizes.empty())
317     return;
318
319   // In the same function all coro.sizes should have the same result type.
320   auto *SizeIntrin = Shape.CoroSizes.back();
321   Module *M = SizeIntrin->getModule();
322   const DataLayout &DL = M->getDataLayout();
323   auto Size = DL.getTypeAllocSize(Shape.FrameTy);
324   auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
325
326   for (CoroSizeInst *CS : Shape.CoroSizes) {
327     CS->replaceAllUsesWith(SizeConstant);
328     CS->eraseFromParent();
329   }
330 }
331
332 // Create a global constant array containing pointers to functions provided and
333 // set Info parameter of CoroBegin to point at this constant. Example:
334 //
335 //   @f.resumers = internal constant [2 x void(%f.frame*)*]
336 //                    [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
337 //   define void @f() {
338 //     ...
339 //     call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
340 //                    i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
341 //
342 // Assumes that all the functions have the same signature.
343 static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin,
344                         std::initializer_list<Function *> Fns) {
345
346   SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
347   assert(!Args.empty());
348   Function *Part = *Fns.begin();
349   Module *M = Part->getParent();
350   auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
351
352   auto *ConstVal = ConstantArray::get(ArrTy, Args);
353   auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
354                                 GlobalVariable::PrivateLinkage, ConstVal,
355                                 F.getName() + Twine(".resumers"));
356
357   // Update coro.begin instruction to refer to this constant.
358   LLVMContext &C = F.getContext();
359   auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
360   CoroBegin->getId()->setInfo(BC);
361 }
362
363 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
364 static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
365                             Function *DestroyFn, Function *CleanupFn) {
366
367   IRBuilder<> Builder(Shape.FramePtr->getNextNode());
368   auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32(
369       Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField,
370       "resume.addr");
371   Builder.CreateStore(ResumeFn, ResumeAddr);
372
373   Value *DestroyOrCleanupFn = DestroyFn;
374
375   CoroIdInst *CoroId = Shape.CoroBegin->getId();
376   if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
377     // If there is a CoroAlloc and it returns false (meaning we elide the
378     // allocation, use CleanupFn instead of DestroyFn).
379     DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
380   }
381
382   auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32(
383       Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField,
384       "destroy.addr");
385   Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
386 }
387
388 static void postSplitCleanup(Function &F) {
389   removeUnreachableBlocks(F);
390   llvm::legacy::FunctionPassManager FPM(F.getParent());
391
392   FPM.add(createVerifierPass());
393   FPM.add(createSCCPPass());
394   FPM.add(createCFGSimplificationPass());
395   FPM.add(createEarlyCSEPass());
396   FPM.add(createCFGSimplificationPass());
397
398   FPM.doInitialization();
399   FPM.run(F);
400   FPM.doFinalization();
401 }
402
403 // Coroutine has no suspend points. Remove heap allocation for the coroutine
404 // frame if possible.
405 static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) {
406   auto *CoroId = CoroBegin->getId();
407   auto *AllocInst = CoroId->getCoroAlloc();
408   coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr);
409   if (AllocInst) {
410     IRBuilder<> Builder(AllocInst);
411     // FIXME: Need to handle overaligned members.
412     auto *Frame = Builder.CreateAlloca(FrameTy);
413     auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
414     AllocInst->replaceAllUsesWith(Builder.getFalse());
415     AllocInst->eraseFromParent();
416     CoroBegin->replaceAllUsesWith(VFrame);
417   } else {
418     CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
419   }
420   CoroBegin->eraseFromParent();
421 }
422
423 // look for a very simple pattern
424 //    coro.save
425 //    no other calls
426 //    resume or destroy call
427 //    coro.suspend
428 //
429 // If there are other calls between coro.save and coro.suspend, they can
430 // potentially resume or destroy the coroutine, so it is unsafe to eliminate a
431 // suspend point.
432 static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
433                                  CoroBeginInst *CoroBegin) {
434   auto *Save = Suspend->getCoroSave();
435   auto *BB = Suspend->getParent();
436   if (BB != Save->getParent())
437     return false;
438
439   CallSite SingleCallSite;
440
441   // Check that we have only one CallSite.
442   for (Instruction *I = Save->getNextNode(); I != Suspend;
443        I = I->getNextNode()) {
444     if (isa<CoroFrameInst>(I))
445       continue;
446     if (isa<CoroSubFnInst>(I))
447       continue;
448     if (CallSite CS = CallSite(I)) {
449       if (SingleCallSite)
450         return false;
451       else
452         SingleCallSite = CS;
453     }
454   }
455   auto *CallInstr = SingleCallSite.getInstruction();
456   if (!CallInstr)
457     return false;
458
459   auto *Callee = SingleCallSite.getCalledValue()->stripPointerCasts();
460
461   // See if the callsite is for resumption or destruction of the coroutine.
462   auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
463   if (!SubFn)
464     return false;
465
466   // Does not refer to the current coroutine, we cannot do anything with it.
467   if (SubFn->getFrame() != CoroBegin)
468     return false;
469
470   // Replace llvm.coro.suspend with the value that results in resumption over
471   // the resume or cleanup path.
472   Suspend->replaceAllUsesWith(SubFn->getRawIndex());
473   Suspend->eraseFromParent();
474   Save->eraseFromParent();
475
476   // No longer need a call to coro.resume or coro.destroy.
477   CallInstr->eraseFromParent();
478
479   if (SubFn->user_empty())
480     SubFn->eraseFromParent();
481
482   return true;
483 }
484
485 // Remove suspend points that are simplified.
486 static void simplifySuspendPoints(coro::Shape &Shape) {
487   auto &S = Shape.CoroSuspends;
488   size_t I = 0, N = S.size();
489   if (N == 0)
490     return;
491   for (;;) {
492     if (simplifySuspendPoint(S[I], Shape.CoroBegin)) {
493       if (--N == I)
494         break;
495       std::swap(S[I], S[N]);
496       continue;
497     }
498     if (++I == N)
499       break;
500   }
501   S.resize(N);
502 }
503
504 static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) {
505   // Collect all blocks that we need to look for instructions to relocate.
506   SmallPtrSet<BasicBlock *, 4> RelocBlocks;
507   SmallVector<BasicBlock *, 4> Work;
508   Work.push_back(CB->getParent());
509
510   do {
511     BasicBlock *Current = Work.pop_back_val();
512     for (BasicBlock *BB : predecessors(Current))
513       if (RelocBlocks.count(BB) == 0) {
514         RelocBlocks.insert(BB);
515         Work.push_back(BB);
516       }
517   } while (!Work.empty());
518   return RelocBlocks;
519 }
520
521 static SmallPtrSet<Instruction *, 8>
522 getNotRelocatableInstructions(CoroBeginInst *CoroBegin,
523                               SmallPtrSetImpl<BasicBlock *> &RelocBlocks) {
524   SmallPtrSet<Instruction *, 8> DoNotRelocate;
525   // Collect all instructions that we should not relocate
526   SmallVector<Instruction *, 8> Work;
527
528   // Start with CoroBegin and terminators of all preceding blocks.
529   Work.push_back(CoroBegin);
530   BasicBlock *CoroBeginBB = CoroBegin->getParent();
531   for (BasicBlock *BB : RelocBlocks)
532     if (BB != CoroBeginBB)
533       Work.push_back(BB->getTerminator());
534
535   // For every instruction in the Work list, place its operands in DoNotRelocate
536   // set.
537   do {
538     Instruction *Current = Work.pop_back_val();
539     DoNotRelocate.insert(Current);
540     for (Value *U : Current->operands()) {
541       auto *I = dyn_cast<Instruction>(U);
542       if (!I)
543         continue;
544       if (isa<AllocaInst>(U))
545         continue;
546       if (DoNotRelocate.count(I) == 0) {
547         Work.push_back(I);
548         DoNotRelocate.insert(I);
549       }
550     }
551   } while (!Work.empty());
552   return DoNotRelocate;
553 }
554
555 static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) {
556   // Analyze which non-alloca instructions are needed for allocation and
557   // relocate the rest to after coro.begin. We need to do it, since some of the
558   // targets of those instructions may be placed into coroutine frame memory
559   // for which becomes available after coro.begin intrinsic.
560
561   auto BlockSet = getCoroBeginPredBlocks(CoroBegin);
562   auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet);
563
564   Instruction *InsertPt = CoroBegin->getNextNode();
565   BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well.
566   for (auto B = BB.begin(), E = BB.end(); B != E;) {
567     Instruction &I = *B++;
568     if (isa<AllocaInst>(&I))
569       continue;
570     if (&I == CoroBegin)
571       break;
572     if (DoNotRelocateSet.count(&I))
573       continue;
574     I.moveBefore(InsertPt);
575   }
576 }
577
578 static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
579   coro::Shape Shape(F);
580   if (!Shape.CoroBegin)
581     return;
582
583   simplifySuspendPoints(Shape);
584   relocateInstructionBefore(Shape.CoroBegin, F);
585   buildCoroutineFrame(F, Shape);
586   replaceFrameSize(Shape);
587
588   // If there are no suspend points, no split required, just remove
589   // the allocation and deallocation blocks, they are not needed.
590   if (Shape.CoroSuspends.empty()) {
591     handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy);
592     removeCoroEnds(Shape);
593     postSplitCleanup(F);
594     coro::updateCallGraph(F, {}, CG, SCC);
595     return;
596   }
597
598   auto *ResumeEntry = createResumeEntryBlock(F, Shape);
599   auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0);
600   auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1);
601   auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2);
602
603   // We no longer need coro.end in F.
604   removeCoroEnds(Shape);
605
606   postSplitCleanup(F);
607   postSplitCleanup(*ResumeClone);
608   postSplitCleanup(*DestroyClone);
609   postSplitCleanup(*CleanupClone);
610
611   // Store addresses resume/destroy/cleanup functions in the coroutine frame.
612   updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
613
614   // Create a constant array referring to resume/destroy/clone functions pointed
615   // by the last argument of @llvm.coro.info, so that CoroElide pass can
616   // determined correct function to call.
617   setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone});
618
619   // Update call graph and add the functions we created to the SCC.
620   coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC);
621 }
622
623 // When we see the coroutine the first time, we insert an indirect call to a
624 // devirt trigger function and mark the coroutine that it is now ready for
625 // split.
626 static void prepareForSplit(Function &F, CallGraph &CG) {
627   Module &M = *F.getParent();
628 #ifndef NDEBUG
629   Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
630   assert(DevirtFn && "coro.devirt.trigger function not found");
631 #endif
632
633   F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
634
635   // Insert an indirect call sequence that will be devirtualized by CoroElide
636   // pass:
637   //    %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
638   //    %1 = bitcast i8* %0 to void(i8*)*
639   //    call void %1(i8* null)
640   coro::LowererBase Lowerer(M);
641   Instruction *InsertPt = F.getEntryBlock().getTerminator();
642   auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(F.getContext()));
643   auto *DevirtFnAddr =
644       Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
645   auto *IndirectCall = CallInst::Create(DevirtFnAddr, Null, "", InsertPt);
646
647   // Update CG graph with an indirect call we just added.
648   CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode());
649 }
650
651 // Make sure that there is a devirtualization trigger function that CoroSplit
652 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not
653 // found, we will create one and add it to the current SCC.
654 static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
655   Module &M = CG.getModule();
656   if (M.getFunction(CORO_DEVIRT_TRIGGER_FN))
657     return;
658
659   LLVMContext &C = M.getContext();
660   auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
661                                  /*IsVarArgs=*/false);
662   Function *DevirtFn =
663       Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
664                        CORO_DEVIRT_TRIGGER_FN, &M);
665   DevirtFn->addFnAttr(Attribute::AlwaysInline);
666   auto *Entry = BasicBlock::Create(C, "entry", DevirtFn);
667   ReturnInst::Create(C, Entry);
668
669   auto *Node = CG.getOrInsertFunction(DevirtFn);
670
671   SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
672   Nodes.push_back(Node);
673   SCC.initialize(Nodes);
674 }
675
676 //===----------------------------------------------------------------------===//
677 //                              Top Level Driver
678 //===----------------------------------------------------------------------===//
679
680 namespace {
681
682 struct CoroSplit : public CallGraphSCCPass {
683   static char ID; // Pass identification, replacement for typeid
684   CoroSplit() : CallGraphSCCPass(ID) {
685     initializeCoroSplitPass(*PassRegistry::getPassRegistry());
686   }
687
688   bool Run = false;
689
690   // A coroutine is identified by the presence of coro.begin intrinsic, if
691   // we don't have any, this pass has nothing to do.
692   bool doInitialization(CallGraph &CG) override {
693     Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"});
694     return CallGraphSCCPass::doInitialization(CG);
695   }
696
697   bool runOnSCC(CallGraphSCC &SCC) override {
698     if (!Run)
699       return false;
700
701     // Find coroutines for processing.
702     SmallVector<Function *, 4> Coroutines;
703     for (CallGraphNode *CGN : SCC)
704       if (auto *F = CGN->getFunction())
705         if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
706           Coroutines.push_back(F);
707
708     if (Coroutines.empty())
709       return false;
710
711     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
712     createDevirtTriggerFunc(CG, SCC);
713
714     for (Function *F : Coroutines) {
715       Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
716       StringRef Value = Attr.getValueAsString();
717       DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
718                    << "' state: " << Value << "\n");
719       if (Value == UNPREPARED_FOR_SPLIT) {
720         prepareForSplit(*F, CG);
721         continue;
722       }
723       F->removeFnAttr(CORO_PRESPLIT_ATTR);
724       splitCoroutine(*F, CG, SCC);
725     }
726     return true;
727   }
728
729   void getAnalysisUsage(AnalysisUsage &AU) const override {
730     CallGraphSCCPass::getAnalysisUsage(AU);
731   }
732   StringRef getPassName() const override { return "Coroutine Splitting"; }
733 };
734 }
735
736 char CoroSplit::ID = 0;
737 INITIALIZE_PASS(
738     CoroSplit, "coro-split",
739     "Split coroutine into a set of functions driving its state machine", false,
740     false)
741
742 Pass *llvm::createCoroSplitPass() { return new CoroSplit(); }