]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
Merge llvm-project main llvmorg-17-init-19304-gd0b54bb50e51
[FreeBSD/FreeBSD.git] / contrib / llvm-project / llvm / lib / CodeGen / ComplexDeinterleavingPass.cpp
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
26 //
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
33 //
34 //
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
41 //
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference  to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
55 //
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
59 //
60 //===----------------------------------------------------------------------===//
61
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/Statistic.h"
64 #include "llvm/Analysis/TargetLibraryInfo.h"
65 #include "llvm/Analysis/TargetTransformInfo.h"
66 #include "llvm/CodeGen/TargetLowering.h"
67 #include "llvm/CodeGen/TargetPassConfig.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
74 #include <algorithm>
75
76 using namespace llvm;
77 using namespace PatternMatch;
78
79 #define DEBUG_TYPE "complex-deinterleaving"
80
81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82
83 static cl::opt<bool> ComplexDeinterleavingEnabled(
84     "enable-complex-deinterleaving",
85     cl::desc("Enable generation of complex instructions"), cl::init(true),
86     cl::Hidden);
87
88 /// Checks the given mask, and determines whether said mask is interleaving.
89 ///
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef<int> Mask);
94
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
96 ///
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
98 /// with 0 or 1.
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100 /// <1, 3, 5, 7>).
101 static bool isDeinterleavingMask(ArrayRef<int> Mask);
102
103 /// Returns true if the operation is a negation of V, and it works for both
104 /// integers and floats.
105 static bool isNeg(Value *V);
106
107 /// Returns the operand for negation operation.
108 static Value *getNegOperand(Value *V);
109
110 namespace {
111
112 class ComplexDeinterleavingLegacyPass : public FunctionPass {
113 public:
114   static char ID;
115
116   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
117       : FunctionPass(ID), TM(TM) {
118     initializeComplexDeinterleavingLegacyPassPass(
119         *PassRegistry::getPassRegistry());
120   }
121
122   StringRef getPassName() const override {
123     return "Complex Deinterleaving Pass";
124   }
125
126   bool runOnFunction(Function &F) override;
127   void getAnalysisUsage(AnalysisUsage &AU) const override {
128     AU.addRequired<TargetLibraryInfoWrapperPass>();
129     AU.setPreservesCFG();
130   }
131
132 private:
133   const TargetMachine *TM;
134 };
135
136 class ComplexDeinterleavingGraph;
137 struct ComplexDeinterleavingCompositeNode {
138
139   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
140                                      Value *R, Value *I)
141       : Operation(Op), Real(R), Imag(I) {}
142
143 private:
144   friend class ComplexDeinterleavingGraph;
145   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
146   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
147
148 public:
149   ComplexDeinterleavingOperation Operation;
150   Value *Real;
151   Value *Imag;
152
153   // This two members are required exclusively for generating
154   // ComplexDeinterleavingOperation::Symmetric operations.
155   unsigned Opcode;
156   std::optional<FastMathFlags> Flags;
157
158   ComplexDeinterleavingRotation Rotation =
159       ComplexDeinterleavingRotation::Rotation_0;
160   SmallVector<RawNodePtr> Operands;
161   Value *ReplacementNode = nullptr;
162
163   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
164
165   void dump() { dump(dbgs()); }
166   void dump(raw_ostream &OS) {
167     auto PrintValue = [&](Value *V) {
168       if (V) {
169         OS << "\"";
170         V->print(OS, true);
171         OS << "\"\n";
172       } else
173         OS << "nullptr\n";
174     };
175     auto PrintNodeRef = [&](RawNodePtr Ptr) {
176       if (Ptr)
177         OS << Ptr << "\n";
178       else
179         OS << "nullptr\n";
180     };
181
182     OS << "- CompositeNode: " << this << "\n";
183     OS << "  Real: ";
184     PrintValue(Real);
185     OS << "  Imag: ";
186     PrintValue(Imag);
187     OS << "  ReplacementNode: ";
188     PrintValue(ReplacementNode);
189     OS << "  Operation: " << (int)Operation << "\n";
190     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
191     OS << "  Operands: \n";
192     for (const auto &Op : Operands) {
193       OS << "    - ";
194       PrintNodeRef(Op);
195     }
196   }
197 };
198
199 class ComplexDeinterleavingGraph {
200 public:
201   struct Product {
202     Value *Multiplier;
203     Value *Multiplicand;
204     bool IsPositive;
205   };
206
207   using Addend = std::pair<Value *, bool>;
208   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
209   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
210
211   // Helper struct for holding info about potential partial multiplication
212   // candidates
213   struct PartialMulCandidate {
214     Value *Common;
215     NodePtr Node;
216     unsigned RealIdx;
217     unsigned ImagIdx;
218     bool IsNodeInverted;
219   };
220
221   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
222                                       const TargetLibraryInfo *TLI)
223       : TL(TL), TLI(TLI) {}
224
225 private:
226   const TargetLowering *TL = nullptr;
227   const TargetLibraryInfo *TLI = nullptr;
228   SmallVector<NodePtr> CompositeNodes;
229
230   SmallPtrSet<Instruction *, 16> FinalInstructions;
231
232   /// Root instructions are instructions from which complex computation starts
233   std::map<Instruction *, NodePtr> RootToNode;
234
235   /// Topologically sorted root instructions
236   SmallVector<Instruction *, 1> OrderedRoots;
237
238   /// When examining a basic block for complex deinterleaving, if it is a simple
239   /// one-block loop, then the only incoming block is 'Incoming' and the
240   /// 'BackEdge' block is the block itself."
241   BasicBlock *BackEdge = nullptr;
242   BasicBlock *Incoming = nullptr;
243
244   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
245   /// %OutsideUser as it is shown in the IR:
246   ///
247   /// vector.body:
248   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
249   ///                                [ %ReductionOp, %vector.body ]
250   ///   ...
251   ///   %ReductionOp = fadd i64 ...
252   ///   ...
253   ///   br i1 %condition, label %vector.body, %middle.block
254   ///
255   /// middle.block:
256   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
257   ///
258   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
259   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
260   std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
261
262   /// In the process of detecting a reduction, we consider a pair of
263   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
264   /// traverse the use-tree to detect complex operations. As this is a reduction
265   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
266   /// to the %ReductionOPs that we suspect to be complex.
267   /// RealPHI and ImagPHI are used by the identifyPHINode method.
268   PHINode *RealPHI = nullptr;
269   PHINode *ImagPHI = nullptr;
270
271   /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
272   /// detection.
273   bool PHIsFound = false;
274
275   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
276   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
277   /// This mapping is populated during
278   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
279   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
280   /// replacement process.
281   std::map<PHINode *, PHINode *> OldToNewPHI;
282
283   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
284                                Value *R, Value *I) {
285     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
286              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
287             (R && I)) &&
288            "Reduction related nodes must have Real and Imaginary parts");
289     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
290                                                                 I);
291   }
292
293   NodePtr submitCompositeNode(NodePtr Node) {
294     CompositeNodes.push_back(Node);
295     return Node;
296   }
297
298   NodePtr getContainingComposite(Value *R, Value *I) {
299     for (const auto &CN : CompositeNodes) {
300       if (CN->Real == R && CN->Imag == I)
301         return CN;
302     }
303     return nullptr;
304   }
305
306   /// Identifies a complex partial multiply pattern and its rotation, based on
307   /// the following patterns
308   ///
309   ///  0:  r: cr + ar * br
310   ///      i: ci + ar * bi
311   /// 90:  r: cr - ai * bi
312   ///      i: ci + ai * br
313   /// 180: r: cr - ar * br
314   ///      i: ci - ar * bi
315   /// 270: r: cr + ai * bi
316   ///      i: ci - ai * br
317   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
318
319   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
320   /// is partially known from identifyPartialMul, filling in the other half of
321   /// the complex pair.
322   NodePtr
323   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
324                               std::pair<Value *, Value *> &CommonOperandI);
325
326   /// Identifies a complex add pattern and its rotation, based on the following
327   /// patterns.
328   ///
329   /// 90:  r: ar - bi
330   ///      i: ai + br
331   /// 270: r: ar + bi
332   ///      i: ai - br
333   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
334   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
335
336   NodePtr identifyNode(Value *R, Value *I);
337
338   /// Determine if a sum of complex numbers can be formed from \p RealAddends
339   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
340   /// Return nullptr if it is not possible to construct a complex number.
341   /// \p Flags are needed to generate symmetric Add and Sub operations.
342   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
343                             std::list<Addend> &ImagAddends,
344                             std::optional<FastMathFlags> Flags,
345                             NodePtr Accumulator);
346
347   /// Extract one addend that have both real and imaginary parts positive.
348   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
349                                 std::list<Addend> &ImagAddends);
350
351   /// Determine if sum of multiplications of complex numbers can be formed from
352   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
353   /// to it. Return nullptr if it is not possible to construct a complex number.
354   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
355                                   std::vector<Product> &ImagMuls,
356                                   NodePtr Accumulator);
357
358   /// Go through pairs of multiplication (one Real and one Imag) and find all
359   /// possible candidates for partial multiplication and put them into \p
360   /// Candidates. Returns true if all Product has pair with common operand
361   bool collectPartialMuls(const std::vector<Product> &RealMuls,
362                           const std::vector<Product> &ImagMuls,
363                           std::vector<PartialMulCandidate> &Candidates);
364
365   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
366   /// the order of complex computation operations may be significantly altered,
367   /// and the real and imaginary parts may not be executed in parallel. This
368   /// function takes this into consideration and employs a more general approach
369   /// to identify complex computations. Initially, it gathers all the addends
370   /// and multiplicands and then constructs a complex expression from them.
371   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
372
373   NodePtr identifyRoot(Instruction *I);
374
375   /// Identifies the Deinterleave operation applied to a vector containing
376   /// complex numbers. There are two ways to represent the Deinterleave
377   /// operation:
378   /// * Using two shufflevectors with even indices for /pReal instruction and
379   /// odd indices for /pImag instructions (only for fixed-width vectors)
380   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
381   /// intrinsic (for both fixed and scalable vectors)
382   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
383
384   /// identifying the operation that represents a complex number repeated in a
385   /// Splat vector. There are two possible types of splats: ConstantExpr with
386   /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
387   /// initialization mask with all values set to zero.
388   NodePtr identifySplat(Value *Real, Value *Imag);
389
390   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
391
392   /// Identifies SelectInsts in a loop that has reduction with predication masks
393   /// and/or predicated tail folding
394   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
395
396   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
397
398   /// Complete IR modifications after producing new reduction operation:
399   /// * Populate the PHINode generated for
400   /// ComplexDeinterleavingOperation::ReductionPHI
401   /// * Deinterleave the final value outside of the loop and repurpose original
402   /// reduction users
403   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
404
405 public:
406   void dump() { dump(dbgs()); }
407   void dump(raw_ostream &OS) {
408     for (const auto &Node : CompositeNodes)
409       Node->dump(OS);
410   }
411
412   /// Returns false if the deinterleaving operation should be cancelled for the
413   /// current graph.
414   bool identifyNodes(Instruction *RootI);
415
416   /// In case \pB is one-block loop, this function seeks potential reductions
417   /// and populates ReductionInfo. Returns true if any reductions were
418   /// identified.
419   bool collectPotentialReductions(BasicBlock *B);
420
421   void identifyReductionNodes();
422
423   /// Check that every instruction, from the roots to the leaves, has internal
424   /// uses.
425   bool checkNodes();
426
427   /// Perform the actual replacement of the underlying instruction graph.
428   void replaceNodes();
429 };
430
431 class ComplexDeinterleaving {
432 public:
433   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
434       : TL(tl), TLI(tli) {}
435   bool runOnFunction(Function &F);
436
437 private:
438   bool evaluateBasicBlock(BasicBlock *B);
439
440   const TargetLowering *TL = nullptr;
441   const TargetLibraryInfo *TLI = nullptr;
442 };
443
444 } // namespace
445
446 char ComplexDeinterleavingLegacyPass::ID = 0;
447
448 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
449                       "Complex Deinterleaving", false, false)
450 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
451                     "Complex Deinterleaving", false, false)
452
453 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
454                                                  FunctionAnalysisManager &AM) {
455   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
456   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
457   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
458     return PreservedAnalyses::all();
459
460   PreservedAnalyses PA;
461   PA.preserve<FunctionAnalysisManagerModuleProxy>();
462   return PA;
463 }
464
465 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
466   return new ComplexDeinterleavingLegacyPass(TM);
467 }
468
469 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
470   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
471   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
472   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
473 }
474
475 bool ComplexDeinterleaving::runOnFunction(Function &F) {
476   if (!ComplexDeinterleavingEnabled) {
477     LLVM_DEBUG(
478         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
479     return false;
480   }
481
482   if (!TL->isComplexDeinterleavingSupported()) {
483     LLVM_DEBUG(
484         dbgs() << "Complex deinterleaving has been disabled, target does "
485                   "not support lowering of complex number operations.\n");
486     return false;
487   }
488
489   bool Changed = false;
490   for (auto &B : F)
491     Changed |= evaluateBasicBlock(&B);
492
493   return Changed;
494 }
495
496 static bool isInterleavingMask(ArrayRef<int> Mask) {
497   // If the size is not even, it's not an interleaving mask
498   if ((Mask.size() & 1))
499     return false;
500
501   int HalfNumElements = Mask.size() / 2;
502   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
503     int MaskIdx = Idx * 2;
504     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
505       return false;
506   }
507
508   return true;
509 }
510
511 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
512   int Offset = Mask[0];
513   int HalfNumElements = Mask.size() / 2;
514
515   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
516     if (Mask[Idx] != (Idx * 2) + Offset)
517       return false;
518   }
519
520   return true;
521 }
522
523 bool isNeg(Value *V) {
524   return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
525 }
526
527 Value *getNegOperand(Value *V) {
528   assert(isNeg(V));
529   auto *I = cast<Instruction>(V);
530   if (I->getOpcode() == Instruction::FNeg)
531     return I->getOperand(0);
532
533   return I->getOperand(1);
534 }
535
536 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
537   ComplexDeinterleavingGraph Graph(TL, TLI);
538   if (Graph.collectPotentialReductions(B))
539     Graph.identifyReductionNodes();
540
541   for (auto &I : *B)
542     Graph.identifyNodes(&I);
543
544   if (Graph.checkNodes()) {
545     Graph.replaceNodes();
546     return true;
547   }
548
549   return false;
550 }
551
552 ComplexDeinterleavingGraph::NodePtr
553 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
554     Instruction *Real, Instruction *Imag,
555     std::pair<Value *, Value *> &PartialMatch) {
556   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
557                     << "\n");
558
559   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
560     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
561     return nullptr;
562   }
563
564   if ((Real->getOpcode() != Instruction::FMul &&
565        Real->getOpcode() != Instruction::Mul) ||
566       (Imag->getOpcode() != Instruction::FMul &&
567        Imag->getOpcode() != Instruction::Mul)) {
568     LLVM_DEBUG(
569         dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
570     return nullptr;
571   }
572
573   Value *R0 = Real->getOperand(0);
574   Value *R1 = Real->getOperand(1);
575   Value *I0 = Imag->getOperand(0);
576   Value *I1 = Imag->getOperand(1);
577
578   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
579   // rotations and use the operand.
580   unsigned Negs = 0;
581   Value *Op;
582   if (match(R0, m_Neg(m_Value(Op)))) {
583     Negs |= 1;
584     R0 = Op;
585   } else if (match(R1, m_Neg(m_Value(Op)))) {
586     Negs |= 1;
587     R1 = Op;
588   }
589
590   if (isNeg(I0)) {
591     Negs |= 2;
592     Negs ^= 1;
593     I0 = Op;
594   } else if (match(I1, m_Neg(m_Value(Op)))) {
595     Negs |= 2;
596     Negs ^= 1;
597     I1 = Op;
598   }
599
600   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
601
602   Value *CommonOperand;
603   Value *UncommonRealOp;
604   Value *UncommonImagOp;
605
606   if (R0 == I0 || R0 == I1) {
607     CommonOperand = R0;
608     UncommonRealOp = R1;
609   } else if (R1 == I0 || R1 == I1) {
610     CommonOperand = R1;
611     UncommonRealOp = R0;
612   } else {
613     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
614     return nullptr;
615   }
616
617   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
618   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
619       Rotation == ComplexDeinterleavingRotation::Rotation_270)
620     std::swap(UncommonRealOp, UncommonImagOp);
621
622   // Between identifyPartialMul and here we need to have found a complete valid
623   // pair from the CommonOperand of each part.
624   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
625       Rotation == ComplexDeinterleavingRotation::Rotation_180)
626     PartialMatch.first = CommonOperand;
627   else
628     PartialMatch.second = CommonOperand;
629
630   if (!PartialMatch.first || !PartialMatch.second) {
631     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
632     return nullptr;
633   }
634
635   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
636   if (!CommonNode) {
637     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
638     return nullptr;
639   }
640
641   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
642   if (!UncommonNode) {
643     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
644     return nullptr;
645   }
646
647   NodePtr Node = prepareCompositeNode(
648       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
649   Node->Rotation = Rotation;
650   Node->addOperand(CommonNode);
651   Node->addOperand(UncommonNode);
652   return submitCompositeNode(Node);
653 }
654
655 ComplexDeinterleavingGraph::NodePtr
656 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
657                                                Instruction *Imag) {
658   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
659                     << "\n");
660   // Determine rotation
661   auto IsAdd = [](unsigned Op) {
662     return Op == Instruction::FAdd || Op == Instruction::Add;
663   };
664   auto IsSub = [](unsigned Op) {
665     return Op == Instruction::FSub || Op == Instruction::Sub;
666   };
667   ComplexDeinterleavingRotation Rotation;
668   if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
669     Rotation = ComplexDeinterleavingRotation::Rotation_0;
670   else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
671     Rotation = ComplexDeinterleavingRotation::Rotation_90;
672   else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
673     Rotation = ComplexDeinterleavingRotation::Rotation_180;
674   else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
675     Rotation = ComplexDeinterleavingRotation::Rotation_270;
676   else {
677     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
678     return nullptr;
679   }
680
681   if (isa<FPMathOperator>(Real) &&
682       (!Real->getFastMathFlags().allowContract() ||
683        !Imag->getFastMathFlags().allowContract())) {
684     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
685     return nullptr;
686   }
687
688   Value *CR = Real->getOperand(0);
689   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
690   if (!RealMulI)
691     return nullptr;
692   Value *CI = Imag->getOperand(0);
693   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
694   if (!ImagMulI)
695     return nullptr;
696
697   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
698     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
699     return nullptr;
700   }
701
702   Value *R0 = RealMulI->getOperand(0);
703   Value *R1 = RealMulI->getOperand(1);
704   Value *I0 = ImagMulI->getOperand(0);
705   Value *I1 = ImagMulI->getOperand(1);
706
707   Value *CommonOperand;
708   Value *UncommonRealOp;
709   Value *UncommonImagOp;
710
711   if (R0 == I0 || R0 == I1) {
712     CommonOperand = R0;
713     UncommonRealOp = R1;
714   } else if (R1 == I0 || R1 == I1) {
715     CommonOperand = R1;
716     UncommonRealOp = R0;
717   } else {
718     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
719     return nullptr;
720   }
721
722   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
723   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
724       Rotation == ComplexDeinterleavingRotation::Rotation_270)
725     std::swap(UncommonRealOp, UncommonImagOp);
726
727   std::pair<Value *, Value *> PartialMatch(
728       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
729        Rotation == ComplexDeinterleavingRotation::Rotation_180)
730           ? CommonOperand
731           : nullptr,
732       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
733        Rotation == ComplexDeinterleavingRotation::Rotation_270)
734           ? CommonOperand
735           : nullptr);
736
737   auto *CRInst = dyn_cast<Instruction>(CR);
738   auto *CIInst = dyn_cast<Instruction>(CI);
739
740   if (!CRInst || !CIInst) {
741     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
742     return nullptr;
743   }
744
745   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
746   if (!CNode) {
747     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
748     return nullptr;
749   }
750
751   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
752   if (!UncommonRes) {
753     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
754     return nullptr;
755   }
756
757   assert(PartialMatch.first && PartialMatch.second);
758   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
759   if (!CommonRes) {
760     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
761     return nullptr;
762   }
763
764   NodePtr Node = prepareCompositeNode(
765       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
766   Node->Rotation = Rotation;
767   Node->addOperand(CommonRes);
768   Node->addOperand(UncommonRes);
769   Node->addOperand(CNode);
770   return submitCompositeNode(Node);
771 }
772
773 ComplexDeinterleavingGraph::NodePtr
774 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
775   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
776
777   // Determine rotation
778   ComplexDeinterleavingRotation Rotation;
779   if ((Real->getOpcode() == Instruction::FSub &&
780        Imag->getOpcode() == Instruction::FAdd) ||
781       (Real->getOpcode() == Instruction::Sub &&
782        Imag->getOpcode() == Instruction::Add))
783     Rotation = ComplexDeinterleavingRotation::Rotation_90;
784   else if ((Real->getOpcode() == Instruction::FAdd &&
785             Imag->getOpcode() == Instruction::FSub) ||
786            (Real->getOpcode() == Instruction::Add &&
787             Imag->getOpcode() == Instruction::Sub))
788     Rotation = ComplexDeinterleavingRotation::Rotation_270;
789   else {
790     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
791     return nullptr;
792   }
793
794   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
795   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
796   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
797   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
798
799   if (!AR || !AI || !BR || !BI) {
800     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
801     return nullptr;
802   }
803
804   NodePtr ResA = identifyNode(AR, AI);
805   if (!ResA) {
806     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
807     return nullptr;
808   }
809   NodePtr ResB = identifyNode(BR, BI);
810   if (!ResB) {
811     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
812     return nullptr;
813   }
814
815   NodePtr Node =
816       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
817   Node->Rotation = Rotation;
818   Node->addOperand(ResA);
819   Node->addOperand(ResB);
820   return submitCompositeNode(Node);
821 }
822
823 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
824   unsigned OpcA = A->getOpcode();
825   unsigned OpcB = B->getOpcode();
826
827   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
828          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
829          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
830          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
831 }
832
833 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
834   auto Pattern =
835       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
836
837   return match(A, Pattern) && match(B, Pattern);
838 }
839
840 static bool isInstructionPotentiallySymmetric(Instruction *I) {
841   switch (I->getOpcode()) {
842   case Instruction::FAdd:
843   case Instruction::FSub:
844   case Instruction::FMul:
845   case Instruction::FNeg:
846   case Instruction::Add:
847   case Instruction::Sub:
848   case Instruction::Mul:
849     return true;
850   default:
851     return false;
852   }
853 }
854
855 ComplexDeinterleavingGraph::NodePtr
856 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
857                                                        Instruction *Imag) {
858   if (Real->getOpcode() != Imag->getOpcode())
859     return nullptr;
860
861   if (!isInstructionPotentiallySymmetric(Real) ||
862       !isInstructionPotentiallySymmetric(Imag))
863     return nullptr;
864
865   auto *R0 = Real->getOperand(0);
866   auto *I0 = Imag->getOperand(0);
867
868   NodePtr Op0 = identifyNode(R0, I0);
869   NodePtr Op1 = nullptr;
870   if (Op0 == nullptr)
871     return nullptr;
872
873   if (Real->isBinaryOp()) {
874     auto *R1 = Real->getOperand(1);
875     auto *I1 = Imag->getOperand(1);
876     Op1 = identifyNode(R1, I1);
877     if (Op1 == nullptr)
878       return nullptr;
879   }
880
881   if (isa<FPMathOperator>(Real) &&
882       Real->getFastMathFlags() != Imag->getFastMathFlags())
883     return nullptr;
884
885   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
886                                    Real, Imag);
887   Node->Opcode = Real->getOpcode();
888   if (isa<FPMathOperator>(Real))
889     Node->Flags = Real->getFastMathFlags();
890
891   Node->addOperand(Op0);
892   if (Real->isBinaryOp())
893     Node->addOperand(Op1);
894
895   return submitCompositeNode(Node);
896 }
897
898 ComplexDeinterleavingGraph::NodePtr
899 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
900   LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
901   assert(R->getType() == I->getType() &&
902          "Real and imaginary parts should not have different types");
903   if (NodePtr CN = getContainingComposite(R, I)) {
904     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
905     return CN;
906   }
907
908   if (NodePtr CN = identifySplat(R, I))
909     return CN;
910
911   auto *Real = dyn_cast<Instruction>(R);
912   auto *Imag = dyn_cast<Instruction>(I);
913   if (!Real || !Imag)
914     return nullptr;
915
916   if (NodePtr CN = identifyDeinterleave(Real, Imag))
917     return CN;
918
919   if (NodePtr CN = identifyPHINode(Real, Imag))
920     return CN;
921
922   if (NodePtr CN = identifySelectNode(Real, Imag))
923     return CN;
924
925   auto *VTy = cast<VectorType>(Real->getType());
926   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
927
928   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
929       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
930   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
931       ComplexDeinterleavingOperation::CAdd, NewVTy);
932
933   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
934     if (NodePtr CN = identifyPartialMul(Real, Imag))
935       return CN;
936   }
937
938   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
939     if (NodePtr CN = identifyAdd(Real, Imag))
940       return CN;
941   }
942
943   if (HasCMulSupport && HasCAddSupport) {
944     if (NodePtr CN = identifyReassocNodes(Real, Imag))
945       return CN;
946   }
947
948   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
949     return CN;
950
951   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
952   return nullptr;
953 }
954
955 ComplexDeinterleavingGraph::NodePtr
956 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
957                                                  Instruction *Imag) {
958   auto IsOperationSupported = [](unsigned Opcode) -> bool {
959     return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
960            Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
961            Opcode == Instruction::Sub;
962   };
963
964   if (!IsOperationSupported(Real->getOpcode()) ||
965       !IsOperationSupported(Imag->getOpcode()))
966     return nullptr;
967
968   std::optional<FastMathFlags> Flags;
969   if (isa<FPMathOperator>(Real)) {
970     if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
971       LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
972                            "not identical\n");
973       return nullptr;
974     }
975
976     Flags = Real->getFastMathFlags();
977     if (!Flags->allowReassoc()) {
978       LLVM_DEBUG(
979           dbgs()
980           << "the 'Reassoc' attribute is missing in the FastMath flags\n");
981       return nullptr;
982     }
983   }
984
985   // Collect multiplications and addend instructions from the given instruction
986   // while traversing it operands. Additionally, verify that all instructions
987   // have the same fast math flags.
988   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
989                           std::list<Addend> &Addends) -> bool {
990     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
991     SmallPtrSet<Value *, 8> Visited;
992     while (!Worklist.empty()) {
993       auto [V, IsPositive] = Worklist.back();
994       Worklist.pop_back();
995       if (!Visited.insert(V).second)
996         continue;
997
998       Instruction *I = dyn_cast<Instruction>(V);
999       if (!I) {
1000         Addends.emplace_back(V, IsPositive);
1001         continue;
1002       }
1003
1004       // If an instruction has more than one user, it indicates that it either
1005       // has an external user, which will be later checked by the checkNodes
1006       // function, or it is a subexpression utilized by multiple expressions. In
1007       // the latter case, we will attempt to separately identify the complex
1008       // operation from here in order to create a shared
1009       // ComplexDeinterleavingCompositeNode.
1010       if (I != Insn && I->getNumUses() > 1) {
1011         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1012         Addends.emplace_back(I, IsPositive);
1013         continue;
1014       }
1015       switch (I->getOpcode()) {
1016       case Instruction::FAdd:
1017       case Instruction::Add:
1018         Worklist.emplace_back(I->getOperand(1), IsPositive);
1019         Worklist.emplace_back(I->getOperand(0), IsPositive);
1020         break;
1021       case Instruction::FSub:
1022         Worklist.emplace_back(I->getOperand(1), !IsPositive);
1023         Worklist.emplace_back(I->getOperand(0), IsPositive);
1024         break;
1025       case Instruction::Sub:
1026         if (isNeg(I)) {
1027           Worklist.emplace_back(getNegOperand(I), !IsPositive);
1028         } else {
1029           Worklist.emplace_back(I->getOperand(1), !IsPositive);
1030           Worklist.emplace_back(I->getOperand(0), IsPositive);
1031         }
1032         break;
1033       case Instruction::FMul:
1034       case Instruction::Mul: {
1035         Value *A, *B;
1036         if (isNeg(I->getOperand(0))) {
1037           A = getNegOperand(I->getOperand(0));
1038           IsPositive = !IsPositive;
1039         } else {
1040           A = I->getOperand(0);
1041         }
1042
1043         if (isNeg(I->getOperand(1))) {
1044           B = getNegOperand(I->getOperand(1));
1045           IsPositive = !IsPositive;
1046         } else {
1047           B = I->getOperand(1);
1048         }
1049         Muls.push_back(Product{A, B, IsPositive});
1050         break;
1051       }
1052       case Instruction::FNeg:
1053         Worklist.emplace_back(I->getOperand(0), !IsPositive);
1054         break;
1055       default:
1056         Addends.emplace_back(I, IsPositive);
1057         continue;
1058       }
1059
1060       if (Flags && I->getFastMathFlags() != *Flags) {
1061         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1062                              "inconsistent with the root instructions' flags: "
1063                           << *I << "\n");
1064         return false;
1065       }
1066     }
1067     return true;
1068   };
1069
1070   std::vector<Product> RealMuls, ImagMuls;
1071   std::list<Addend> RealAddends, ImagAddends;
1072   if (!Collect(Real, RealMuls, RealAddends) ||
1073       !Collect(Imag, ImagMuls, ImagAddends))
1074     return nullptr;
1075
1076   if (RealAddends.size() != ImagAddends.size())
1077     return nullptr;
1078
1079   NodePtr FinalNode;
1080   if (!RealMuls.empty() || !ImagMuls.empty()) {
1081     // If there are multiplicands, extract positive addend and use it as an
1082     // accumulator
1083     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1084     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1085     if (!FinalNode)
1086       return nullptr;
1087   }
1088
1089   // Identify and process remaining additions
1090   if (!RealAddends.empty() || !ImagAddends.empty()) {
1091     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1092     if (!FinalNode)
1093       return nullptr;
1094   }
1095   assert(FinalNode && "FinalNode can not be nullptr here");
1096   // Set the Real and Imag fields of the final node and submit it
1097   FinalNode->Real = Real;
1098   FinalNode->Imag = Imag;
1099   submitCompositeNode(FinalNode);
1100   return FinalNode;
1101 }
1102
1103 bool ComplexDeinterleavingGraph::collectPartialMuls(
1104     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1105     std::vector<PartialMulCandidate> &PartialMulCandidates) {
1106   // Helper function to extract a common operand from two products
1107   auto FindCommonInstruction = [](const Product &Real,
1108                                   const Product &Imag) -> Value * {
1109     if (Real.Multiplicand == Imag.Multiplicand ||
1110         Real.Multiplicand == Imag.Multiplier)
1111       return Real.Multiplicand;
1112
1113     if (Real.Multiplier == Imag.Multiplicand ||
1114         Real.Multiplier == Imag.Multiplier)
1115       return Real.Multiplier;
1116
1117     return nullptr;
1118   };
1119
1120   // Iterating over real and imaginary multiplications to find common operands
1121   // If a common operand is found, a partial multiplication candidate is created
1122   // and added to the candidates vector The function returns false if no common
1123   // operands are found for any product
1124   for (unsigned i = 0; i < RealMuls.size(); ++i) {
1125     bool FoundCommon = false;
1126     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1127       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1128       if (!Common)
1129         continue;
1130
1131       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1132                                                    : RealMuls[i].Multiplicand;
1133       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1134                                                    : ImagMuls[j].Multiplicand;
1135
1136       auto Node = identifyNode(A, B);
1137       if (Node) {
1138         FoundCommon = true;
1139         PartialMulCandidates.push_back({Common, Node, i, j, false});
1140       }
1141
1142       Node = identifyNode(B, A);
1143       if (Node) {
1144         FoundCommon = true;
1145         PartialMulCandidates.push_back({Common, Node, i, j, true});
1146       }
1147     }
1148     if (!FoundCommon)
1149       return false;
1150   }
1151   return true;
1152 }
1153
1154 ComplexDeinterleavingGraph::NodePtr
1155 ComplexDeinterleavingGraph::identifyMultiplications(
1156     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1157     NodePtr Accumulator = nullptr) {
1158   if (RealMuls.size() != ImagMuls.size())
1159     return nullptr;
1160
1161   std::vector<PartialMulCandidate> Info;
1162   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1163     return nullptr;
1164
1165   // Map to store common instruction to node pointers
1166   std::map<Value *, NodePtr> CommonToNode;
1167   std::vector<bool> Processed(Info.size(), false);
1168   for (unsigned I = 0; I < Info.size(); ++I) {
1169     if (Processed[I])
1170       continue;
1171
1172     PartialMulCandidate &InfoA = Info[I];
1173     for (unsigned J = I + 1; J < Info.size(); ++J) {
1174       if (Processed[J])
1175         continue;
1176
1177       PartialMulCandidate &InfoB = Info[J];
1178       auto *InfoReal = &InfoA;
1179       auto *InfoImag = &InfoB;
1180
1181       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1182       if (!NodeFromCommon) {
1183         std::swap(InfoReal, InfoImag);
1184         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1185       }
1186       if (!NodeFromCommon)
1187         continue;
1188
1189       CommonToNode[InfoReal->Common] = NodeFromCommon;
1190       CommonToNode[InfoImag->Common] = NodeFromCommon;
1191       Processed[I] = true;
1192       Processed[J] = true;
1193     }
1194   }
1195
1196   std::vector<bool> ProcessedReal(RealMuls.size(), false);
1197   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1198   NodePtr Result = Accumulator;
1199   for (auto &PMI : Info) {
1200     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1201       continue;
1202
1203     auto It = CommonToNode.find(PMI.Common);
1204     // TODO: Process independent complex multiplications. Cases like this:
1205     //  A.real() * B where both A and B are complex numbers.
1206     if (It == CommonToNode.end()) {
1207       LLVM_DEBUG({
1208         dbgs() << "Unprocessed independent partial multiplication:\n";
1209         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1210           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1211                            << " multiplied by " << *Mul->Multiplicand << "\n";
1212       });
1213       return nullptr;
1214     }
1215
1216     auto &RealMul = RealMuls[PMI.RealIdx];
1217     auto &ImagMul = ImagMuls[PMI.ImagIdx];
1218
1219     auto NodeA = It->second;
1220     auto NodeB = PMI.Node;
1221     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1222     // The following table illustrates the relationship between multiplications
1223     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1224     // can see:
1225     //
1226     // Rotation |   Real |   Imag |
1227     // ---------+--------+--------+
1228     //        0 |  x * u |  x * v |
1229     //       90 | -y * v |  y * u |
1230     //      180 | -x * u | -x * v |
1231     //      270 |  y * v | -y * u |
1232     //
1233     // Check if the candidate can indeed be represented by partial
1234     // multiplication
1235     // TODO: Add support for multiplication by complex one
1236     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1237         (!IsMultiplicandReal && !PMI.IsNodeInverted))
1238       continue;
1239
1240     // Determine the rotation based on the multiplications
1241     ComplexDeinterleavingRotation Rotation;
1242     if (IsMultiplicandReal) {
1243       // Detect 0 and 180 degrees rotation
1244       if (RealMul.IsPositive && ImagMul.IsPositive)
1245         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1246       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1247         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1248       else
1249         continue;
1250
1251     } else {
1252       // Detect 90 and 270 degrees rotation
1253       if (!RealMul.IsPositive && ImagMul.IsPositive)
1254         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1255       else if (RealMul.IsPositive && !ImagMul.IsPositive)
1256         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1257       else
1258         continue;
1259     }
1260
1261     LLVM_DEBUG({
1262       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1263       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1264       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1265       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1266       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1267       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1268     });
1269
1270     NodePtr NodeMul = prepareCompositeNode(
1271         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1272     NodeMul->Rotation = Rotation;
1273     NodeMul->addOperand(NodeA);
1274     NodeMul->addOperand(NodeB);
1275     if (Result)
1276       NodeMul->addOperand(Result);
1277     submitCompositeNode(NodeMul);
1278     Result = NodeMul;
1279     ProcessedReal[PMI.RealIdx] = true;
1280     ProcessedImag[PMI.ImagIdx] = true;
1281   }
1282
1283   // Ensure all products have been processed, if not return nullptr.
1284   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1285       !all_of(ProcessedImag, [](bool V) { return V; })) {
1286
1287     // Dump debug information about which partial multiplications are not
1288     // processed.
1289     LLVM_DEBUG({
1290       dbgs() << "Unprocessed products (Real):\n";
1291       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1292         if (!ProcessedReal[i])
1293           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1294                            << *RealMuls[i].Multiplier << " multiplied by "
1295                            << *RealMuls[i].Multiplicand << "\n";
1296       }
1297       dbgs() << "Unprocessed products (Imag):\n";
1298       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1299         if (!ProcessedImag[i])
1300           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1301                            << *ImagMuls[i].Multiplier << " multiplied by "
1302                            << *ImagMuls[i].Multiplicand << "\n";
1303       }
1304     });
1305     return nullptr;
1306   }
1307
1308   return Result;
1309 }
1310
1311 ComplexDeinterleavingGraph::NodePtr
1312 ComplexDeinterleavingGraph::identifyAdditions(
1313     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1314     std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1315   if (RealAddends.size() != ImagAddends.size())
1316     return nullptr;
1317
1318   NodePtr Result;
1319   // If we have accumulator use it as first addend
1320   if (Accumulator)
1321     Result = Accumulator;
1322   // Otherwise find an element with both positive real and imaginary parts.
1323   else
1324     Result = extractPositiveAddend(RealAddends, ImagAddends);
1325
1326   if (!Result)
1327     return nullptr;
1328
1329   while (!RealAddends.empty()) {
1330     auto ItR = RealAddends.begin();
1331     auto [R, IsPositiveR] = *ItR;
1332
1333     bool FoundImag = false;
1334     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1335       auto [I, IsPositiveI] = *ItI;
1336       ComplexDeinterleavingRotation Rotation;
1337       if (IsPositiveR && IsPositiveI)
1338         Rotation = ComplexDeinterleavingRotation::Rotation_0;
1339       else if (!IsPositiveR && IsPositiveI)
1340         Rotation = ComplexDeinterleavingRotation::Rotation_90;
1341       else if (!IsPositiveR && !IsPositiveI)
1342         Rotation = ComplexDeinterleavingRotation::Rotation_180;
1343       else
1344         Rotation = ComplexDeinterleavingRotation::Rotation_270;
1345
1346       NodePtr AddNode;
1347       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1348           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1349         AddNode = identifyNode(R, I);
1350       } else {
1351         AddNode = identifyNode(I, R);
1352       }
1353       if (AddNode) {
1354         LLVM_DEBUG({
1355           dbgs() << "Identified addition:\n";
1356           dbgs().indent(4) << "X: " << *R << "\n";
1357           dbgs().indent(4) << "Y: " << *I << "\n";
1358           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1359         });
1360
1361         NodePtr TmpNode;
1362         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1363           TmpNode = prepareCompositeNode(
1364               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1365           if (Flags) {
1366             TmpNode->Opcode = Instruction::FAdd;
1367             TmpNode->Flags = *Flags;
1368           } else {
1369             TmpNode->Opcode = Instruction::Add;
1370           }
1371         } else if (Rotation ==
1372                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
1373           TmpNode = prepareCompositeNode(
1374               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1375           if (Flags) {
1376             TmpNode->Opcode = Instruction::FSub;
1377             TmpNode->Flags = *Flags;
1378           } else {
1379             TmpNode->Opcode = Instruction::Sub;
1380           }
1381         } else {
1382           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1383                                          nullptr, nullptr);
1384           TmpNode->Rotation = Rotation;
1385         }
1386
1387         TmpNode->addOperand(Result);
1388         TmpNode->addOperand(AddNode);
1389         submitCompositeNode(TmpNode);
1390         Result = TmpNode;
1391         RealAddends.erase(ItR);
1392         ImagAddends.erase(ItI);
1393         FoundImag = true;
1394         break;
1395       }
1396     }
1397     if (!FoundImag)
1398       return nullptr;
1399   }
1400   return Result;
1401 }
1402
1403 ComplexDeinterleavingGraph::NodePtr
1404 ComplexDeinterleavingGraph::extractPositiveAddend(
1405     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1406   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1407     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1408       auto [R, IsPositiveR] = *ItR;
1409       auto [I, IsPositiveI] = *ItI;
1410       if (IsPositiveR && IsPositiveI) {
1411         auto Result = identifyNode(R, I);
1412         if (Result) {
1413           RealAddends.erase(ItR);
1414           ImagAddends.erase(ItI);
1415           return Result;
1416         }
1417       }
1418     }
1419   }
1420   return nullptr;
1421 }
1422
1423 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1424   // This potential root instruction might already have been recognized as
1425   // reduction. Because RootToNode maps both Real and Imaginary parts to
1426   // CompositeNode we should choose only one either Real or Imag instruction to
1427   // use as an anchor for generating complex instruction.
1428   auto It = RootToNode.find(RootI);
1429   if (It != RootToNode.end() && It->second->Real == RootI) {
1430     OrderedRoots.push_back(RootI);
1431     return true;
1432   }
1433
1434   auto RootNode = identifyRoot(RootI);
1435   if (!RootNode)
1436     return false;
1437
1438   LLVM_DEBUG({
1439     Function *F = RootI->getFunction();
1440     BasicBlock *B = RootI->getParent();
1441     dbgs() << "Complex deinterleaving graph for " << F->getName()
1442            << "::" << B->getName() << ".\n";
1443     dump(dbgs());
1444     dbgs() << "\n";
1445   });
1446   RootToNode[RootI] = RootNode;
1447   OrderedRoots.push_back(RootI);
1448   return true;
1449 }
1450
1451 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1452   bool FoundPotentialReduction = false;
1453
1454   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1455   if (!Br || Br->getNumSuccessors() != 2)
1456     return false;
1457
1458   // Identify simple one-block loop
1459   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1460     return false;
1461
1462   SmallVector<PHINode *> PHIs;
1463   for (auto &PHI : B->phis()) {
1464     if (PHI.getNumIncomingValues() != 2)
1465       continue;
1466
1467     if (!PHI.getType()->isVectorTy())
1468       continue;
1469
1470     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1471     if (!ReductionOp)
1472       continue;
1473
1474     // Check if final instruction is reduced outside of current block
1475     Instruction *FinalReduction = nullptr;
1476     auto NumUsers = 0u;
1477     for (auto *U : ReductionOp->users()) {
1478       ++NumUsers;
1479       if (U == &PHI)
1480         continue;
1481       FinalReduction = dyn_cast<Instruction>(U);
1482     }
1483
1484     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1485         isa<PHINode>(FinalReduction))
1486       continue;
1487
1488     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1489     BackEdge = B;
1490     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1491     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1492     Incoming = PHI.getIncomingBlock(IncomingIdx);
1493     FoundPotentialReduction = true;
1494
1495     // If the initial value of PHINode is an Instruction, consider it a leaf
1496     // value of a complex deinterleaving graph.
1497     if (auto *InitPHI =
1498             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1499       FinalInstructions.insert(InitPHI);
1500   }
1501   return FoundPotentialReduction;
1502 }
1503
1504 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1505   SmallVector<bool> Processed(ReductionInfo.size(), false);
1506   SmallVector<Instruction *> OperationInstruction;
1507   for (auto &P : ReductionInfo)
1508     OperationInstruction.push_back(P.first);
1509
1510   // Identify a complex computation by evaluating two reduction operations that
1511   // potentially could be involved
1512   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1513     if (Processed[i])
1514       continue;
1515     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1516       if (Processed[j])
1517         continue;
1518
1519       auto *Real = OperationInstruction[i];
1520       auto *Imag = OperationInstruction[j];
1521       if (Real->getType() != Imag->getType())
1522         continue;
1523
1524       RealPHI = ReductionInfo[Real].first;
1525       ImagPHI = ReductionInfo[Imag].first;
1526       PHIsFound = false;
1527       auto Node = identifyNode(Real, Imag);
1528       if (!Node) {
1529         std::swap(Real, Imag);
1530         std::swap(RealPHI, ImagPHI);
1531         Node = identifyNode(Real, Imag);
1532       }
1533
1534       // If a node is identified and reduction PHINode is used in the chain of
1535       // operations, mark its operation instructions as used to prevent
1536       // re-identification and attach the node to the real part
1537       if (Node && PHIsFound) {
1538         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1539                           << *Real << " / " << *Imag << "\n");
1540         Processed[i] = true;
1541         Processed[j] = true;
1542         auto RootNode = prepareCompositeNode(
1543             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1544         RootNode->addOperand(Node);
1545         RootToNode[Real] = RootNode;
1546         RootToNode[Imag] = RootNode;
1547         submitCompositeNode(RootNode);
1548         break;
1549       }
1550     }
1551   }
1552
1553   RealPHI = nullptr;
1554   ImagPHI = nullptr;
1555 }
1556
1557 bool ComplexDeinterleavingGraph::checkNodes() {
1558   // Collect all instructions from roots to leaves
1559   SmallPtrSet<Instruction *, 16> AllInstructions;
1560   SmallVector<Instruction *, 8> Worklist;
1561   for (auto &Pair : RootToNode)
1562     Worklist.push_back(Pair.first);
1563
1564   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1565   // chains
1566   while (!Worklist.empty()) {
1567     auto *I = Worklist.back();
1568     Worklist.pop_back();
1569
1570     if (!AllInstructions.insert(I).second)
1571       continue;
1572
1573     for (Value *Op : I->operands()) {
1574       if (auto *OpI = dyn_cast<Instruction>(Op)) {
1575         if (!FinalInstructions.count(I))
1576           Worklist.emplace_back(OpI);
1577       }
1578     }
1579   }
1580
1581   // Find instructions that have users outside of chain
1582   SmallVector<Instruction *, 2> OuterInstructions;
1583   for (auto *I : AllInstructions) {
1584     // Skip root nodes
1585     if (RootToNode.count(I))
1586       continue;
1587
1588     for (User *U : I->users()) {
1589       if (AllInstructions.count(cast<Instruction>(U)))
1590         continue;
1591
1592       // Found an instruction that is not used by XCMLA/XCADD chain
1593       Worklist.emplace_back(I);
1594       break;
1595     }
1596   }
1597
1598   // If any instructions are found to be used outside, find and remove roots
1599   // that somehow connect to those instructions.
1600   SmallPtrSet<Instruction *, 16> Visited;
1601   while (!Worklist.empty()) {
1602     auto *I = Worklist.back();
1603     Worklist.pop_back();
1604     if (!Visited.insert(I).second)
1605       continue;
1606
1607     // Found an impacted root node. Removing it from the nodes to be
1608     // deinterleaved
1609     if (RootToNode.count(I)) {
1610       LLVM_DEBUG(dbgs() << "Instruction " << *I
1611                         << " could be deinterleaved but its chain of complex "
1612                            "operations have an outside user\n");
1613       RootToNode.erase(I);
1614     }
1615
1616     if (!AllInstructions.count(I) || FinalInstructions.count(I))
1617       continue;
1618
1619     for (User *U : I->users())
1620       Worklist.emplace_back(cast<Instruction>(U));
1621
1622     for (Value *Op : I->operands()) {
1623       if (auto *OpI = dyn_cast<Instruction>(Op))
1624         Worklist.emplace_back(OpI);
1625     }
1626   }
1627   return !RootToNode.empty();
1628 }
1629
1630 ComplexDeinterleavingGraph::NodePtr
1631 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1632   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1633     if (Intrinsic->getIntrinsicID() !=
1634         Intrinsic::experimental_vector_interleave2)
1635       return nullptr;
1636
1637     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1638     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1639     if (!Real || !Imag)
1640       return nullptr;
1641
1642     return identifyNode(Real, Imag);
1643   }
1644
1645   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1646   if (!SVI)
1647     return nullptr;
1648
1649   // Look for a shufflevector that takes separate vectors of the real and
1650   // imaginary components and recombines them into a single vector.
1651   if (!isInterleavingMask(SVI->getShuffleMask()))
1652     return nullptr;
1653
1654   Instruction *Real;
1655   Instruction *Imag;
1656   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1657     return nullptr;
1658
1659   return identifyNode(Real, Imag);
1660 }
1661
1662 ComplexDeinterleavingGraph::NodePtr
1663 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1664                                                  Instruction *Imag) {
1665   Instruction *I = nullptr;
1666   Value *FinalValue = nullptr;
1667   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1668       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1669       match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1670                    m_Value(FinalValue)))) {
1671     NodePtr PlaceholderNode = prepareCompositeNode(
1672         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1673     PlaceholderNode->ReplacementNode = FinalValue;
1674     FinalInstructions.insert(Real);
1675     FinalInstructions.insert(Imag);
1676     return submitCompositeNode(PlaceholderNode);
1677   }
1678
1679   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1680   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1681   if (!RealShuffle || !ImagShuffle) {
1682     if (RealShuffle || ImagShuffle)
1683       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1684     return nullptr;
1685   }
1686
1687   Value *RealOp1 = RealShuffle->getOperand(1);
1688   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1689     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1690     return nullptr;
1691   }
1692   Value *ImagOp1 = ImagShuffle->getOperand(1);
1693   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1694     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1695     return nullptr;
1696   }
1697
1698   Value *RealOp0 = RealShuffle->getOperand(0);
1699   Value *ImagOp0 = ImagShuffle->getOperand(0);
1700
1701   if (RealOp0 != ImagOp0) {
1702     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1703     return nullptr;
1704   }
1705
1706   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1707   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1708   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1709     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1710     return nullptr;
1711   }
1712
1713   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1714     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1715     return nullptr;
1716   }
1717
1718   // Type checking, the shuffle type should be a vector type of the same
1719   // scalar type, but half the size
1720   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1721     Value *Op = Shuffle->getOperand(0);
1722     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1723     auto *OpTy = cast<FixedVectorType>(Op->getType());
1724
1725     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1726       return false;
1727     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1728       return false;
1729
1730     return true;
1731   };
1732
1733   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1734     if (!CheckType(Shuffle))
1735       return false;
1736
1737     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1738     int Last = *Mask.rbegin();
1739
1740     Value *Op = Shuffle->getOperand(0);
1741     auto *OpTy = cast<FixedVectorType>(Op->getType());
1742     int NumElements = OpTy->getNumElements();
1743
1744     // Ensure that the deinterleaving shuffle only pulls from the first
1745     // shuffle operand.
1746     return Last < NumElements;
1747   };
1748
1749   if (RealShuffle->getType() != ImagShuffle->getType()) {
1750     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1751     return nullptr;
1752   }
1753   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1754     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1755     return nullptr;
1756   }
1757   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1758     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1759     return nullptr;
1760   }
1761
1762   NodePtr PlaceholderNode =
1763       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1764                            RealShuffle, ImagShuffle);
1765   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1766   FinalInstructions.insert(RealShuffle);
1767   FinalInstructions.insert(ImagShuffle);
1768   return submitCompositeNode(PlaceholderNode);
1769 }
1770
1771 ComplexDeinterleavingGraph::NodePtr
1772 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1773   auto IsSplat = [](Value *V) -> bool {
1774     // Fixed-width vector with constants
1775     if (isa<ConstantDataVector>(V))
1776       return true;
1777
1778     VectorType *VTy;
1779     ArrayRef<int> Mask;
1780     // Splats are represented differently depending on whether the repeated
1781     // value is a constant or an Instruction
1782     if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1783       if (Const->getOpcode() != Instruction::ShuffleVector)
1784         return false;
1785       VTy = cast<VectorType>(Const->getType());
1786       Mask = Const->getShuffleMask();
1787     } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1788       VTy = Shuf->getType();
1789       Mask = Shuf->getShuffleMask();
1790     } else {
1791       return false;
1792     }
1793
1794     // When the data type is <1 x Type>, it's not possible to differentiate
1795     // between the ComplexDeinterleaving::Deinterleave and
1796     // ComplexDeinterleaving::Splat operations.
1797     if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1798       return false;
1799
1800     return all_equal(Mask) && Mask[0] == 0;
1801   };
1802
1803   if (!IsSplat(R) || !IsSplat(I))
1804     return nullptr;
1805
1806   auto *Real = dyn_cast<Instruction>(R);
1807   auto *Imag = dyn_cast<Instruction>(I);
1808   if ((!Real && Imag) || (Real && !Imag))
1809     return nullptr;
1810
1811   if (Real && Imag) {
1812     // Non-constant splats should be in the same basic block
1813     if (Real->getParent() != Imag->getParent())
1814       return nullptr;
1815
1816     FinalInstructions.insert(Real);
1817     FinalInstructions.insert(Imag);
1818   }
1819   NodePtr PlaceholderNode =
1820       prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1821   return submitCompositeNode(PlaceholderNode);
1822 }
1823
1824 ComplexDeinterleavingGraph::NodePtr
1825 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1826                                             Instruction *Imag) {
1827   if (Real != RealPHI || Imag != ImagPHI)
1828     return nullptr;
1829
1830   PHIsFound = true;
1831   NodePtr PlaceholderNode = prepareCompositeNode(
1832       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1833   return submitCompositeNode(PlaceholderNode);
1834 }
1835
1836 ComplexDeinterleavingGraph::NodePtr
1837 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1838                                                Instruction *Imag) {
1839   auto *SelectReal = dyn_cast<SelectInst>(Real);
1840   auto *SelectImag = dyn_cast<SelectInst>(Imag);
1841   if (!SelectReal || !SelectImag)
1842     return nullptr;
1843
1844   Instruction *MaskA, *MaskB;
1845   Instruction *AR, *AI, *RA, *BI;
1846   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1847                             m_Instruction(RA))) ||
1848       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1849                             m_Instruction(BI))))
1850     return nullptr;
1851
1852   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1853     return nullptr;
1854
1855   if (!MaskA->getType()->isVectorTy())
1856     return nullptr;
1857
1858   auto NodeA = identifyNode(AR, AI);
1859   if (!NodeA)
1860     return nullptr;
1861
1862   auto NodeB = identifyNode(RA, BI);
1863   if (!NodeB)
1864     return nullptr;
1865
1866   NodePtr PlaceholderNode = prepareCompositeNode(
1867       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1868   PlaceholderNode->addOperand(NodeA);
1869   PlaceholderNode->addOperand(NodeB);
1870   FinalInstructions.insert(MaskA);
1871   FinalInstructions.insert(MaskB);
1872   return submitCompositeNode(PlaceholderNode);
1873 }
1874
1875 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1876                                    std::optional<FastMathFlags> Flags,
1877                                    Value *InputA, Value *InputB) {
1878   Value *I;
1879   switch (Opcode) {
1880   case Instruction::FNeg:
1881     I = B.CreateFNeg(InputA);
1882     break;
1883   case Instruction::FAdd:
1884     I = B.CreateFAdd(InputA, InputB);
1885     break;
1886   case Instruction::Add:
1887     I = B.CreateAdd(InputA, InputB);
1888     break;
1889   case Instruction::FSub:
1890     I = B.CreateFSub(InputA, InputB);
1891     break;
1892   case Instruction::Sub:
1893     I = B.CreateSub(InputA, InputB);
1894     break;
1895   case Instruction::FMul:
1896     I = B.CreateFMul(InputA, InputB);
1897     break;
1898   case Instruction::Mul:
1899     I = B.CreateMul(InputA, InputB);
1900     break;
1901   default:
1902     llvm_unreachable("Incorrect symmetric opcode");
1903   }
1904   if (Flags)
1905     cast<Instruction>(I)->setFastMathFlags(*Flags);
1906   return I;
1907 }
1908
1909 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1910                                                RawNodePtr Node) {
1911   if (Node->ReplacementNode)
1912     return Node->ReplacementNode;
1913
1914   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1915     return Node->Operands.size() > Idx
1916                ? replaceNode(Builder, Node->Operands[Idx])
1917                : nullptr;
1918   };
1919
1920   Value *ReplacementNode;
1921   switch (Node->Operation) {
1922   case ComplexDeinterleavingOperation::CAdd:
1923   case ComplexDeinterleavingOperation::CMulPartial:
1924   case ComplexDeinterleavingOperation::Symmetric: {
1925     Value *Input0 = ReplaceOperandIfExist(Node, 0);
1926     Value *Input1 = ReplaceOperandIfExist(Node, 1);
1927     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1928     assert(!Input1 || (Input0->getType() == Input1->getType() &&
1929                        "Node inputs need to be of the same type"));
1930     assert(!Accumulator ||
1931            (Input0->getType() == Accumulator->getType() &&
1932             "Accumulator and input need to be of the same type"));
1933     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1934       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1935                                              Input0, Input1);
1936     else
1937       ReplacementNode = TL->createComplexDeinterleavingIR(
1938           Builder, Node->Operation, Node->Rotation, Input0, Input1,
1939           Accumulator);
1940     break;
1941   }
1942   case ComplexDeinterleavingOperation::Deinterleave:
1943     llvm_unreachable("Deinterleave node should already have ReplacementNode");
1944     break;
1945   case ComplexDeinterleavingOperation::Splat: {
1946     auto *NewTy = VectorType::getDoubleElementsVectorType(
1947         cast<VectorType>(Node->Real->getType()));
1948     auto *R = dyn_cast<Instruction>(Node->Real);
1949     auto *I = dyn_cast<Instruction>(Node->Imag);
1950     if (R && I) {
1951       // Splats that are not constant are interleaved where they are located
1952       Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1953       IRBuilder<> IRB(InsertPoint);
1954       ReplacementNode =
1955           IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
1956                               {Node->Real, Node->Imag});
1957     } else {
1958       ReplacementNode =
1959           Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1960                                   NewTy, {Node->Real, Node->Imag});
1961     }
1962     break;
1963   }
1964   case ComplexDeinterleavingOperation::ReductionPHI: {
1965     // If Operation is ReductionPHI, a new empty PHINode is created.
1966     // It is filled later when the ReductionOperation is processed.
1967     auto *VTy = cast<VectorType>(Node->Real->getType());
1968     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1969     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1970     OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1971     ReplacementNode = NewPHI;
1972     break;
1973   }
1974   case ComplexDeinterleavingOperation::ReductionOperation:
1975     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1976     processReductionOperation(ReplacementNode, Node);
1977     break;
1978   case ComplexDeinterleavingOperation::ReductionSelect: {
1979     auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1980     auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1981     auto *A = replaceNode(Builder, Node->Operands[0]);
1982     auto *B = replaceNode(Builder, Node->Operands[1]);
1983     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1984         cast<VectorType>(MaskReal->getType()));
1985     auto *NewMask =
1986         Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1987                                 NewMaskTy, {MaskReal, MaskImag});
1988     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1989     break;
1990   }
1991   }
1992
1993   assert(ReplacementNode && "Target failed to create Intrinsic call.");
1994   NumComplexTransformations += 1;
1995   Node->ReplacementNode = ReplacementNode;
1996   return ReplacementNode;
1997 }
1998
1999 void ComplexDeinterleavingGraph::processReductionOperation(
2000     Value *OperationReplacement, RawNodePtr Node) {
2001   auto *Real = cast<Instruction>(Node->Real);
2002   auto *Imag = cast<Instruction>(Node->Imag);
2003   auto *OldPHIReal = ReductionInfo[Real].first;
2004   auto *OldPHIImag = ReductionInfo[Imag].first;
2005   auto *NewPHI = OldToNewPHI[OldPHIReal];
2006
2007   auto *VTy = cast<VectorType>(Real->getType());
2008   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2009
2010   // We have to interleave initial origin values coming from IncomingBlock
2011   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2012   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2013
2014   IRBuilder<> Builder(Incoming->getTerminator());
2015   auto *NewInit = Builder.CreateIntrinsic(
2016       Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
2017
2018   NewPHI->addIncoming(NewInit, Incoming);
2019   NewPHI->addIncoming(OperationReplacement, BackEdge);
2020
2021   // Deinterleave complex vector outside of loop so that it can be finally
2022   // reduced
2023   auto *FinalReductionReal = ReductionInfo[Real].second;
2024   auto *FinalReductionImag = ReductionInfo[Imag].second;
2025
2026   Builder.SetInsertPoint(
2027       &*FinalReductionReal->getParent()->getFirstInsertionPt());
2028   auto *Deinterleave = Builder.CreateIntrinsic(
2029       Intrinsic::experimental_vector_deinterleave2,
2030       OperationReplacement->getType(), OperationReplacement);
2031
2032   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2033   FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2034
2035   Builder.SetInsertPoint(FinalReductionImag);
2036   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2037   FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2038 }
2039
2040 void ComplexDeinterleavingGraph::replaceNodes() {
2041   SmallVector<Instruction *, 16> DeadInstrRoots;
2042   for (auto *RootInstruction : OrderedRoots) {
2043     // Check if this potential root went through check process and we can
2044     // deinterleave it
2045     if (!RootToNode.count(RootInstruction))
2046       continue;
2047
2048     IRBuilder<> Builder(RootInstruction);
2049     auto RootNode = RootToNode[RootInstruction];
2050     Value *R = replaceNode(Builder, RootNode.get());
2051
2052     if (RootNode->Operation ==
2053         ComplexDeinterleavingOperation::ReductionOperation) {
2054       auto *RootReal = cast<Instruction>(RootNode->Real);
2055       auto *RootImag = cast<Instruction>(RootNode->Imag);
2056       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2057       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2058       DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2059       DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2060     } else {
2061       assert(R && "Unable to find replacement for RootInstruction");
2062       DeadInstrRoots.push_back(RootInstruction);
2063       RootInstruction->replaceAllUsesWith(R);
2064     }
2065   }
2066
2067   for (auto *I : DeadInstrRoots)
2068     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2069 }