1 //======- GVNExpression.h - GVN Expression classes -------*- C++ -*-==-------=//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
11 /// The header file for the GVN pass that contains expression handling
14 //===----------------------------------------------------------------------===//
16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Value.h"
23 #include "llvm/Support/Allocator.h"
24 #include "llvm/Support/ArrayRecycler.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "llvm/Transforms/Utils/MemorySSA.h"
33 namespace GVNExpression {
56 Expression(const Expression &) = delete;
57 Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
58 : EType(ET), Opcode(O) {}
59 void operator=(const Expression &) = delete;
60 virtual ~Expression();
62 static unsigned getEmptyKey() { return ~0U; }
63 static unsigned getTombstoneKey() { return ~1U; }
65 bool operator==(const Expression &Other) const {
66 if (getOpcode() != Other.getOpcode())
68 if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
70 // Compare the expression type for anything but load and store.
71 // For load and store we set the opcode to zero.
72 // This is needed for load coercion.
73 if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
74 getExpressionType() != Other.getExpressionType())
80 virtual bool equals(const Expression &Other) const { return true; }
82 unsigned getOpcode() const { return Opcode; }
83 void setOpcode(unsigned opcode) { Opcode = opcode; }
84 ExpressionType getExpressionType() const { return EType; }
86 virtual hash_code getHashValue() const {
87 return hash_combine(getExpressionType(), getOpcode());
93 virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
95 OS << "etype = " << getExpressionType() << ",";
96 OS << "opcode = " << getOpcode() << ", ";
99 void print(raw_ostream &OS) const {
101 printInternal(OS, true);
104 void dump() const { print(dbgs()); }
107 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
112 class BasicExpression : public Expression {
114 typedef ArrayRecycler<Value *> RecyclerType;
115 typedef RecyclerType::Capacity RecyclerCapacity;
117 unsigned MaxOperands;
118 unsigned NumOperands;
122 static bool classof(const Expression *EB) {
123 ExpressionType ET = EB->getExpressionType();
124 return ET > ET_BasicStart && ET < ET_BasicEnd;
127 BasicExpression(unsigned NumOperands)
128 : BasicExpression(NumOperands, ET_Basic) {}
129 BasicExpression(unsigned NumOperands, ExpressionType ET)
130 : Expression(ET), Operands(nullptr), MaxOperands(NumOperands),
131 NumOperands(0), ValueType(nullptr) {}
132 virtual ~BasicExpression() override;
133 void operator=(const BasicExpression &) = delete;
134 BasicExpression(const BasicExpression &) = delete;
135 BasicExpression() = delete;
137 /// \brief Swap two operands. Used during GVN to put commutative operands in
139 void swapOperands(unsigned First, unsigned Second) {
140 std::swap(Operands[First], Operands[Second]);
143 Value *getOperand(unsigned N) const {
144 assert(Operands && "Operands not allocated");
145 assert(N < NumOperands && "Operand out of range");
149 void setOperand(unsigned N, Value *V) {
150 assert(Operands && "Operands not allocated before setting");
151 assert(N < NumOperands && "Operand out of range");
155 unsigned getNumOperands() const { return NumOperands; }
157 typedef Value **op_iterator;
158 typedef Value *const *const_op_iterator;
159 op_iterator op_begin() { return Operands; }
160 op_iterator op_end() { return Operands + NumOperands; }
161 const_op_iterator op_begin() const { return Operands; }
162 const_op_iterator op_end() const { return Operands + NumOperands; }
163 iterator_range<op_iterator> operands() {
164 return iterator_range<op_iterator>(op_begin(), op_end());
166 iterator_range<const_op_iterator> operands() const {
167 return iterator_range<const_op_iterator>(op_begin(), op_end());
170 void op_push_back(Value *Arg) {
171 assert(NumOperands < MaxOperands && "Tried to add too many operands");
172 assert(Operands && "Operandss not allocated before pushing");
173 Operands[NumOperands++] = Arg;
175 bool op_empty() const { return getNumOperands() == 0; }
177 void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
178 assert(!Operands && "Operands already allocated");
179 Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
181 void deallocateOperands(RecyclerType &Recycler) {
182 Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
185 void setType(Type *T) { ValueType = T; }
186 Type *getType() const { return ValueType; }
188 virtual bool equals(const Expression &Other) const override {
189 if (getOpcode() != Other.getOpcode())
192 const auto &OE = cast<BasicExpression>(Other);
193 return getType() == OE.getType() && NumOperands == OE.NumOperands &&
194 std::equal(op_begin(), op_end(), OE.op_begin());
197 virtual hash_code getHashValue() const override {
198 return hash_combine(getExpressionType(), getOpcode(), ValueType,
199 hash_combine_range(op_begin(), op_end()));
205 virtual void printInternal(raw_ostream &OS, bool PrintEType) const override {
207 OS << "ExpressionTypeBasic, ";
209 this->Expression::printInternal(OS, false);
210 OS << "operands = {";
211 for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
212 OS << "[" << i << "] = ";
213 Operands[i]->printAsOperand(OS);
220 : public std::iterator<std::output_iterator_tag, void, void, void, void> {
222 typedef BasicExpression Container;
226 explicit op_inserter(BasicExpression &E) : BE(&E) {}
227 explicit op_inserter(BasicExpression *E) : BE(E) {}
229 op_inserter &operator=(Value *val) {
230 BE->op_push_back(val);
233 op_inserter &operator*() { return *this; }
234 op_inserter &operator++() { return *this; }
235 op_inserter &operator++(int) { return *this; }
238 class CallExpression final : public BasicExpression {
241 MemoryAccess *DefiningAccess;
244 static bool classof(const Expression *EB) {
245 return EB->getExpressionType() == ET_Call;
248 CallExpression(unsigned NumOperands, CallInst *C, MemoryAccess *DA)
249 : BasicExpression(NumOperands, ET_Call), Call(C), DefiningAccess(DA) {}
250 void operator=(const CallExpression &) = delete;
251 CallExpression(const CallExpression &) = delete;
252 CallExpression() = delete;
253 virtual ~CallExpression() override;
255 virtual bool equals(const Expression &Other) const override {
256 if (!this->BasicExpression::equals(Other))
258 const auto &OE = cast<CallExpression>(Other);
259 return DefiningAccess == OE.DefiningAccess;
262 virtual hash_code getHashValue() const override {
263 return hash_combine(this->BasicExpression::getHashValue(), DefiningAccess);
269 virtual void printInternal(raw_ostream &OS, bool PrintEType) const override {
271 OS << "ExpressionTypeCall, ";
272 this->BasicExpression::printInternal(OS, false);
273 OS << " represents call at " << Call;
277 class LoadExpression final : public BasicExpression {
280 MemoryAccess *DefiningAccess;
284 static bool classof(const Expression *EB) {
285 return EB->getExpressionType() == ET_Load;
288 LoadExpression(unsigned NumOperands, LoadInst *L, MemoryAccess *DA)
289 : LoadExpression(ET_Load, NumOperands, L, DA) {}
290 LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
292 : BasicExpression(NumOperands, EType), Load(L), DefiningAccess(DA) {
293 Alignment = L ? L->getAlignment() : 0;
295 void operator=(const LoadExpression &) = delete;
296 LoadExpression(const LoadExpression &) = delete;
297 LoadExpression() = delete;
298 virtual ~LoadExpression() override;
300 LoadInst *getLoadInst() const { return Load; }
301 void setLoadInst(LoadInst *L) { Load = L; }
303 MemoryAccess *getDefiningAccess() const { return DefiningAccess; }
304 void setDefiningAccess(MemoryAccess *MA) { DefiningAccess = MA; }
305 unsigned getAlignment() const { return Alignment; }
306 void setAlignment(unsigned Align) { Alignment = Align; }
308 virtual bool equals(const Expression &Other) const override;
310 virtual hash_code getHashValue() const override {
311 return hash_combine(getOpcode(), getType(), DefiningAccess,
312 hash_combine_range(op_begin(), op_end()));
318 virtual void printInternal(raw_ostream &OS, bool PrintEType) const override {
320 OS << "ExpressionTypeLoad, ";
321 this->BasicExpression::printInternal(OS, false);
322 OS << " represents Load at " << Load;
323 OS << " with DefiningAccess " << *DefiningAccess;
327 class StoreExpression final : public BasicExpression {
330 MemoryAccess *DefiningAccess;
333 static bool classof(const Expression *EB) {
334 return EB->getExpressionType() == ET_Store;
337 StoreExpression(unsigned NumOperands, StoreInst *S, MemoryAccess *DA)
338 : BasicExpression(NumOperands, ET_Store), Store(S), DefiningAccess(DA) {}
339 void operator=(const StoreExpression &) = delete;
340 StoreExpression(const StoreExpression &) = delete;
341 StoreExpression() = delete;
342 virtual ~StoreExpression() override;
344 StoreInst *getStoreInst() const { return Store; }
345 MemoryAccess *getDefiningAccess() const { return DefiningAccess; }
347 virtual bool equals(const Expression &Other) const override;
349 virtual hash_code getHashValue() const override {
350 return hash_combine(getOpcode(), getType(), DefiningAccess,
351 hash_combine_range(op_begin(), op_end()));
357 virtual void printInternal(raw_ostream &OS, bool PrintEType) const override {
359 OS << "ExpressionTypeStore, ";
360 this->BasicExpression::printInternal(OS, false);
361 OS << " represents Store at " << Store;
362 OS << " with DefiningAccess " << *DefiningAccess;
366 class AggregateValueExpression final : public BasicExpression {
368 unsigned MaxIntOperands;
369 unsigned NumIntOperands;
370 unsigned *IntOperands;
373 static bool classof(const Expression *EB) {
374 return EB->getExpressionType() == ET_AggregateValue;
377 AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
378 : BasicExpression(NumOperands, ET_AggregateValue),
379 MaxIntOperands(NumIntOperands), NumIntOperands(0),
380 IntOperands(nullptr) {}
382 void operator=(const AggregateValueExpression &) = delete;
383 AggregateValueExpression(const AggregateValueExpression &) = delete;
384 AggregateValueExpression() = delete;
385 virtual ~AggregateValueExpression() override;
387 typedef unsigned *int_arg_iterator;
388 typedef const unsigned *const_int_arg_iterator;
390 int_arg_iterator int_op_begin() { return IntOperands; }
391 int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
392 const_int_arg_iterator int_op_begin() const { return IntOperands; }
393 const_int_arg_iterator int_op_end() const {
394 return IntOperands + NumIntOperands;
396 unsigned int_op_size() const { return NumIntOperands; }
397 bool int_op_empty() const { return NumIntOperands == 0; }
398 void int_op_push_back(unsigned IntOperand) {
399 assert(NumIntOperands < MaxIntOperands &&
400 "Tried to add too many int operands");
401 assert(IntOperands && "Operands not allocated before pushing");
402 IntOperands[NumIntOperands++] = IntOperand;
405 virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
406 assert(!IntOperands && "Operands already allocated");
407 IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
410 virtual bool equals(const Expression &Other) const override {
411 if (!this->BasicExpression::equals(Other))
413 const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
414 return NumIntOperands == OE.NumIntOperands &&
415 std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
418 virtual hash_code getHashValue() const override {
419 return hash_combine(this->BasicExpression::getHashValue(),
420 hash_combine_range(int_op_begin(), int_op_end()));
426 virtual void printInternal(raw_ostream &OS, bool PrintEType) const override {
428 OS << "ExpressionTypeAggregateValue, ";
429 this->BasicExpression::printInternal(OS, false);
430 OS << ", intoperands = {";
431 for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
432 OS << "[" << i << "] = " << IntOperands[i] << " ";
437 class int_op_inserter
438 : public std::iterator<std::output_iterator_tag, void, void, void, void> {
440 typedef AggregateValueExpression Container;
444 explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
445 explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
446 int_op_inserter &operator=(unsigned int val) {
447 AVE->int_op_push_back(val);
450 int_op_inserter &operator*() { return *this; }
451 int_op_inserter &operator++() { return *this; }
452 int_op_inserter &operator++(int) { return *this; }
455 class PHIExpression final : public BasicExpression {
460 static bool classof(const Expression *EB) {
461 return EB->getExpressionType() == ET_Phi;
464 PHIExpression(unsigned NumOperands, BasicBlock *B)
465 : BasicExpression(NumOperands, ET_Phi), BB(B) {}
466 void operator=(const PHIExpression &) = delete;
467 PHIExpression(const PHIExpression &) = delete;
468 PHIExpression() = delete;
469 virtual ~PHIExpression() override;
471 virtual bool equals(const Expression &Other) const override {
472 if (!this->BasicExpression::equals(Other))
474 const PHIExpression &OE = cast<PHIExpression>(Other);
478 virtual hash_code getHashValue() const override {
479 return hash_combine(this->BasicExpression::getHashValue(), BB);
485 virtual void printInternal(raw_ostream &OS, bool PrintEType) const override {
487 OS << "ExpressionTypePhi, ";
488 this->BasicExpression::printInternal(OS, false);
493 class VariableExpression final : public Expression {
495 Value *VariableValue;
498 static bool classof(const Expression *EB) {
499 return EB->getExpressionType() == ET_Variable;
502 VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
503 void operator=(const VariableExpression &) = delete;
504 VariableExpression(const VariableExpression &) = delete;
505 VariableExpression() = delete;
507 Value *getVariableValue() const { return VariableValue; }
508 void setVariableValue(Value *V) { VariableValue = V; }
509 virtual bool equals(const Expression &Other) const override {
510 const VariableExpression &OC = cast<VariableExpression>(Other);
511 return VariableValue == OC.VariableValue;
514 virtual hash_code getHashValue() const override {
515 return hash_combine(getExpressionType(), VariableValue->getType(),
522 virtual void printInternal(raw_ostream &OS, bool PrintEType) const override {
524 OS << "ExpressionTypeVariable, ";
525 this->Expression::printInternal(OS, false);
526 OS << " variable = " << *VariableValue;
530 class ConstantExpression final : public Expression {
532 Constant *ConstantValue;
535 static bool classof(const Expression *EB) {
536 return EB->getExpressionType() == ET_Constant;
539 ConstantExpression() : Expression(ET_Constant), ConstantValue(NULL) {}
540 ConstantExpression(Constant *constantValue)
541 : Expression(ET_Constant), ConstantValue(constantValue) {}
542 void operator=(const ConstantExpression &) = delete;
543 ConstantExpression(const ConstantExpression &) = delete;
545 Constant *getConstantValue() const { return ConstantValue; }
546 void setConstantValue(Constant *V) { ConstantValue = V; }
548 virtual bool equals(const Expression &Other) const override {
549 const ConstantExpression &OC = cast<ConstantExpression>(Other);
550 return ConstantValue == OC.ConstantValue;
553 virtual hash_code getHashValue() const override {
554 return hash_combine(getExpressionType(), ConstantValue->getType(),
561 virtual void printInternal(raw_ostream &OS, bool PrintEType) const override {
563 OS << "ExpressionTypeConstant, ";
564 this->Expression::printInternal(OS, false);
565 OS << " constant = " << *ConstantValue;
569 class UnknownExpression final : public Expression {
574 static bool classof(const Expression *EB) {
575 return EB->getExpressionType() == ET_Unknown;
578 UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
579 void operator=(const UnknownExpression &) = delete;
580 UnknownExpression(const UnknownExpression &) = delete;
581 UnknownExpression() = delete;
583 Instruction *getInstruction() const { return Inst; }
584 void setInstruction(Instruction *I) { Inst = I; }
585 virtual bool equals(const Expression &Other) const override {
586 const auto &OU = cast<UnknownExpression>(Other);
587 return Inst == OU.Inst;
589 virtual hash_code getHashValue() const override {
590 return hash_combine(getExpressionType(), Inst);
595 virtual void printInternal(raw_ostream &OS, bool PrintEType) const override {
597 OS << "ExpressionTypeUnknown, ";
598 this->Expression::printInternal(OS, false);
599 OS << " inst = " << *Inst;