1 //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
12 /// The header file for the GVN pass that contains expression handling
15 //===----------------------------------------------------------------------===//
17 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
18 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
20 #include "llvm/ADT/Hashing.h"
21 #include "llvm/ADT/iterator_range.h"
22 #include "llvm/Analysis/MemorySSA.h"
23 #include "llvm/IR/Constant.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Value.h"
26 #include "llvm/Support/Allocator.h"
27 #include "llvm/Support/ArrayRecycler.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Compiler.h"
30 #include "llvm/Support/raw_ostream.h"
41 namespace GVNExpression {
65 mutable hash_code HashVal = 0;
68 Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
69 : EType(ET), Opcode(O) {}
70 Expression(const Expression &) = delete;
71 Expression &operator=(const Expression &) = delete;
72 virtual ~Expression();
74 static unsigned getEmptyKey() { return ~0U; }
75 static unsigned getTombstoneKey() { return ~1U; }
77 bool operator!=(const Expression &Other) const { return !(*this == Other); }
78 bool operator==(const Expression &Other) const {
79 if (getOpcode() != Other.getOpcode())
81 if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
83 // Compare the expression type for anything but load and store.
84 // For load and store we set the opcode to zero to make them equal.
85 if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
86 getExpressionType() != Other.getExpressionType())
92 hash_code getComputedHash() const {
93 // It's theoretically possible for a thing to hash to zero. In that case,
94 // we will just compute the hash a few extra times, which is no worse that
95 // we did before, which was to compute it always.
96 if (static_cast<unsigned>(HashVal) == 0)
97 HashVal = getHashValue();
101 virtual bool equals(const Expression &Other) const { return true; }
103 // Return true if the two expressions are exactly the same, including the
104 // normally ignored fields.
105 virtual bool exactlyEquals(const Expression &Other) const {
106 return getExpressionType() == Other.getExpressionType() && equals(Other);
109 unsigned getOpcode() const { return Opcode; }
110 void setOpcode(unsigned opcode) { Opcode = opcode; }
111 ExpressionType getExpressionType() const { return EType; }
113 // We deliberately leave the expression type out of the hash value.
114 virtual hash_code getHashValue() const { return getOpcode(); }
117 virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
119 OS << "etype = " << getExpressionType() << ",";
120 OS << "opcode = " << getOpcode() << ", ";
123 void print(raw_ostream &OS) const {
125 printInternal(OS, true);
129 LLVM_DUMP_METHOD void dump() const;
132 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
137 class BasicExpression : public Expression {
139 using RecyclerType = ArrayRecycler<Value *>;
140 using RecyclerCapacity = RecyclerType::Capacity;
142 Value **Operands = nullptr;
143 unsigned MaxOperands;
144 unsigned NumOperands = 0;
145 Type *ValueType = nullptr;
148 BasicExpression(unsigned NumOperands)
149 : BasicExpression(NumOperands, ET_Basic) {}
150 BasicExpression(unsigned NumOperands, ExpressionType ET)
151 : Expression(ET), MaxOperands(NumOperands) {}
152 BasicExpression() = delete;
153 BasicExpression(const BasicExpression &) = delete;
154 BasicExpression &operator=(const BasicExpression &) = delete;
155 ~BasicExpression() override;
157 static bool classof(const Expression *EB) {
158 ExpressionType ET = EB->getExpressionType();
159 return ET > ET_BasicStart && ET < ET_BasicEnd;
162 /// Swap two operands. Used during GVN to put commutative operands in
164 void swapOperands(unsigned First, unsigned Second) {
165 std::swap(Operands[First], Operands[Second]);
168 Value *getOperand(unsigned N) const {
169 assert(Operands && "Operands not allocated");
170 assert(N < NumOperands && "Operand out of range");
174 void setOperand(unsigned N, Value *V) {
175 assert(Operands && "Operands not allocated before setting");
176 assert(N < NumOperands && "Operand out of range");
180 unsigned getNumOperands() const { return NumOperands; }
182 using op_iterator = Value **;
183 using const_op_iterator = Value *const *;
185 op_iterator op_begin() { return Operands; }
186 op_iterator op_end() { return Operands + NumOperands; }
187 const_op_iterator op_begin() const { return Operands; }
188 const_op_iterator op_end() const { return Operands + NumOperands; }
189 iterator_range<op_iterator> operands() {
190 return iterator_range<op_iterator>(op_begin(), op_end());
192 iterator_range<const_op_iterator> operands() const {
193 return iterator_range<const_op_iterator>(op_begin(), op_end());
196 void op_push_back(Value *Arg) {
197 assert(NumOperands < MaxOperands && "Tried to add too many operands");
198 assert(Operands && "Operandss not allocated before pushing");
199 Operands[NumOperands++] = Arg;
201 bool op_empty() const { return getNumOperands() == 0; }
203 void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
204 assert(!Operands && "Operands already allocated");
205 Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
207 void deallocateOperands(RecyclerType &Recycler) {
208 Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
211 void setType(Type *T) { ValueType = T; }
212 Type *getType() const { return ValueType; }
214 bool equals(const Expression &Other) const override {
215 if (getOpcode() != Other.getOpcode())
218 const auto &OE = cast<BasicExpression>(Other);
219 return getType() == OE.getType() && NumOperands == OE.NumOperands &&
220 std::equal(op_begin(), op_end(), OE.op_begin());
223 hash_code getHashValue() const override {
224 return hash_combine(this->Expression::getHashValue(), ValueType,
225 hash_combine_range(op_begin(), op_end()));
229 void printInternal(raw_ostream &OS, bool PrintEType) const override {
231 OS << "ExpressionTypeBasic, ";
233 this->Expression::printInternal(OS, false);
234 OS << "operands = {";
235 for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
236 OS << "[" << i << "] = ";
237 Operands[i]->printAsOperand(OS);
245 : public std::iterator<std::output_iterator_tag, void, void, void, void> {
247 using Container = BasicExpression;
252 explicit op_inserter(BasicExpression &E) : BE(&E) {}
253 explicit op_inserter(BasicExpression *E) : BE(E) {}
255 op_inserter &operator=(Value *val) {
256 BE->op_push_back(val);
259 op_inserter &operator*() { return *this; }
260 op_inserter &operator++() { return *this; }
261 op_inserter &operator++(int) { return *this; }
264 class MemoryExpression : public BasicExpression {
266 const MemoryAccess *MemoryLeader;
269 MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
270 const MemoryAccess *MemoryLeader)
271 : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {}
272 MemoryExpression() = delete;
273 MemoryExpression(const MemoryExpression &) = delete;
274 MemoryExpression &operator=(const MemoryExpression &) = delete;
276 static bool classof(const Expression *EB) {
277 return EB->getExpressionType() > ET_MemoryStart &&
278 EB->getExpressionType() < ET_MemoryEnd;
281 hash_code getHashValue() const override {
282 return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
285 bool equals(const Expression &Other) const override {
286 if (!this->BasicExpression::equals(Other))
288 const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
290 return MemoryLeader == OtherMCE.MemoryLeader;
293 const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
294 void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
297 class CallExpression final : public MemoryExpression {
302 CallExpression(unsigned NumOperands, CallInst *C,
303 const MemoryAccess *MemoryLeader)
304 : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
305 CallExpression() = delete;
306 CallExpression(const CallExpression &) = delete;
307 CallExpression &operator=(const CallExpression &) = delete;
308 ~CallExpression() override;
310 static bool classof(const Expression *EB) {
311 return EB->getExpressionType() == ET_Call;
315 void printInternal(raw_ostream &OS, bool PrintEType) const override {
317 OS << "ExpressionTypeCall, ";
318 this->BasicExpression::printInternal(OS, false);
319 OS << " represents call at ";
320 Call->printAsOperand(OS);
324 class LoadExpression final : public MemoryExpression {
330 LoadExpression(unsigned NumOperands, LoadInst *L,
331 const MemoryAccess *MemoryLeader)
332 : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
334 LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
335 const MemoryAccess *MemoryLeader)
336 : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {
337 Alignment = L ? L->getAlignment() : 0;
340 LoadExpression() = delete;
341 LoadExpression(const LoadExpression &) = delete;
342 LoadExpression &operator=(const LoadExpression &) = delete;
343 ~LoadExpression() override;
345 static bool classof(const Expression *EB) {
346 return EB->getExpressionType() == ET_Load;
349 LoadInst *getLoadInst() const { return Load; }
350 void setLoadInst(LoadInst *L) { Load = L; }
352 unsigned getAlignment() const { return Alignment; }
353 void setAlignment(unsigned Align) { Alignment = Align; }
355 bool equals(const Expression &Other) const override;
356 bool exactlyEquals(const Expression &Other) const override {
357 return Expression::exactlyEquals(Other) &&
358 cast<LoadExpression>(Other).getLoadInst() == getLoadInst();
362 void printInternal(raw_ostream &OS, bool PrintEType) const override {
364 OS << "ExpressionTypeLoad, ";
365 this->BasicExpression::printInternal(OS, false);
366 OS << " represents Load at ";
367 Load->printAsOperand(OS);
368 OS << " with MemoryLeader " << *getMemoryLeader();
372 class StoreExpression final : public MemoryExpression {
378 StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
379 const MemoryAccess *MemoryLeader)
380 : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
381 StoredValue(StoredValue) {}
382 StoreExpression() = delete;
383 StoreExpression(const StoreExpression &) = delete;
384 StoreExpression &operator=(const StoreExpression &) = delete;
385 ~StoreExpression() override;
387 static bool classof(const Expression *EB) {
388 return EB->getExpressionType() == ET_Store;
391 StoreInst *getStoreInst() const { return Store; }
392 Value *getStoredValue() const { return StoredValue; }
394 bool equals(const Expression &Other) const override;
396 bool exactlyEquals(const Expression &Other) const override {
397 return Expression::exactlyEquals(Other) &&
398 cast<StoreExpression>(Other).getStoreInst() == getStoreInst();
402 void printInternal(raw_ostream &OS, bool PrintEType) const override {
404 OS << "ExpressionTypeStore, ";
405 this->BasicExpression::printInternal(OS, false);
406 OS << " represents Store " << *Store;
407 OS << " with StoredValue ";
408 StoredValue->printAsOperand(OS);
409 OS << " and MemoryLeader " << *getMemoryLeader();
413 class AggregateValueExpression final : public BasicExpression {
415 unsigned MaxIntOperands;
416 unsigned NumIntOperands = 0;
417 unsigned *IntOperands = nullptr;
420 AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
421 : BasicExpression(NumOperands, ET_AggregateValue),
422 MaxIntOperands(NumIntOperands) {}
423 AggregateValueExpression() = delete;
424 AggregateValueExpression(const AggregateValueExpression &) = delete;
425 AggregateValueExpression &
426 operator=(const AggregateValueExpression &) = delete;
427 ~AggregateValueExpression() override;
429 static bool classof(const Expression *EB) {
430 return EB->getExpressionType() == ET_AggregateValue;
433 using int_arg_iterator = unsigned *;
434 using const_int_arg_iterator = const unsigned *;
436 int_arg_iterator int_op_begin() { return IntOperands; }
437 int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
438 const_int_arg_iterator int_op_begin() const { return IntOperands; }
439 const_int_arg_iterator int_op_end() const {
440 return IntOperands + NumIntOperands;
442 unsigned int_op_size() const { return NumIntOperands; }
443 bool int_op_empty() const { return NumIntOperands == 0; }
444 void int_op_push_back(unsigned IntOperand) {
445 assert(NumIntOperands < MaxIntOperands &&
446 "Tried to add too many int operands");
447 assert(IntOperands && "Operands not allocated before pushing");
448 IntOperands[NumIntOperands++] = IntOperand;
451 virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
452 assert(!IntOperands && "Operands already allocated");
453 IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
456 bool equals(const Expression &Other) const override {
457 if (!this->BasicExpression::equals(Other))
459 const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
460 return NumIntOperands == OE.NumIntOperands &&
461 std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
464 hash_code getHashValue() const override {
465 return hash_combine(this->BasicExpression::getHashValue(),
466 hash_combine_range(int_op_begin(), int_op_end()));
470 void printInternal(raw_ostream &OS, bool PrintEType) const override {
472 OS << "ExpressionTypeAggregateValue, ";
473 this->BasicExpression::printInternal(OS, false);
474 OS << ", intoperands = {";
475 for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
476 OS << "[" << i << "] = " << IntOperands[i] << " ";
482 class int_op_inserter
483 : public std::iterator<std::output_iterator_tag, void, void, void, void> {
485 using Container = AggregateValueExpression;
490 explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
491 explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
493 int_op_inserter &operator=(unsigned int val) {
494 AVE->int_op_push_back(val);
497 int_op_inserter &operator*() { return *this; }
498 int_op_inserter &operator++() { return *this; }
499 int_op_inserter &operator++(int) { return *this; }
502 class PHIExpression final : public BasicExpression {
507 PHIExpression(unsigned NumOperands, BasicBlock *B)
508 : BasicExpression(NumOperands, ET_Phi), BB(B) {}
509 PHIExpression() = delete;
510 PHIExpression(const PHIExpression &) = delete;
511 PHIExpression &operator=(const PHIExpression &) = delete;
512 ~PHIExpression() override;
514 static bool classof(const Expression *EB) {
515 return EB->getExpressionType() == ET_Phi;
518 bool equals(const Expression &Other) const override {
519 if (!this->BasicExpression::equals(Other))
521 const PHIExpression &OE = cast<PHIExpression>(Other);
525 hash_code getHashValue() const override {
526 return hash_combine(this->BasicExpression::getHashValue(), BB);
530 void printInternal(raw_ostream &OS, bool PrintEType) const override {
532 OS << "ExpressionTypePhi, ";
533 this->BasicExpression::printInternal(OS, false);
538 class DeadExpression final : public Expression {
540 DeadExpression() : Expression(ET_Dead) {}
541 DeadExpression(const DeadExpression &) = delete;
542 DeadExpression &operator=(const DeadExpression &) = delete;
544 static bool classof(const Expression *E) {
545 return E->getExpressionType() == ET_Dead;
549 class VariableExpression final : public Expression {
551 Value *VariableValue;
554 VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
555 VariableExpression() = delete;
556 VariableExpression(const VariableExpression &) = delete;
557 VariableExpression &operator=(const VariableExpression &) = delete;
559 static bool classof(const Expression *EB) {
560 return EB->getExpressionType() == ET_Variable;
563 Value *getVariableValue() const { return VariableValue; }
564 void setVariableValue(Value *V) { VariableValue = V; }
566 bool equals(const Expression &Other) const override {
567 const VariableExpression &OC = cast<VariableExpression>(Other);
568 return VariableValue == OC.VariableValue;
571 hash_code getHashValue() const override {
572 return hash_combine(this->Expression::getHashValue(),
573 VariableValue->getType(), VariableValue);
577 void printInternal(raw_ostream &OS, bool PrintEType) const override {
579 OS << "ExpressionTypeVariable, ";
580 this->Expression::printInternal(OS, false);
581 OS << " variable = " << *VariableValue;
585 class ConstantExpression final : public Expression {
587 Constant *ConstantValue = nullptr;
590 ConstantExpression() : Expression(ET_Constant) {}
591 ConstantExpression(Constant *constantValue)
592 : Expression(ET_Constant), ConstantValue(constantValue) {}
593 ConstantExpression(const ConstantExpression &) = delete;
594 ConstantExpression &operator=(const ConstantExpression &) = delete;
596 static bool classof(const Expression *EB) {
597 return EB->getExpressionType() == ET_Constant;
600 Constant *getConstantValue() const { return ConstantValue; }
601 void setConstantValue(Constant *V) { ConstantValue = V; }
603 bool equals(const Expression &Other) const override {
604 const ConstantExpression &OC = cast<ConstantExpression>(Other);
605 return ConstantValue == OC.ConstantValue;
608 hash_code getHashValue() const override {
609 return hash_combine(this->Expression::getHashValue(),
610 ConstantValue->getType(), ConstantValue);
614 void printInternal(raw_ostream &OS, bool PrintEType) const override {
616 OS << "ExpressionTypeConstant, ";
617 this->Expression::printInternal(OS, false);
618 OS << " constant = " << *ConstantValue;
622 class UnknownExpression final : public Expression {
627 UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
628 UnknownExpression() = delete;
629 UnknownExpression(const UnknownExpression &) = delete;
630 UnknownExpression &operator=(const UnknownExpression &) = delete;
632 static bool classof(const Expression *EB) {
633 return EB->getExpressionType() == ET_Unknown;
636 Instruction *getInstruction() const { return Inst; }
637 void setInstruction(Instruction *I) { Inst = I; }
639 bool equals(const Expression &Other) const override {
640 const auto &OU = cast<UnknownExpression>(Other);
641 return Inst == OU.Inst;
644 hash_code getHashValue() const override {
645 return hash_combine(this->Expression::getHashValue(), Inst);
649 void printInternal(raw_ostream &OS, bool PrintEType) const override {
651 OS << "ExpressionTypeUnknown, ";
652 this->Expression::printInternal(OS, false);
653 OS << " inst = " << *Inst;
657 } // end namespace GVNExpression
659 } // end namespace llvm
661 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H