]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
Merge llvm, clang, compiler-rt, libc++, libunwind, lld, lldb and openmp
[FreeBSD/FreeBSD.git] / contrib / llvm / lib / CodeGen / ScalarizeMaskedMemIntrin.cpp
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 //                                    instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/CodeGen/TargetSubtargetInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <cassert>
36
37 using namespace llvm;
38
39 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
40
41 namespace {
42
43 class ScalarizeMaskedMemIntrin : public FunctionPass {
44   const TargetTransformInfo *TTI = nullptr;
45
46 public:
47   static char ID; // Pass identification, replacement for typeid
48
49   explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
50     initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
51   }
52
53   bool runOnFunction(Function &F) override;
54
55   StringRef getPassName() const override {
56     return "Scalarize Masked Memory Intrinsics";
57   }
58
59   void getAnalysisUsage(AnalysisUsage &AU) const override {
60     AU.addRequired<TargetTransformInfoWrapperPass>();
61   }
62
63 private:
64   bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
65   bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
66 };
67
68 } // end anonymous namespace
69
70 char ScalarizeMaskedMemIntrin::ID = 0;
71
72 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
73                 "Scalarize unsupported masked memory intrinsics", false, false)
74
75 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
76   return new ScalarizeMaskedMemIntrin();
77 }
78
79 static bool isConstantIntVector(Value *Mask) {
80   Constant *C = dyn_cast<Constant>(Mask);
81   if (!C)
82     return false;
83
84   unsigned NumElts = Mask->getType()->getVectorNumElements();
85   for (unsigned i = 0; i != NumElts; ++i) {
86     Constant *CElt = C->getAggregateElement(i);
87     if (!CElt || !isa<ConstantInt>(CElt))
88       return false;
89   }
90
91   return true;
92 }
93
94 // Translate a masked load intrinsic like
95 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96 //                               <16 x i1> %mask, <16 x i32> %passthru)
97 // to a chain of basic blocks, with loading element one-by-one if
98 // the appropriate mask bit is set
99 //
100 //  %1 = bitcast i8* %addr to i32*
101 //  %2 = extractelement <16 x i1> %mask, i32 0
102 //  br i1 %2, label %cond.load, label %else
103 //
104 // cond.load:                                        ; preds = %0
105 //  %3 = getelementptr i32* %1, i32 0
106 //  %4 = load i32* %3
107 //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
108 //  br label %else
109 //
110 // else:                                             ; preds = %0, %cond.load
111 //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112 //  %6 = extractelement <16 x i1> %mask, i32 1
113 //  br i1 %6, label %cond.load1, label %else2
114 //
115 // cond.load1:                                       ; preds = %else
116 //  %7 = getelementptr i32* %1, i32 1
117 //  %8 = load i32* %7
118 //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
119 //  br label %else2
120 //
121 // else2:                                          ; preds = %else, %cond.load1
122 //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123 //  %10 = extractelement <16 x i1> %mask, i32 2
124 //  br i1 %10, label %cond.load4, label %else5
125 //
126 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
127   Value *Ptr = CI->getArgOperand(0);
128   Value *Alignment = CI->getArgOperand(1);
129   Value *Mask = CI->getArgOperand(2);
130   Value *Src0 = CI->getArgOperand(3);
131
132   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
133   VectorType *VecType = cast<VectorType>(CI->getType());
134
135   Type *EltTy = VecType->getElementType();
136
137   IRBuilder<> Builder(CI->getContext());
138   Instruction *InsertPt = CI;
139   BasicBlock *IfBlock = CI->getParent();
140
141   Builder.SetInsertPoint(InsertPt);
142   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
143
144   // Short-cut if the mask is all-true.
145   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
146     Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
147     CI->replaceAllUsesWith(NewI);
148     CI->eraseFromParent();
149     return;
150   }
151
152   // Adjust alignment for the scalar instruction.
153   AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
154   // Bitcast %addr from i8* to EltTy*
155   Type *NewPtrType =
156       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
157   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
158   unsigned VectorWidth = VecType->getNumElements();
159
160   // The result vector
161   Value *VResult = Src0;
162
163   if (isConstantIntVector(Mask)) {
164     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
165       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
166         continue;
167       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
168       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
169       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
170     }
171     CI->replaceAllUsesWith(VResult);
172     CI->eraseFromParent();
173     return;
174   }
175
176   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
177     // Fill the "else" block, created in the previous iteration
178     //
179     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
180     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
181     //  br i1 %mask_1, label %cond.load, label %else
182     //
183
184     Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
185
186     // Create "cond" block
187     //
188     //  %EltAddr = getelementptr i32* %1, i32 0
189     //  %Elt = load i32* %EltAddr
190     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
191     //
192     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
193                                                      "cond.load");
194     Builder.SetInsertPoint(InsertPt);
195
196     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
197     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
198     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
199
200     // Create "else" block, fill it in the next iteration
201     BasicBlock *NewIfBlock =
202         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
203     Builder.SetInsertPoint(InsertPt);
204     Instruction *OldBr = IfBlock->getTerminator();
205     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
206     OldBr->eraseFromParent();
207     BasicBlock *PrevIfBlock = IfBlock;
208     IfBlock = NewIfBlock;
209
210     // Create the phi to join the new and previous value.
211     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
212     Phi->addIncoming(NewVResult, CondBlock);
213     Phi->addIncoming(VResult, PrevIfBlock);
214     VResult = Phi;
215   }
216
217   CI->replaceAllUsesWith(VResult);
218   CI->eraseFromParent();
219
220   ModifiedDT = true;
221 }
222
223 // Translate a masked store intrinsic, like
224 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
225 //                               <16 x i1> %mask)
226 // to a chain of basic blocks, that stores element one-by-one if
227 // the appropriate mask bit is set
228 //
229 //   %1 = bitcast i8* %addr to i32*
230 //   %2 = extractelement <16 x i1> %mask, i32 0
231 //   br i1 %2, label %cond.store, label %else
232 //
233 // cond.store:                                       ; preds = %0
234 //   %3 = extractelement <16 x i32> %val, i32 0
235 //   %4 = getelementptr i32* %1, i32 0
236 //   store i32 %3, i32* %4
237 //   br label %else
238 //
239 // else:                                             ; preds = %0, %cond.store
240 //   %5 = extractelement <16 x i1> %mask, i32 1
241 //   br i1 %5, label %cond.store1, label %else2
242 //
243 // cond.store1:                                      ; preds = %else
244 //   %6 = extractelement <16 x i32> %val, i32 1
245 //   %7 = getelementptr i32* %1, i32 1
246 //   store i32 %6, i32* %7
247 //   br label %else2
248 //   . . .
249 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
250   Value *Src = CI->getArgOperand(0);
251   Value *Ptr = CI->getArgOperand(1);
252   Value *Alignment = CI->getArgOperand(2);
253   Value *Mask = CI->getArgOperand(3);
254
255   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
256   VectorType *VecType = cast<VectorType>(Src->getType());
257
258   Type *EltTy = VecType->getElementType();
259
260   IRBuilder<> Builder(CI->getContext());
261   Instruction *InsertPt = CI;
262   BasicBlock *IfBlock = CI->getParent();
263   Builder.SetInsertPoint(InsertPt);
264   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
265
266   // Short-cut if the mask is all-true.
267   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
268     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
269     CI->eraseFromParent();
270     return;
271   }
272
273   // Adjust alignment for the scalar instruction.
274   AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
275   // Bitcast %addr from i8* to EltTy*
276   Type *NewPtrType =
277       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
278   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
279   unsigned VectorWidth = VecType->getNumElements();
280
281   if (isConstantIntVector(Mask)) {
282     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
283       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
284         continue;
285       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
286       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
287       Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
288     }
289     CI->eraseFromParent();
290     return;
291   }
292
293   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
294     // Fill the "else" block, created in the previous iteration
295     //
296     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
297     //  br i1 %mask_1, label %cond.store, label %else
298     //
299     Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
300
301     // Create "cond" block
302     //
303     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
304     //  %EltAddr = getelementptr i32* %1, i32 0
305     //  %store i32 %OneElt, i32* %EltAddr
306     //
307     BasicBlock *CondBlock =
308         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
309     Builder.SetInsertPoint(InsertPt);
310
311     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
312     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
313     Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
314
315     // Create "else" block, fill it in the next iteration
316     BasicBlock *NewIfBlock =
317         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
318     Builder.SetInsertPoint(InsertPt);
319     Instruction *OldBr = IfBlock->getTerminator();
320     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
321     OldBr->eraseFromParent();
322     IfBlock = NewIfBlock;
323   }
324   CI->eraseFromParent();
325
326   ModifiedDT = true;
327 }
328
329 // Translate a masked gather intrinsic like
330 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
331 //                               <16 x i1> %Mask, <16 x i32> %Src)
332 // to a chain of basic blocks, with loading element one-by-one if
333 // the appropriate mask bit is set
334 //
335 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
336 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
337 // br i1 %Mask0, label %cond.load, label %else
338 //
339 // cond.load:
340 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
341 // %Load0 = load i32, i32* %Ptr0, align 4
342 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
343 // br label %else
344 //
345 // else:
346 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
347 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
348 // br i1 %Mask1, label %cond.load1, label %else2
349 //
350 // cond.load1:
351 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
352 // %Load1 = load i32, i32* %Ptr1, align 4
353 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
354 // br label %else2
355 // . . .
356 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
357 // ret <16 x i32> %Result
358 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
359   Value *Ptrs = CI->getArgOperand(0);
360   Value *Alignment = CI->getArgOperand(1);
361   Value *Mask = CI->getArgOperand(2);
362   Value *Src0 = CI->getArgOperand(3);
363
364   VectorType *VecType = cast<VectorType>(CI->getType());
365   Type *EltTy = VecType->getElementType();
366
367   IRBuilder<> Builder(CI->getContext());
368   Instruction *InsertPt = CI;
369   BasicBlock *IfBlock = CI->getParent();
370   Builder.SetInsertPoint(InsertPt);
371   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
372
373   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
374
375   // The result vector
376   Value *VResult = Src0;
377   unsigned VectorWidth = VecType->getNumElements();
378
379   // Shorten the way if the mask is a vector of constants.
380   if (isConstantIntVector(Mask)) {
381     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
382       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
383         continue;
384       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
385       LoadInst *Load =
386           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
387       VResult =
388           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
389     }
390     CI->replaceAllUsesWith(VResult);
391     CI->eraseFromParent();
392     return;
393   }
394
395   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
396     // Fill the "else" block, created in the previous iteration
397     //
398     //  %Mask1 = extractelement <16 x i1> %Mask, i32 1
399     //  br i1 %Mask1, label %cond.load, label %else
400     //
401
402     Value *Predicate =
403         Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
404
405     // Create "cond" block
406     //
407     //  %EltAddr = getelementptr i32* %1, i32 0
408     //  %Elt = load i32* %EltAddr
409     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
410     //
411     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
412     Builder.SetInsertPoint(InsertPt);
413
414     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
415     LoadInst *Load =
416         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
417     Value *NewVResult =
418         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
419
420     // Create "else" block, fill it in the next iteration
421     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
422     Builder.SetInsertPoint(InsertPt);
423     Instruction *OldBr = IfBlock->getTerminator();
424     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
425     OldBr->eraseFromParent();
426     BasicBlock *PrevIfBlock = IfBlock;
427     IfBlock = NewIfBlock;
428
429     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
430     Phi->addIncoming(NewVResult, CondBlock);
431     Phi->addIncoming(VResult, PrevIfBlock);
432     VResult = Phi;
433   }
434
435   CI->replaceAllUsesWith(VResult);
436   CI->eraseFromParent();
437
438   ModifiedDT = true;
439 }
440
441 // Translate a masked scatter intrinsic, like
442 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
443 //                                  <16 x i1> %Mask)
444 // to a chain of basic blocks, that stores element one-by-one if
445 // the appropriate mask bit is set.
446 //
447 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
448 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
449 // br i1 %Mask0, label %cond.store, label %else
450 //
451 // cond.store:
452 // %Elt0 = extractelement <16 x i32> %Src, i32 0
453 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454 // store i32 %Elt0, i32* %Ptr0, align 4
455 // br label %else
456 //
457 // else:
458 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
459 // br i1 %Mask1, label %cond.store1, label %else2
460 //
461 // cond.store1:
462 // %Elt1 = extractelement <16 x i32> %Src, i32 1
463 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
464 // store i32 %Elt1, i32* %Ptr1, align 4
465 // br label %else2
466 //   . . .
467 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
468   Value *Src = CI->getArgOperand(0);
469   Value *Ptrs = CI->getArgOperand(1);
470   Value *Alignment = CI->getArgOperand(2);
471   Value *Mask = CI->getArgOperand(3);
472
473   assert(isa<VectorType>(Src->getType()) &&
474          "Unexpected data type in masked scatter intrinsic");
475   assert(isa<VectorType>(Ptrs->getType()) &&
476          isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
477          "Vector of pointers is expected in masked scatter intrinsic");
478
479   IRBuilder<> Builder(CI->getContext());
480   Instruction *InsertPt = CI;
481   BasicBlock *IfBlock = CI->getParent();
482   Builder.SetInsertPoint(InsertPt);
483   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
484
485   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
486   unsigned VectorWidth = Src->getType()->getVectorNumElements();
487
488   // Shorten the way if the mask is a vector of constants.
489   if (isConstantIntVector(Mask)) {
490     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
491       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
492         continue;
493       Value *OneElt =
494           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
495       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
496       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
497     }
498     CI->eraseFromParent();
499     return;
500   }
501
502   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
503     // Fill the "else" block, created in the previous iteration
504     //
505     //  %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
506     //  br i1 %Mask1, label %cond.store, label %else
507     //
508     Value *Predicate =
509         Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
510
511     // Create "cond" block
512     //
513     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
514     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
515     //  %store i32 %Elt1, i32* %Ptr1
516     //
517     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
518     Builder.SetInsertPoint(InsertPt);
519
520     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
521     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
522     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
523
524     // Create "else" block, fill it in the next iteration
525     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
526     Builder.SetInsertPoint(InsertPt);
527     Instruction *OldBr = IfBlock->getTerminator();
528     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
529     OldBr->eraseFromParent();
530     IfBlock = NewIfBlock;
531   }
532   CI->eraseFromParent();
533
534   ModifiedDT = true;
535 }
536
537 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
538   Value *Ptr = CI->getArgOperand(0);
539   Value *Mask = CI->getArgOperand(1);
540   Value *PassThru = CI->getArgOperand(2);
541
542   VectorType *VecType = cast<VectorType>(CI->getType());
543
544   Type *EltTy = VecType->getElementType();
545
546   IRBuilder<> Builder(CI->getContext());
547   Instruction *InsertPt = CI;
548   BasicBlock *IfBlock = CI->getParent();
549
550   Builder.SetInsertPoint(InsertPt);
551   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
552
553   unsigned VectorWidth = VecType->getNumElements();
554
555   // The result vector
556   Value *VResult = PassThru;
557
558   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
559     // Fill the "else" block, created in the previous iteration
560     //
561     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
562     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
563     //  br i1 %mask_1, label %cond.load, label %else
564     //
565
566     Value *Predicate =
567         Builder.CreateExtractElement(Mask, Idx);
568
569     // Create "cond" block
570     //
571     //  %EltAddr = getelementptr i32* %1, i32 0
572     //  %Elt = load i32* %EltAddr
573     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
574     //
575     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
576                                                      "cond.load");
577     Builder.SetInsertPoint(InsertPt);
578
579     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
580     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
581
582     // Move the pointer if there are more blocks to come.
583     Value *NewPtr;
584     if ((Idx + 1) != VectorWidth)
585       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
586
587     // Create "else" block, fill it in the next iteration
588     BasicBlock *NewIfBlock =
589         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
590     Builder.SetInsertPoint(InsertPt);
591     Instruction *OldBr = IfBlock->getTerminator();
592     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
593     OldBr->eraseFromParent();
594     BasicBlock *PrevIfBlock = IfBlock;
595     IfBlock = NewIfBlock;
596
597     // Create the phi to join the new and previous value.
598     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
599     ResultPhi->addIncoming(NewVResult, CondBlock);
600     ResultPhi->addIncoming(VResult, PrevIfBlock);
601     VResult = ResultPhi;
602
603     // Add a PHI for the pointer if this isn't the last iteration.
604     if ((Idx + 1) != VectorWidth) {
605       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
606       PtrPhi->addIncoming(NewPtr, CondBlock);
607       PtrPhi->addIncoming(Ptr, PrevIfBlock);
608       Ptr = PtrPhi;
609     }
610   }
611
612   CI->replaceAllUsesWith(VResult);
613   CI->eraseFromParent();
614
615   ModifiedDT = true;
616 }
617
618 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
619   Value *Src = CI->getArgOperand(0);
620   Value *Ptr = CI->getArgOperand(1);
621   Value *Mask = CI->getArgOperand(2);
622
623   VectorType *VecType = cast<VectorType>(Src->getType());
624
625   IRBuilder<> Builder(CI->getContext());
626   Instruction *InsertPt = CI;
627   BasicBlock *IfBlock = CI->getParent();
628
629   Builder.SetInsertPoint(InsertPt);
630   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
631
632   Type *EltTy = VecType->getVectorElementType();
633
634   unsigned VectorWidth = VecType->getNumElements();
635
636   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
637     // Fill the "else" block, created in the previous iteration
638     //
639     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
640     //  br i1 %mask_1, label %cond.store, label %else
641     //
642     Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
643
644     // Create "cond" block
645     //
646     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
647     //  %EltAddr = getelementptr i32* %1, i32 0
648     //  %store i32 %OneElt, i32* %EltAddr
649     //
650     BasicBlock *CondBlock =
651         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
652     Builder.SetInsertPoint(InsertPt);
653
654     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
655     Builder.CreateAlignedStore(OneElt, Ptr, 1);
656
657     // Move the pointer if there are more blocks to come.
658     Value *NewPtr;
659     if ((Idx + 1) != VectorWidth)
660       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
661
662     // Create "else" block, fill it in the next iteration
663     BasicBlock *NewIfBlock =
664         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
665     Builder.SetInsertPoint(InsertPt);
666     Instruction *OldBr = IfBlock->getTerminator();
667     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
668     OldBr->eraseFromParent();
669     BasicBlock *PrevIfBlock = IfBlock;
670     IfBlock = NewIfBlock;
671
672     // Add a PHI for the pointer if this isn't the last iteration.
673     if ((Idx + 1) != VectorWidth) {
674       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
675       PtrPhi->addIncoming(NewPtr, CondBlock);
676       PtrPhi->addIncoming(Ptr, PrevIfBlock);
677       Ptr = PtrPhi;
678     }
679   }
680   CI->eraseFromParent();
681
682   ModifiedDT = true;
683 }
684
685 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
686   bool EverMadeChange = false;
687
688   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
689
690   bool MadeChange = true;
691   while (MadeChange) {
692     MadeChange = false;
693     for (Function::iterator I = F.begin(); I != F.end();) {
694       BasicBlock *BB = &*I++;
695       bool ModifiedDTOnIteration = false;
696       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
697
698       // Restart BB iteration if the dominator tree of the Function was changed
699       if (ModifiedDTOnIteration)
700         break;
701     }
702
703     EverMadeChange |= MadeChange;
704   }
705
706   return EverMadeChange;
707 }
708
709 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
710   bool MadeChange = false;
711
712   BasicBlock::iterator CurInstIterator = BB.begin();
713   while (CurInstIterator != BB.end()) {
714     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
715       MadeChange |= optimizeCallInst(CI, ModifiedDT);
716     if (ModifiedDT)
717       return true;
718   }
719
720   return MadeChange;
721 }
722
723 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
724                                                 bool &ModifiedDT) {
725   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
726   if (II) {
727     switch (II->getIntrinsicID()) {
728     default:
729       break;
730     case Intrinsic::masked_load:
731       // Scalarize unsupported vector masked load
732       if (TTI->isLegalMaskedLoad(CI->getType()))
733         return false;
734       scalarizeMaskedLoad(CI, ModifiedDT);
735       return true;
736     case Intrinsic::masked_store:
737       if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
738         return false;
739       scalarizeMaskedStore(CI, ModifiedDT);
740       return true;
741     case Intrinsic::masked_gather:
742       if (TTI->isLegalMaskedGather(CI->getType()))
743         return false;
744       scalarizeMaskedGather(CI, ModifiedDT);
745       return true;
746     case Intrinsic::masked_scatter:
747       if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
748         return false;
749       scalarizeMaskedScatter(CI, ModifiedDT);
750       return true;
751     case Intrinsic::masked_expandload:
752       if (TTI->isLegalMaskedExpandLoad(CI->getType()))
753         return false;
754       scalarizeMaskedExpandLoad(CI, ModifiedDT);
755       return true;
756     case Intrinsic::masked_compressstore:
757       if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
758         return false;
759       scalarizeMaskedCompressStore(CI, ModifiedDT);
760       return true;
761     }
762   }
763
764   return false;
765 }