1 //===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
11 // The divergence analysis determines which instructions and branches are
12 // divergent given a set of divergent source instructions.
14 //===----------------------------------------------------------------------===//
16 #ifndef LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
17 #define LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/Analysis/SyncDependenceAnalysis.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/Pass.h"
31 class TargetTransformInfo;
33 /// \brief Generic divergence analysis for reducible CFGs.
35 /// This analysis propagates divergence in a data-parallel context from sources
36 /// of divergence to all users. It requires reducible CFGs. All assignments
37 /// should be in SSA form.
38 class DivergenceAnalysis {
40 /// \brief This instance will analyze the whole function \p F or the loop \p
43 /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop.
44 /// Otherwise the whole function is analyzed.
45 /// \param IsLCSSAForm whether the analysis may assume that the IR in the
46 /// region in in LCSSA form.
47 DivergenceAnalysis(const Function &F, const Loop *RegionLoop,
48 const DominatorTree &DT, const LoopInfo &LI,
49 SyncDependenceAnalysis &SDA, bool IsLCSSAForm);
51 /// \brief The loop that defines the analyzed region (if any).
52 const Loop *getRegionLoop() const { return RegionLoop; }
53 const Function &getFunction() const { return F; }
55 /// \brief Whether \p BB is part of the region.
56 bool inRegion(const BasicBlock &BB) const;
57 /// \brief Whether \p I is part of the region.
58 bool inRegion(const Instruction &I) const;
60 /// \brief Mark \p UniVal as a value that is always uniform.
61 void addUniformOverride(const Value &UniVal);
63 /// \brief Mark \p DivVal as a value that is always divergent.
64 void markDivergent(const Value &DivVal);
66 /// \brief Propagate divergence to all instructions in the region.
67 /// Divergence is seeded by calls to \p markDivergent.
70 /// \brief Whether any value was marked or analyzed to be divergent.
71 bool hasDetectedDivergence() const { return !DivergentValues.empty(); }
73 /// \brief Whether \p Val will always return a uniform value regardless of its
75 bool isAlwaysUniform(const Value &Val) const;
77 /// \brief Whether \p Val is a divergent value
78 bool isDivergent(const Value &Val) const;
80 void print(raw_ostream &OS, const Module *) const;
83 bool updateTerminator(const Instruction &Term) const;
84 bool updatePHINode(const PHINode &Phi) const;
86 /// \brief Computes whether \p Inst is divergent based on the
87 /// divergence of its operands.
89 /// \returns Whether \p Inst is divergent.
91 /// This should only be called for non-phi, non-terminator instructions.
92 bool updateNormalInstruction(const Instruction &Inst) const;
94 /// \brief Mark users of live-out users as divergent.
96 /// \param LoopHeader the header of the divergent loop.
98 /// Marks all users of live-out values of the loop headed by \p LoopHeader
99 /// as divergent and puts them on the worklist.
100 void taintLoopLiveOuts(const BasicBlock &LoopHeader);
102 /// \brief Push all users of \p Val (in the region) to the worklist
103 void pushUsers(const Value &I);
105 /// \brief Push all phi nodes in @block to the worklist
106 void pushPHINodes(const BasicBlock &Block);
108 /// \brief Mark \p Block as join divergent
110 /// A block is join divergent if two threads may reach it from different
111 /// incoming blocks at the same time.
112 void markBlockJoinDivergent(const BasicBlock &Block) {
113 DivergentJoinBlocks.insert(&Block);
116 /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
117 bool isTemporalDivergent(const BasicBlock &ObservingBlock,
118 const Value &Val) const;
120 /// \brief Whether \p Block is join divergent
122 /// (see markBlockJoinDivergent).
123 bool isJoinDivergent(const BasicBlock &Block) const {
124 return DivergentJoinBlocks.find(&Block) != DivergentJoinBlocks.end();
127 /// \brief Propagate control-induced divergence to users (phi nodes and
130 // \param JoinBlock is a divergent loop exit or join point of two disjoint
132 // \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop.
133 bool propagateJoinDivergence(const BasicBlock &JoinBlock,
134 const Loop *TermLoop);
136 /// \brief Propagate induced value divergence due to control divergence in \p
138 void propagateBranchDivergence(const Instruction &Term);
140 /// \brief Propagate divergent caused by a divergent loop exit.
142 /// \param ExitingLoop is a divergent loop.
143 void propagateLoopDivergence(const Loop &ExitingLoop);
147 // If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
148 // Otw, analyze the whole function
149 const Loop *RegionLoop;
151 const DominatorTree &DT;
154 // Recognized divergent loops
155 DenseSet<const Loop *> DivergentLoops;
157 // The SDA links divergent branches to divergent control-flow joins.
158 SyncDependenceAnalysis &SDA;
160 // Use simplified code path for LCSSA form.
163 // Set of known-uniform values.
164 DenseSet<const Value *> UniformOverrides;
166 // Blocks with joining divergent control from different predecessors.
167 DenseSet<const BasicBlock *> DivergentJoinBlocks;
169 // Detected/marked divergent values.
170 DenseSet<const Value *> DivergentValues;
172 // Internal worklist for divergence propagation.
173 std::vector<const Instruction *> Worklist;
176 /// \brief Divergence analysis frontend for GPU kernels.
177 class GPUDivergenceAnalysis {
178 SyncDependenceAnalysis SDA;
179 DivergenceAnalysis DA;
182 /// Runs the divergence analysis on @F, a GPU kernel
183 GPUDivergenceAnalysis(Function &F, const DominatorTree &DT,
184 const PostDominatorTree &PDT, const LoopInfo &LI,
185 const TargetTransformInfo &TTI);
187 /// Whether any divergence was detected.
188 bool hasDivergence() const { return DA.hasDetectedDivergence(); }
190 /// The GPU kernel this analysis result is for
191 const Function &getFunction() const { return DA.getFunction(); }
193 /// Whether \p V is divergent.
194 bool isDivergent(const Value &V) const;
196 /// Whether \p V is uniform/non-divergent
197 bool isUniform(const Value &V) const { return !isDivergent(V); }
199 /// Print all divergent values in the kernel.
200 void print(raw_ostream &OS, const Module *) const;
205 #endif // LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H