]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm-project/llvm/lib/Target/ARM/ARMParallelDSP.cpp
Fix a memory leak in if_delgroups() introduced in r334118.
[FreeBSD/FreeBSD.git] / contrib / llvm-project / llvm / lib / Target / ARM / ARMParallelDSP.cpp
1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/NoFolder.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 #include "llvm/Transforms/Utils/LoopUtils.h"
28 #include "llvm/Pass.h"
29 #include "llvm/PassRegistry.h"
30 #include "llvm/PassSupport.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/CodeGen/TargetPassConfig.h"
34 #include "ARM.h"
35 #include "ARMSubtarget.h"
36
37 using namespace llvm;
38 using namespace PatternMatch;
39
40 #define DEBUG_TYPE "arm-parallel-dsp"
41
42 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
43
44 static cl::opt<bool>
45 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
46                    cl::desc("Disable the ARM Parallel DSP pass"));
47
48 namespace {
49   struct OpChain;
50   struct BinOpChain;
51   class Reduction;
52
53   using OpChainList     = SmallVector<std::unique_ptr<OpChain>, 8>;
54   using ReductionList   = SmallVector<Reduction, 8>;
55   using ValueList       = SmallVector<Value*, 8>;
56   using MemInstList     = SmallVector<LoadInst*, 8>;
57   using PMACPair        = std::pair<BinOpChain*,BinOpChain*>;
58   using PMACPairList    = SmallVector<PMACPair, 8>;
59   using Instructions    = SmallVector<Instruction*,16>;
60   using MemLocList      = SmallVector<MemoryLocation, 4>;
61
62   struct OpChain {
63     Instruction   *Root;
64     ValueList     AllValues;
65     MemInstList   VecLd;    // List of all load instructions.
66     MemInstList   Loads;
67     bool          ReadOnly = true;
68
69     OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
70     virtual ~OpChain() = default;
71
72     void PopulateLoads() {
73       for (auto *V : AllValues) {
74         if (auto *Ld = dyn_cast<LoadInst>(V))
75           Loads.push_back(Ld);
76       }
77     }
78
79     unsigned size() const { return AllValues.size(); }
80   };
81
82   // 'BinOpChain' holds the multiplication instructions that are candidates
83   // for parallel execution.
84   struct BinOpChain : public OpChain {
85     ValueList     LHS;      // List of all (narrow) left hand operands.
86     ValueList     RHS;      // List of all (narrow) right hand operands.
87     bool Exchange = false;
88
89     BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
90       OpChain(I, lhs), LHS(lhs), RHS(rhs) {
91         for (auto *V : RHS)
92           AllValues.push_back(V);
93       }
94
95     bool AreSymmetrical(BinOpChain *Other);
96   };
97
98   /// Represent a sequence of multiply-accumulate operations with the aim to
99   /// perform the multiplications in parallel.
100   class Reduction {
101     Instruction     *Root = nullptr;
102     Value           *Acc = nullptr;
103     OpChainList     Muls;
104     PMACPairList        MulPairs;
105     SmallPtrSet<Instruction*, 4> Adds;
106
107   public:
108     Reduction() = delete;
109
110     Reduction (Instruction *Add) : Root(Add) { }
111
112     /// Record an Add instruction that is a part of the this reduction.
113     void InsertAdd(Instruction *I) { Adds.insert(I); }
114
115     /// Record a BinOpChain, rooted at a Mul instruction, that is a part of
116     /// this reduction.
117     void InsertMul(Instruction *I, ValueList &LHS, ValueList &RHS) {
118       Muls.push_back(make_unique<BinOpChain>(I, LHS, RHS));
119     }
120
121     /// Add the incoming accumulator value, returns true if a value had not
122     /// already been added. Returning false signals to the user that this
123     /// reduction already has a value to initialise the accumulator.
124     bool InsertAcc(Value *V) {
125       if (Acc)
126         return false;
127       Acc = V;
128       return true;
129     }
130
131     /// Set two BinOpChains, rooted at muls, that can be executed as a single
132     /// parallel operation.
133     void AddMulPair(BinOpChain *Mul0, BinOpChain *Mul1) {
134       MulPairs.push_back(std::make_pair(Mul0, Mul1));
135     }
136
137     /// Return true if enough mul operations are found that can be executed in
138     /// parallel.
139     bool CreateParallelPairs();
140
141     /// Return the add instruction which is the root of the reduction.
142     Instruction *getRoot() { return Root; }
143
144     /// Return the incoming value to be accumulated. This maybe null.
145     Value *getAccumulator() { return Acc; }
146
147     /// Return the set of adds that comprise the reduction.
148     SmallPtrSetImpl<Instruction*> &getAdds() { return Adds; }
149
150     /// Return the BinOpChain, rooted at mul instruction, that comprise the
151     /// the reduction.
152     OpChainList &getMuls() { return Muls; }
153
154     /// Return the BinOpChain, rooted at mul instructions, that have been
155     /// paired for parallel execution.
156     PMACPairList &getMulPairs() { return MulPairs; }
157
158     /// To finalise, replace the uses of the root with the intrinsic call.
159     void UpdateRoot(Instruction *SMLAD) {
160       Root->replaceAllUsesWith(SMLAD);
161     }
162   };
163
164   class WidenedLoad {
165     LoadInst *NewLd = nullptr;
166     SmallVector<LoadInst*, 4> Loads;
167
168   public:
169     WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
170       : NewLd(Wide) {
171       for (auto *I : Lds)
172         Loads.push_back(I);
173     }
174     LoadInst *getLoad() {
175       return NewLd;
176     }
177   };
178
179   class ARMParallelDSP : public LoopPass {
180     ScalarEvolution   *SE;
181     AliasAnalysis     *AA;
182     TargetLibraryInfo *TLI;
183     DominatorTree     *DT;
184     LoopInfo          *LI;
185     Loop              *L;
186     const DataLayout  *DL;
187     Module            *M;
188     std::map<LoadInst*, LoadInst*> LoadPairs;
189     SmallPtrSet<LoadInst*, 4> OffsetLoads;
190     std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
191
192     template<unsigned>
193     bool IsNarrowSequence(Value *V, ValueList &VL);
194
195     bool RecordMemoryOps(BasicBlock *BB);
196     void InsertParallelMACs(Reduction &Reduction);
197     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
198     LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
199                              IntegerType *LoadTy);
200     bool CreateParallelPairs(Reduction &R);
201
202     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
203     /// Dual performs two signed 16x16-bit multiplications. It adds the
204     /// products to a 32-bit accumulate operand. Optionally, the instruction can
205     /// exchange the halfwords of the second operand before performing the
206     /// arithmetic.
207     bool MatchSMLAD(Loop *L);
208
209   public:
210     static char ID;
211
212     ARMParallelDSP() : LoopPass(ID) { }
213
214     bool doInitialization(Loop *L, LPPassManager &LPM) override {
215       LoadPairs.clear();
216       WideLoads.clear();
217       return true;
218     }
219
220     void getAnalysisUsage(AnalysisUsage &AU) const override {
221       LoopPass::getAnalysisUsage(AU);
222       AU.addRequired<AssumptionCacheTracker>();
223       AU.addRequired<ScalarEvolutionWrapperPass>();
224       AU.addRequired<AAResultsWrapperPass>();
225       AU.addRequired<TargetLibraryInfoWrapperPass>();
226       AU.addRequired<LoopInfoWrapperPass>();
227       AU.addRequired<DominatorTreeWrapperPass>();
228       AU.addRequired<TargetPassConfig>();
229       AU.addPreserved<LoopInfoWrapperPass>();
230       AU.setPreservesCFG();
231     }
232
233     bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
234       if (DisableParallelDSP)
235         return false;
236       L = TheLoop;
237       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
238       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
239       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
240       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
241       LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
242       auto &TPC = getAnalysis<TargetPassConfig>();
243
244       BasicBlock *Header = TheLoop->getHeader();
245       if (!Header)
246         return false;
247
248       // TODO: We assume the loop header and latch to be the same block.
249       // This is not a fundamental restriction, but lifting this would just
250       // require more work to do the transformation and then patch up the CFG.
251       if (Header != TheLoop->getLoopLatch()) {
252         LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
253                              "running pass ARMParallelDSP\n");
254         return false;
255       }
256
257       if (!TheLoop->getLoopPreheader())
258         InsertPreheaderForLoop(L, DT, LI, nullptr, true);
259
260       Function &F = *Header->getParent();
261       M = F.getParent();
262       DL = &M->getDataLayout();
263
264       auto &TM = TPC.getTM<TargetMachine>();
265       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
266
267       if (!ST->allowsUnalignedMem()) {
268         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
269                              "running pass ARMParallelDSP\n");
270         return false;
271       }
272
273       if (!ST->hasDSP()) {
274         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
275                              "ARMParallelDSP\n");
276         return false;
277       }
278
279       if (!ST->isLittle()) {
280         LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
281                           << "ARMParallelDSP\n");
282         return false;
283       }
284
285       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
286
287       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
288       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
289
290       if (!RecordMemoryOps(Header)) {
291         LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
292         return false;
293       }
294
295       bool Changes = MatchSMLAD(L);
296       return Changes;
297     }
298   };
299 }
300
301 template<typename MemInst>
302 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
303                                   const DataLayout &DL, ScalarEvolution &SE) {
304   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
305     return true;
306   return false;
307 }
308
309 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
310                                         MemInstList &VecMem) {
311   if (!Ld0 || !Ld1)
312     return false;
313
314   if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
315     return false;
316
317   LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
318     dbgs() << "Ld0:"; Ld0->dump();
319     dbgs() << "Ld1:"; Ld1->dump();
320   );
321
322   VecMem.clear();
323   VecMem.push_back(Ld0);
324   VecMem.push_back(Ld1);
325   return true;
326 }
327
328 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
329 // instructions, which is set to 16. So here we should collect all i8 and i16
330 // narrow operations.
331 // TODO: we currently only collect i16, and will support i8 later, so that's
332 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
333 template<unsigned MaxBitWidth>
334 bool ARMParallelDSP::IsNarrowSequence(Value *V, ValueList &VL) {
335   ConstantInt *CInt;
336
337   if (match(V, m_ConstantInt(CInt))) {
338     // TODO: if a constant is used, it needs to fit within the bit width.
339     return false;
340   }
341
342   auto *I = dyn_cast<Instruction>(V);
343   if (!I)
344     return false;
345
346   Value *Val, *LHS, *RHS;
347   if (match(V, m_Trunc(m_Value(Val)))) {
348     if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
349       return IsNarrowSequence<MaxBitWidth>(Val, VL);
350   } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
351     // TODO: we need to implement sadd16/sadd8 for this, which enables to
352     // also do the rewrite for smlad8.ll, but it is unsupported for now.
353     return false;
354   } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
355     if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
356       return false;
357
358     if (match(Val, m_Load(m_Value()))) {
359       auto *Ld = cast<LoadInst>(Val);
360
361       // Check that these load could be paired.
362       if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld))
363         return false;
364
365       VL.push_back(Val);
366       VL.push_back(I);
367       return true;
368     }
369   }
370   return false;
371 }
372
373 /// Iterate through the block and record base, offset pairs of loads which can
374 /// be widened into a single load.
375 bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
376   SmallVector<LoadInst*, 8> Loads;
377   SmallVector<Instruction*, 8> Writes;
378
379   // Collect loads and instruction that may write to memory. For now we only
380   // record loads which are simple, sign-extended and have a single user.
381   // TODO: Allow zero-extended loads.
382   for (auto &I : *BB) {
383     if (I.mayWriteToMemory())
384       Writes.push_back(&I);
385     auto *Ld = dyn_cast<LoadInst>(&I);
386     if (!Ld || !Ld->isSimple() ||
387         !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
388       continue;
389     Loads.push_back(Ld);
390   }
391
392   using InstSet = std::set<Instruction*>;
393   using DepMap = std::map<Instruction*, InstSet>;
394   DepMap RAWDeps;
395
396   // Record any writes that may alias a load.
397   const auto Size = LocationSize::unknown();
398   for (auto Read : Loads) {
399     for (auto Write : Writes) {
400       MemoryLocation ReadLoc =
401         MemoryLocation(Read->getPointerOperand(), Size);
402
403       if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
404           ModRefInfo::ModRef)))
405         continue;
406       if (DT->dominates(Write, Read))
407         RAWDeps[Read].insert(Write);
408     }
409   }
410
411   // Check whether there's not a write between the two loads which would
412   // prevent them from being safely merged.
413   auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
414     LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
415     LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
416
417     if (RAWDeps.count(Dominated)) {
418       InstSet &WritesBefore = RAWDeps[Dominated];
419
420       for (auto Before : WritesBefore) {
421
422         // We can't move the second load backward, past a write, to merge
423         // with the first load.
424         if (DT->dominates(Dominator, Before))
425           return false;
426       }
427     }
428     return true;
429   };
430
431   // Record base, offset load pairs.
432   for (auto *Base : Loads) {
433     for (auto *Offset : Loads) {
434       if (Base == Offset)
435         continue;
436
437       if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
438           SafeToPair(Base, Offset)) {
439         LoadPairs[Base] = Offset;
440         OffsetLoads.insert(Offset);
441         break;
442       }
443     }
444   }
445
446   LLVM_DEBUG(if (!LoadPairs.empty()) {
447                dbgs() << "Consecutive load pairs:\n";
448                for (auto &MapIt : LoadPairs) {
449                  LLVM_DEBUG(dbgs() << *MapIt.first << ", "
450                             << *MapIt.second << "\n");
451                }
452              });
453   return LoadPairs.size() > 1;
454 }
455
456 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
457 // multiplications.
458 // To use SMLAD:
459 // 1) we first need to find integer add then look for this pattern:
460 //
461 // acc0 = ...
462 // ld0 = load i16
463 // sext0 = sext i16 %ld0 to i32
464 // ld1 = load i16
465 // sext1 = sext i16 %ld1 to i32
466 // mul0 = mul %sext0, %sext1
467 // ld2 = load i16
468 // sext2 = sext i16 %ld2 to i32
469 // ld3 = load i16
470 // sext3 = sext i16 %ld3 to i32
471 // mul1 = mul i32 %sext2, %sext3
472 // add0 = add i32 %mul0, %acc0
473 // acc1 = add i32 %add0, %mul1
474 //
475 // Which can be selected to:
476 //
477 // ldr r0
478 // ldr r1
479 // smlad r2, r0, r1, r2
480 //
481 // If constants are used instead of loads, these will need to be hoisted
482 // out and into a register.
483 //
484 // If loop invariants are used instead of loads, these need to be packed
485 // before the loop begins.
486 //
487 bool ARMParallelDSP::MatchSMLAD(Loop *L) {
488   // Search recursively back through the operands to find a tree of values that
489   // form a multiply-accumulate chain. The search records the Add and Mul
490   // instructions that form the reduction and allows us to find a single value
491   // to be used as the initial input to the accumlator.
492   std::function<bool(Value*, Reduction&)> Search = [&]
493     (Value *V, Reduction &R) -> bool {
494
495     // If we find a non-instruction, try to use it as the initial accumulator
496     // value. This may have already been found during the search in which case
497     // this function will return false, signaling a search fail.
498     auto *I = dyn_cast<Instruction>(V);
499     if (!I)
500       return R.InsertAcc(V);
501
502     switch (I->getOpcode()) {
503     default:
504       break;
505     case Instruction::PHI:
506       // Could be the accumulator value.
507       return R.InsertAcc(V);
508     case Instruction::Add: {
509       // Adds should be adding together two muls, or another add and a mul to
510       // be within the mac chain. One of the operands may also be the
511       // accumulator value at which point we should stop searching.
512       bool ValidLHS = Search(I->getOperand(0), R);
513       bool ValidRHS = Search(I->getOperand(1), R);
514       if (!ValidLHS && !ValidLHS)
515         return false;
516       else if (ValidLHS && ValidRHS) {
517         R.InsertAdd(I);
518         return true;
519       } else {
520         R.InsertAdd(I);
521         return R.InsertAcc(I);
522       }
523     }
524     case Instruction::Mul: {
525       Value *MulOp0 = I->getOperand(0);
526       Value *MulOp1 = I->getOperand(1);
527       if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
528         ValueList LHS;
529         ValueList RHS;
530         if (IsNarrowSequence<16>(MulOp0, LHS) &&
531             IsNarrowSequence<16>(MulOp1, RHS)) {
532           R.InsertMul(I, LHS, RHS);
533           return true;
534         }
535       }
536       return false;
537     }
538     case Instruction::SExt:
539       return Search(I->getOperand(0), R);
540     }
541     return false;
542   };
543
544   bool Changed = false;
545   SmallPtrSet<Instruction*, 4> AllAdds;
546   BasicBlock *Latch = L->getLoopLatch();
547
548   for (Instruction &I : reverse(*Latch)) {
549     if (I.getOpcode() != Instruction::Add)
550       continue;
551
552     if (AllAdds.count(&I))
553       continue;
554
555     const auto *Ty = I.getType();
556     if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
557       continue;
558
559     Reduction R(&I);
560     if (!Search(&I, R))
561       continue;
562
563     if (!CreateParallelPairs(R))
564       continue;
565
566     InsertParallelMACs(R);
567     Changed = true;
568     AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
569   }
570
571   return Changed;
572 }
573
574 bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
575
576   // Not enough mul operations to make a pair.
577   if (R.getMuls().size() < 2)
578     return false;
579
580   // Check that the muls operate directly upon sign extended loads.
581   for (auto &MulChain : R.getMuls()) {
582     // A mul has 2 operands, and a narrow op consist of sext and a load; thus
583     // we expect at least 4 items in this operand value list.
584     if (MulChain->size() < 4) {
585       LLVM_DEBUG(dbgs() << "Operand list too short.\n");
586       return false;
587     }
588     MulChain->PopulateLoads();
589     ValueList &LHS = static_cast<BinOpChain*>(MulChain.get())->LHS;
590     ValueList &RHS = static_cast<BinOpChain*>(MulChain.get())->RHS;
591
592     // Use +=2 to skip over the expected extend instructions.
593     for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
594       if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
595         return false;
596     }
597   }
598
599   auto CanPair = [&](Reduction &R, BinOpChain *PMul0, BinOpChain *PMul1) {
600     if (!PMul0->AreSymmetrical(PMul1))
601       return false;
602
603     // The first elements of each vector should be loads with sexts. If we
604     // find that its two pairs of consecutive loads, then these can be
605     // transformed into two wider loads and the users can be replaced with
606     // DSP intrinsics.
607     for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
608       auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
609       auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
610       auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
611       auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
612
613       if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
614         return false;
615
616       LLVM_DEBUG(dbgs() << "Loads:\n"
617                  << " - " << *Ld0 << "\n"
618                  << " - " << *Ld1 << "\n"
619                  << " - " << *Ld2 << "\n"
620                  << " - " << *Ld3 << "\n");
621
622       if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
623         if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
624           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
625           R.AddMulPair(PMul0, PMul1);
626           return true;
627         } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
628           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
629           LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
630           PMul1->Exchange = true;
631           R.AddMulPair(PMul0, PMul1);
632           return true;
633         }
634       } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
635                  AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
636         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
637         LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
638         LLVM_DEBUG(dbgs() << "    and swapping muls\n");
639         PMul0->Exchange = true;
640         // Only the second operand can be exchanged, so swap the muls.
641         R.AddMulPair(PMul1, PMul0);
642         return true;
643       }
644     }
645     return false;
646   };
647
648   OpChainList &Muls = R.getMuls();
649   const unsigned Elems = Muls.size();
650   SmallPtrSet<const Instruction*, 4> Paired;
651   for (unsigned i = 0; i < Elems; ++i) {
652     BinOpChain *PMul0 = static_cast<BinOpChain*>(Muls[i].get());
653     if (Paired.count(PMul0->Root))
654       continue;
655
656     for (unsigned j = 0; j < Elems; ++j) {
657       if (i == j)
658         continue;
659
660       BinOpChain *PMul1 = static_cast<BinOpChain*>(Muls[j].get());
661       if (Paired.count(PMul1->Root))
662         continue;
663
664       const Instruction *Mul0 = PMul0->Root;
665       const Instruction *Mul1 = PMul1->Root;
666       if (Mul0 == Mul1)
667         continue;
668
669       assert(PMul0 != PMul1 && "expected different chains");
670
671       if (CanPair(R, PMul0, PMul1)) {
672         Paired.insert(Mul0);
673         Paired.insert(Mul1);
674         break;
675       }
676     }
677   }
678   return !R.getMulPairs().empty();
679 }
680
681
682 void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
683
684   auto CreateSMLADCall = [&](SmallVectorImpl<LoadInst*> &VecLd0,
685                              SmallVectorImpl<LoadInst*> &VecLd1,
686                              Value *Acc, bool Exchange,
687                              Instruction *InsertAfter) {
688     // Replace the reduction chain with an intrinsic call
689     IntegerType *Ty = IntegerType::get(M->getContext(), 32);
690     LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
691       WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty);
692     LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
693       WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty);
694
695     Value* Args[] = { WideLd0, WideLd1, Acc };
696     Function *SMLAD = nullptr;
697     if (Exchange)
698       SMLAD = Acc->getType()->isIntegerTy(32) ?
699         Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
700         Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
701     else
702       SMLAD = Acc->getType()->isIntegerTy(32) ?
703         Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
704         Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
705
706     IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
707                                 ++BasicBlock::iterator(InsertAfter));
708     Instruction *Call = Builder.CreateCall(SMLAD, Args);
709     NumSMLAD++;
710     return Call;
711   };
712
713   Instruction *InsertAfter = R.getRoot();
714   Value *Acc = R.getAccumulator();
715   if (!Acc)
716     Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
717
718   LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n"
719              << "Acc: " << *Acc << "\n");
720   for (auto &Pair : R.getMulPairs()) {
721     BinOpChain *PMul0 = Pair.first;
722     BinOpChain *PMul1 = Pair.second;
723     LLVM_DEBUG(dbgs() << "Muls:\n"
724                << "- " << *PMul0->Root << "\n"
725                << "- " << *PMul1->Root << "\n");
726
727     Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
728                           InsertAfter);
729     InsertAfter = cast<Instruction>(Acc);
730   }
731   R.UpdateRoot(cast<Instruction>(Acc));
732 }
733
734 LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
735                                          IntegerType *LoadTy) {
736   assert(Loads.size() == 2 && "currently only support widening two loads");
737
738   LoadInst *Base = Loads[0];
739   LoadInst *Offset = Loads[1];
740
741   Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
742   Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
743
744   assert((BaseSExt && OffsetSExt)
745          && "Loads should have a single, extending, user");
746
747   std::function<void(Value*, Value*)> MoveBefore =
748     [&](Value *A, Value *B) -> void {
749       if (!isa<Instruction>(A) || !isa<Instruction>(B))
750         return;
751
752       auto *Source = cast<Instruction>(A);
753       auto *Sink = cast<Instruction>(B);
754
755       if (DT->dominates(Source, Sink) ||
756           Source->getParent() != Sink->getParent() ||
757           isa<PHINode>(Source) || isa<PHINode>(Sink))
758         return;
759
760       Source->moveBefore(Sink);
761       for (auto &U : Source->uses())
762         MoveBefore(Source, U.getUser());
763     };
764
765   // Insert the load at the point of the original dominating load.
766   LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
767   IRBuilder<NoFolder> IRB(DomLoad->getParent(),
768                           ++BasicBlock::iterator(DomLoad));
769
770   // Bitcast the pointer to a wider type and create the wide load, while making
771   // sure to maintain the original alignment as this prevents ldrd from being
772   // generated when it could be illegal due to memory alignment.
773   const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
774   Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
775                                     LoadTy->getPointerTo(AddrSpace));
776   LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
777                                              Base->getAlignment());
778
779   // Make sure everything is in the correct order in the basic block.
780   MoveBefore(Base->getPointerOperand(), VecPtr);
781   MoveBefore(VecPtr, WideLoad);
782
783   // From the wide load, create two values that equal the original two loads.
784   // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
785   // TODO: Support big-endian as well.
786   Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
787   BaseSExt->setOperand(0, Bottom);
788
789   IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
790   Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
791   Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
792   Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
793   OffsetSExt->setOperand(0, Trunc);
794
795   WideLoads.emplace(std::make_pair(Base,
796                                    make_unique<WidenedLoad>(Loads, WideLoad)));
797   return WideLoad;
798 }
799
800 // Compare the value lists in Other to this chain.
801 bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
802   // Element-by-element comparison of Value lists returning true if they are
803   // instructions with the same opcode or constants with the same value.
804   auto CompareValueList = [](const ValueList &VL0,
805                              const ValueList &VL1) {
806     if (VL0.size() != VL1.size()) {
807       LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
808                         << VL0.size() << " != " << VL1.size() << "\n");
809       return false;
810     }
811
812     const unsigned Pairs = VL0.size();
813
814     for (unsigned i = 0; i < Pairs; ++i) {
815       const Value *V0 = VL0[i];
816       const Value *V1 = VL1[i];
817       const auto *Inst0 = dyn_cast<Instruction>(V0);
818       const auto *Inst1 = dyn_cast<Instruction>(V1);
819
820       if (!Inst0 || !Inst1)
821         return false;
822
823       if (Inst0->isSameOperationAs(Inst1))
824         continue;
825
826       const APInt *C0, *C1;
827       if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
828         return false;
829     }
830
831     return true;
832   };
833
834   return CompareValueList(LHS, Other->LHS) &&
835          CompareValueList(RHS, Other->RHS);
836 }
837
838 Pass *llvm::createARMParallelDSPPass() {
839   return new ARMParallelDSP();
840 }
841
842 char ARMParallelDSP::ID = 0;
843
844 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
845                 "Transform loops to use DSP intrinsics", false, false)
846 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
847                 "Transform loops to use DSP intrinsics", false, false)