//===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements an algorithm that returns for a divergent branch // the set of basic blocks whose phi nodes become divergent due to divergent // control. These are the blocks that are reachable by two disjoint paths from // the branch or loop exits that have a reaching path that is disjoint from a // path to the loop latch. // // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model // control-induced divergence in phi nodes. // // // -- Reference -- // The algorithm is presented in Section 5 of // // An abstract interpretation for SPMD divergence // on reducible control flow graphs. // Julian Rosemann, Simon Moll and Sebastian Hack // POPL '21 // // // -- Sync dependence -- // Sync dependence characterizes the control flow aspect of the // propagation of branch divergence. For example, // // %cond = icmp slt i32 %tid, 10 // br i1 %cond, label %then, label %else // then: // br label %merge // else: // br label %merge // merge: // %a = phi i32 [ 0, %then ], [ 1, %else ] // // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid // because %tid is not on its use-def chains, %a is sync dependent on %tid // because the branch "br i1 %cond" depends on %tid and affects which value %a // is assigned to. // // // -- Reduction to SSA construction -- // There are two disjoint paths from A to X, if a certain variant of SSA // construction places a phi node in X under the following set-up scheme. // // This variant of SSA construction ignores incoming undef values. // That is paths from the entry without a definition do not result in // phi nodes. // // entry // / \ // A \ // / \ Y // B C / // \ / \ / // D E // \ / // F // // Assume that A contains a divergent branch. We are interested // in the set of all blocks where each block is reachable from A // via two disjoint paths. This would be the set {D, F} in this // case. // To generally reduce this query to SSA construction we introduce // a virtual variable x and assign to x different values in each // successor block of A. // // entry // / \ // A \ // / \ Y // x = 0 x = 1 / // \ / \ / // D E // \ / // F // // Our flavor of SSA construction for x will construct the following // // entry // / \ // A \ // / \ Y // x0 = 0 x1 = 1 / // \ / \ / // x2 = phi E // \ / // x3 = phi // // The blocks D and F contain phi nodes and are thus each reachable // by two disjoins paths from A. // // -- Remarks -- // * In case of loop exits we need to check the disjoint path criterion for loops. // To this end, we check whether the definition of x differs between the // loop exit and the loop header (_after_ SSA construction). // // -- Known Limitations & Future Work -- // * The algorithm requires reducible loops because the implementation // implicitly performs a single iteration of the underlying data flow analysis. // This was done for pragmatism, simplicity and speed. // // Relevant related work for extending the algorithm to irreducible control: // A simple algorithm for global data flow analysis problems. // Matthew S. Hecht and Jeffrey D. Ullman. // SIAM Journal on Computing, 4(4):519–532, December 1975. // // * Another reason for requiring reducible loops is that points of // synchronization in irreducible loops aren't 'obvious' - there is no unique // header where threads 'should' synchronize when entering or coming back // around from the latch. // //===----------------------------------------------------------------------===// #include "llvm/Analysis/SyncDependenceAnalysis.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include #include #include #define DEBUG_TYPE "sync-dependence" // The SDA algorithm operates on a modified CFG - we modify the edges leaving // loop headers as follows: // // * We remove all edges leaving all loop headers. // * We add additional edges from the loop headers to their exit blocks. // // The modification is virtual, that is whenever we visit a loop header we // pretend it had different successors. namespace { using namespace llvm; // Custom Post-Order Traveral // // We cannot use the vanilla (R)PO computation of LLVM because: // * We (virtually) modify the CFG. // * We want a loop-compact block enumeration, that is the numbers assigned to // blocks of a loop form an interval // using POCB = std::function; using VisitedSet = std::set; using BlockStack = std::vector; // forward static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, VisitedSet &Finalized); // for a nested region (top-level loop or nested loop) static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop, POCB CallBack, VisitedSet &Finalized) { const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr; while (!Stack.empty()) { const auto *NextBB = Stack.back(); auto *NestedLoop = LI.getLoopFor(NextBB); bool IsNestedLoop = NestedLoop != Loop; // Treat the loop as a node if (IsNestedLoop) { SmallVector NestedExits; NestedLoop->getUniqueExitBlocks(NestedExits); bool PushedNodes = false; for (const auto *NestedExitBB : NestedExits) { if (NestedExitBB == LoopHeader) continue; if (Loop && !Loop->contains(NestedExitBB)) continue; if (Finalized.count(NestedExitBB)) continue; PushedNodes = true; Stack.push_back(NestedExitBB); } if (!PushedNodes) { // All loop exits finalized -> finish this node Stack.pop_back(); computeLoopPO(LI, *NestedLoop, CallBack, Finalized); } continue; } // DAG-style bool PushedNodes = false; for (const auto *SuccBB : successors(NextBB)) { if (SuccBB == LoopHeader) continue; if (Loop && !Loop->contains(SuccBB)) continue; if (Finalized.count(SuccBB)) continue; PushedNodes = true; Stack.push_back(SuccBB); } if (!PushedNodes) { // Never push nodes twice Stack.pop_back(); if (!Finalized.insert(NextBB).second) continue; CallBack(*NextBB); } } } static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) { VisitedSet Finalized; BlockStack Stack; Stack.reserve(24); // FIXME made-up number Stack.push_back(&F.getEntryBlock()); computeStackPO(Stack, LI, nullptr, CallBack, Finalized); } static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, VisitedSet &Finalized) { /// Call CallBack on all loop blocks. std::vector Stack; const auto *LoopHeader = Loop.getHeader(); // Visit the header last Finalized.insert(LoopHeader); CallBack(*LoopHeader); // Initialize with immediate successors for (const auto *BB : successors(LoopHeader)) { if (!Loop.contains(BB)) continue; if (BB == LoopHeader) continue; Stack.push_back(BB); } // Compute PO inside region computeStackPO(Stack, LI, &Loop, CallBack, Finalized); } } // namespace namespace llvm { ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc; SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI) : DT(DT), PDT(PDT), LI(LI) { computeTopLevelPO(*DT.getRoot()->getParent(), LI, [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); }); } SyncDependenceAnalysis::~SyncDependenceAnalysis() {} // divergence propagator for reducible CFGs struct DivergencePropagator { const ModifiedPO &LoopPOT; const DominatorTree &DT; const PostDominatorTree &PDT; const LoopInfo &LI; const BasicBlock &DivTermBlock; // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at // block B // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths // from X or B is an immediate successor of X (initial value). using BlockLabelVec = std::vector; BlockLabelVec BlockLabels; // divergent join and loop exit descriptor. std::unique_ptr DivDesc; DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI, const BasicBlock &DivTermBlock) : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock), BlockLabels(LoopPOT.size(), nullptr), DivDesc(new ControlDivergenceDesc) {} void printDefs(raw_ostream &Out) { Out << "Propagator::BlockLabels {\n"; for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) { const auto *Label = BlockLabels[BlockIdx]; Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx << ") : "; if (!Label) { Out << "\n"; } else { Out << Label->getName() << "\n"; } } Out << "}\n"; } // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this // causes a divergent join. bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) { auto SuccIdx = LoopPOT.getIndexOf(SuccBlock); // unset or same reaching label const auto *OldLabel = BlockLabels[SuccIdx]; if (!OldLabel || (OldLabel == &PushedLabel)) { BlockLabels[SuccIdx] = &PushedLabel; return false; } // Update the definition BlockLabels[SuccIdx] = &SuccBlock; return true; } // visiting a virtual loop exit edge from the loop header --> temporal // divergence on join bool visitLoopExitEdge(const BasicBlock &ExitBlock, const BasicBlock &DefBlock, bool FromParentLoop) { // Pushing from a non-parent loop cannot cause temporal divergence. if (!FromParentLoop) return visitEdge(ExitBlock, DefBlock); if (!computeJoin(ExitBlock, DefBlock)) return false; // Identified a divergent loop exit DivDesc->LoopDivBlocks.insert(&ExitBlock); LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName() << "\n"); return true; } // process \p SuccBlock with reaching definition \p DefBlock bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) { if (!computeJoin(SuccBlock, DefBlock)) return false; // Divergent, disjoint paths join. DivDesc->JoinDivBlocks.insert(&SuccBlock); LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName()); return true; } std::unique_ptr computeJoinPoints() { assert(DivDesc); LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName() << "\n"); const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock); // Early stopping criterion int FloorIdx = LoopPOT.size() - 1; const BasicBlock *FloorLabel = nullptr; // bootstrap with branch targets int BlockIdx = 0; for (const auto *SuccBlock : successors(&DivTermBlock)) { auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock); BlockLabels[SuccIdx] = SuccBlock; // Find the successor with the highest index to start with BlockIdx = std::max(BlockIdx, SuccIdx); FloorIdx = std::min(FloorIdx, SuccIdx); // Identify immediate divergent loop exits if (!DivBlockLoop) continue; const auto *BlockLoop = LI.getLoopFor(SuccBlock); if (BlockLoop && DivBlockLoop->contains(BlockLoop)) continue; DivDesc->LoopDivBlocks.insert(SuccBlock); LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: " << SuccBlock->getName() << "\n"); } // propagate definitions at the immediate successors of the node in RPO for (; BlockIdx >= FloorIdx; --BlockIdx) { LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs())); // Any label available here const auto *Label = BlockLabels[BlockIdx]; if (!Label) continue; // Ok. Get the block const auto *Block = LoopPOT.getBlockAt(BlockIdx); LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); auto *BlockLoop = LI.getLoopFor(Block); bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block; bool CausedJoin = false; int LoweredFloorIdx = FloorIdx; if (IsLoopHeader) { // Disconnect from immediate successors and propagate directly to loop // exits. SmallVector BlockLoopExits; BlockLoop->getExitBlocks(BlockLoopExits); bool IsParentLoop = BlockLoop->contains(&DivTermBlock); for (const auto *BlockLoopExit : BlockLoopExits) { CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop); LoweredFloorIdx = std::min(LoweredFloorIdx, LoopPOT.getIndexOf(*BlockLoopExit)); } } else { // Acyclic successor case for (const auto *SuccBlock : successors(Block)) { CausedJoin |= visitEdge(*SuccBlock, *Label); LoweredFloorIdx = std::min(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock)); } } // Floor update if (CausedJoin) { // 1. Different labels pushed to successors FloorIdx = LoweredFloorIdx; } else if (FloorLabel != Label) { // 2. No join caused BUT we pushed a label that is different than the // last pushed label FloorIdx = LoweredFloorIdx; FloorLabel = Label; } } LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); return std::move(DivDesc); } }; #ifndef NDEBUG static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) { Out << "["; ListSeparator LS; for (const auto *BB : Blocks) Out << LS << BB->getName(); Out << "]"; } #endif const ControlDivergenceDesc & SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) { // trivial case if (Term.getNumSuccessors() <= 1) { return EmptyDivergenceDesc; } // already available in cache? auto ItCached = CachedControlDivDescs.find(&Term); if (ItCached != CachedControlDivDescs.end()) return *ItCached->second; // compute all join points // Special handling of divergent loop exits is not needed for LCSSA const auto &TermBlock = *Term.getParent(); DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock); auto DivDesc = Propagator.computeJoinPoints(); LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n"; dbgs() << "JoinDivBlocks: "; printBlockSet(DivDesc->JoinDivBlocks, dbgs()); dbgs() << "\nLoopDivBlocks: "; printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";); auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc)); assert(ItInserted.second); return *ItInserted.first->second; } } // namespace llvm