1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Constant.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
42 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
46 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
48 static char ID; // Pass identification, replacement for typeid
50 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
51 initializeScalarizeMaskedMemIntrinLegacyPassPass(
52 *PassRegistry::getPassRegistry());
55 bool runOnFunction(Function &F) override;
57 StringRef getPassName() const override {
58 return "Scalarize Masked Memory Intrinsics";
61 void getAnalysisUsage(AnalysisUsage &AU) const override {
62 AU.addRequired<TargetTransformInfoWrapperPass>();
63 AU.addPreserved<DominatorTreeWrapperPass>();
67 } // end anonymous namespace
69 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
70 const TargetTransformInfo &TTI, const DataLayout &DL,
72 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
73 const TargetTransformInfo &TTI,
74 const DataLayout &DL, DomTreeUpdater *DTU);
76 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
78 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
79 "Scalarize unsupported masked memory intrinsics", false,
81 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
82 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
83 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
84 "Scalarize unsupported masked memory intrinsics", false,
87 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
88 return new ScalarizeMaskedMemIntrinLegacyPass();
91 static bool isConstantIntVector(Value *Mask) {
92 Constant *C = dyn_cast<Constant>(Mask);
96 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
97 for (unsigned i = 0; i != NumElts; ++i) {
98 Constant *CElt = C->getAggregateElement(i);
99 if (!CElt || !isa<ConstantInt>(CElt))
106 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
108 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
111 // Translate a masked load intrinsic like
112 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
113 // <16 x i1> %mask, <16 x i32> %passthru)
114 // to a chain of basic blocks, with loading element one-by-one if
115 // the appropriate mask bit is set
117 // %1 = bitcast i8* %addr to i32*
118 // %2 = extractelement <16 x i1> %mask, i32 0
119 // br i1 %2, label %cond.load, label %else
121 // cond.load: ; preds = %0
122 // %3 = getelementptr i32* %1, i32 0
124 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
127 // else: ; preds = %0, %cond.load
128 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
129 // %6 = extractelement <16 x i1> %mask, i32 1
130 // br i1 %6, label %cond.load1, label %else2
132 // cond.load1: ; preds = %else
133 // %7 = getelementptr i32* %1, i32 1
135 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
138 // else2: ; preds = %else, %cond.load1
139 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
140 // %10 = extractelement <16 x i1> %mask, i32 2
141 // br i1 %10, label %cond.load4, label %else5
143 static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
144 DomTreeUpdater *DTU, bool &ModifiedDT) {
145 Value *Ptr = CI->getArgOperand(0);
146 Value *Alignment = CI->getArgOperand(1);
147 Value *Mask = CI->getArgOperand(2);
148 Value *Src0 = CI->getArgOperand(3);
150 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
151 VectorType *VecType = cast<FixedVectorType>(CI->getType());
153 Type *EltTy = VecType->getElementType();
155 IRBuilder<> Builder(CI->getContext());
156 Instruction *InsertPt = CI;
157 BasicBlock *IfBlock = CI->getParent();
159 Builder.SetInsertPoint(InsertPt);
160 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
162 // Short-cut if the mask is all-true.
163 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
164 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
165 CI->replaceAllUsesWith(NewI);
166 CI->eraseFromParent();
170 // Adjust alignment for the scalar instruction.
171 const Align AdjustedAlignVal =
172 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
173 // Bitcast %addr from i8* to EltTy*
175 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
176 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
177 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
180 Value *VResult = Src0;
182 if (isConstantIntVector(Mask)) {
183 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
184 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
186 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
187 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
188 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
190 CI->replaceAllUsesWith(VResult);
191 CI->eraseFromParent();
195 // If the mask is not v1i1, use scalar bit test operations. This generates
196 // better results on X86 at least.
198 if (VectorWidth != 1) {
199 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
200 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
203 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
204 // Fill the "else" block, created in the previous iteration
206 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
207 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
208 // %cond = icmp ne i16 %mask_1, 0
209 // br i1 %mask_1, label %cond.load, label %else
212 if (VectorWidth != 1) {
213 Value *Mask = Builder.getInt(APInt::getOneBitSet(
214 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
215 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
216 Builder.getIntN(VectorWidth, 0));
218 Predicate = Builder.CreateExtractElement(Mask, Idx);
221 // Create "cond" block
223 // %EltAddr = getelementptr i32* %1, i32 0
224 // %Elt = load i32* %EltAddr
225 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
227 Instruction *ThenTerm =
228 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
229 /*BranchWeights=*/nullptr, DTU);
231 BasicBlock *CondBlock = ThenTerm->getParent();
232 CondBlock->setName("cond.load");
234 Builder.SetInsertPoint(CondBlock->getTerminator());
235 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
236 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
237 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
239 // Create "else" block, fill it in the next iteration
240 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
241 NewIfBlock->setName("else");
242 BasicBlock *PrevIfBlock = IfBlock;
243 IfBlock = NewIfBlock;
245 // Create the phi to join the new and previous value.
246 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
247 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
248 Phi->addIncoming(NewVResult, CondBlock);
249 Phi->addIncoming(VResult, PrevIfBlock);
253 CI->replaceAllUsesWith(VResult);
254 CI->eraseFromParent();
259 // Translate a masked store intrinsic, like
260 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
262 // to a chain of basic blocks, that stores element one-by-one if
263 // the appropriate mask bit is set
265 // %1 = bitcast i8* %addr to i32*
266 // %2 = extractelement <16 x i1> %mask, i32 0
267 // br i1 %2, label %cond.store, label %else
269 // cond.store: ; preds = %0
270 // %3 = extractelement <16 x i32> %val, i32 0
271 // %4 = getelementptr i32* %1, i32 0
272 // store i32 %3, i32* %4
275 // else: ; preds = %0, %cond.store
276 // %5 = extractelement <16 x i1> %mask, i32 1
277 // br i1 %5, label %cond.store1, label %else2
279 // cond.store1: ; preds = %else
280 // %6 = extractelement <16 x i32> %val, i32 1
281 // %7 = getelementptr i32* %1, i32 1
282 // store i32 %6, i32* %7
285 static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
286 DomTreeUpdater *DTU, bool &ModifiedDT) {
287 Value *Src = CI->getArgOperand(0);
288 Value *Ptr = CI->getArgOperand(1);
289 Value *Alignment = CI->getArgOperand(2);
290 Value *Mask = CI->getArgOperand(3);
292 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
293 auto *VecType = cast<VectorType>(Src->getType());
295 Type *EltTy = VecType->getElementType();
297 IRBuilder<> Builder(CI->getContext());
298 Instruction *InsertPt = CI;
299 Builder.SetInsertPoint(InsertPt);
300 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
302 // Short-cut if the mask is all-true.
303 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
304 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
305 CI->eraseFromParent();
309 // Adjust alignment for the scalar instruction.
310 const Align AdjustedAlignVal =
311 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
312 // Bitcast %addr from i8* to EltTy*
314 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
315 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
316 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
318 if (isConstantIntVector(Mask)) {
319 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
320 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
322 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
323 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
324 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
326 CI->eraseFromParent();
330 // If the mask is not v1i1, use scalar bit test operations. This generates
331 // better results on X86 at least.
333 if (VectorWidth != 1) {
334 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
335 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
338 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
339 // Fill the "else" block, created in the previous iteration
341 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
342 // %cond = icmp ne i16 %mask_1, 0
343 // br i1 %mask_1, label %cond.store, label %else
346 if (VectorWidth != 1) {
347 Value *Mask = Builder.getInt(APInt::getOneBitSet(
348 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
349 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
350 Builder.getIntN(VectorWidth, 0));
352 Predicate = Builder.CreateExtractElement(Mask, Idx);
355 // Create "cond" block
357 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
358 // %EltAddr = getelementptr i32* %1, i32 0
359 // %store i32 %OneElt, i32* %EltAddr
361 Instruction *ThenTerm =
362 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
363 /*BranchWeights=*/nullptr, DTU);
365 BasicBlock *CondBlock = ThenTerm->getParent();
366 CondBlock->setName("cond.store");
368 Builder.SetInsertPoint(CondBlock->getTerminator());
369 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
370 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
371 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
373 // Create "else" block, fill it in the next iteration
374 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
375 NewIfBlock->setName("else");
377 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
379 CI->eraseFromParent();
384 // Translate a masked gather intrinsic like
385 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
386 // <16 x i1> %Mask, <16 x i32> %Src)
387 // to a chain of basic blocks, with loading element one-by-one if
388 // the appropriate mask bit is set
390 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
391 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
392 // br i1 %Mask0, label %cond.load, label %else
395 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
396 // %Load0 = load i32, i32* %Ptr0, align 4
397 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
401 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
402 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
403 // br i1 %Mask1, label %cond.load1, label %else2
406 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
407 // %Load1 = load i32, i32* %Ptr1, align 4
408 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
411 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
412 // ret <16 x i32> %Result
413 static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
414 DomTreeUpdater *DTU, bool &ModifiedDT) {
415 Value *Ptrs = CI->getArgOperand(0);
416 Value *Alignment = CI->getArgOperand(1);
417 Value *Mask = CI->getArgOperand(2);
418 Value *Src0 = CI->getArgOperand(3);
420 auto *VecType = cast<FixedVectorType>(CI->getType());
421 Type *EltTy = VecType->getElementType();
423 IRBuilder<> Builder(CI->getContext());
424 Instruction *InsertPt = CI;
425 BasicBlock *IfBlock = CI->getParent();
426 Builder.SetInsertPoint(InsertPt);
427 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
429 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
432 Value *VResult = Src0;
433 unsigned VectorWidth = VecType->getNumElements();
435 // Shorten the way if the mask is a vector of constants.
436 if (isConstantIntVector(Mask)) {
437 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
438 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
440 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
442 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
444 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
446 CI->replaceAllUsesWith(VResult);
447 CI->eraseFromParent();
451 // If the mask is not v1i1, use scalar bit test operations. This generates
452 // better results on X86 at least.
454 if (VectorWidth != 1) {
455 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
456 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
459 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
460 // Fill the "else" block, created in the previous iteration
462 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
463 // %cond = icmp ne i16 %mask_1, 0
464 // br i1 %Mask1, label %cond.load, label %else
468 if (VectorWidth != 1) {
469 Value *Mask = Builder.getInt(APInt::getOneBitSet(
470 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
471 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
472 Builder.getIntN(VectorWidth, 0));
474 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
477 // Create "cond" block
479 // %EltAddr = getelementptr i32* %1, i32 0
480 // %Elt = load i32* %EltAddr
481 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
483 Instruction *ThenTerm =
484 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
485 /*BranchWeights=*/nullptr, DTU);
487 BasicBlock *CondBlock = ThenTerm->getParent();
488 CondBlock->setName("cond.load");
490 Builder.SetInsertPoint(CondBlock->getTerminator());
491 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
493 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
495 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
497 // Create "else" block, fill it in the next iteration
498 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
499 NewIfBlock->setName("else");
500 BasicBlock *PrevIfBlock = IfBlock;
501 IfBlock = NewIfBlock;
503 // Create the phi to join the new and previous value.
504 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
505 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
506 Phi->addIncoming(NewVResult, CondBlock);
507 Phi->addIncoming(VResult, PrevIfBlock);
511 CI->replaceAllUsesWith(VResult);
512 CI->eraseFromParent();
517 // Translate a masked scatter intrinsic, like
518 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
520 // to a chain of basic blocks, that stores element one-by-one if
521 // the appropriate mask bit is set.
523 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
524 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
525 // br i1 %Mask0, label %cond.store, label %else
528 // %Elt0 = extractelement <16 x i32> %Src, i32 0
529 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
530 // store i32 %Elt0, i32* %Ptr0, align 4
534 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
535 // br i1 %Mask1, label %cond.store1, label %else2
538 // %Elt1 = extractelement <16 x i32> %Src, i32 1
539 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
540 // store i32 %Elt1, i32* %Ptr1, align 4
543 static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
544 DomTreeUpdater *DTU, bool &ModifiedDT) {
545 Value *Src = CI->getArgOperand(0);
546 Value *Ptrs = CI->getArgOperand(1);
547 Value *Alignment = CI->getArgOperand(2);
548 Value *Mask = CI->getArgOperand(3);
550 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
553 isa<VectorType>(Ptrs->getType()) &&
554 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
555 "Vector of pointers is expected in masked scatter intrinsic");
557 IRBuilder<> Builder(CI->getContext());
558 Instruction *InsertPt = CI;
559 Builder.SetInsertPoint(InsertPt);
560 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
562 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
563 unsigned VectorWidth = SrcFVTy->getNumElements();
565 // Shorten the way if the mask is a vector of constants.
566 if (isConstantIntVector(Mask)) {
567 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
568 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
571 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
572 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
573 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
575 CI->eraseFromParent();
579 // If the mask is not v1i1, use scalar bit test operations. This generates
580 // better results on X86 at least.
582 if (VectorWidth != 1) {
583 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
584 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
587 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
588 // Fill the "else" block, created in the previous iteration
590 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
591 // %cond = icmp ne i16 %mask_1, 0
592 // br i1 %Mask1, label %cond.store, label %else
595 if (VectorWidth != 1) {
596 Value *Mask = Builder.getInt(APInt::getOneBitSet(
597 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
598 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
599 Builder.getIntN(VectorWidth, 0));
601 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
604 // Create "cond" block
606 // %Elt1 = extractelement <16 x i32> %Src, i32 1
607 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
608 // %store i32 %Elt1, i32* %Ptr1
610 Instruction *ThenTerm =
611 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
612 /*BranchWeights=*/nullptr, DTU);
614 BasicBlock *CondBlock = ThenTerm->getParent();
615 CondBlock->setName("cond.store");
617 Builder.SetInsertPoint(CondBlock->getTerminator());
618 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
619 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
620 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
622 // Create "else" block, fill it in the next iteration
623 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
624 NewIfBlock->setName("else");
626 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
628 CI->eraseFromParent();
633 static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
634 DomTreeUpdater *DTU, bool &ModifiedDT) {
635 Value *Ptr = CI->getArgOperand(0);
636 Value *Mask = CI->getArgOperand(1);
637 Value *PassThru = CI->getArgOperand(2);
639 auto *VecType = cast<FixedVectorType>(CI->getType());
641 Type *EltTy = VecType->getElementType();
643 IRBuilder<> Builder(CI->getContext());
644 Instruction *InsertPt = CI;
645 BasicBlock *IfBlock = CI->getParent();
647 Builder.SetInsertPoint(InsertPt);
648 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
650 unsigned VectorWidth = VecType->getNumElements();
653 Value *VResult = PassThru;
655 // Shorten the way if the mask is a vector of constants.
656 // Create a build_vector pattern, with loads/undefs as necessary and then
657 // shuffle blend with the pass through value.
658 if (isConstantIntVector(Mask)) {
659 unsigned MemIndex = 0;
660 VResult = PoisonValue::get(VecType);
661 SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
662 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
664 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
665 InsertElt = UndefValue::get(EltTy);
666 ShuffleMask[Idx] = Idx + VectorWidth;
669 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
670 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
671 "Load" + Twine(Idx));
672 ShuffleMask[Idx] = Idx;
675 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
678 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
679 CI->replaceAllUsesWith(VResult);
680 CI->eraseFromParent();
684 // If the mask is not v1i1, use scalar bit test operations. This generates
685 // better results on X86 at least.
687 if (VectorWidth != 1) {
688 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
689 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
692 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
693 // Fill the "else" block, created in the previous iteration
695 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
696 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
697 // br i1 %mask_1, label %cond.load, label %else
701 if (VectorWidth != 1) {
702 Value *Mask = Builder.getInt(APInt::getOneBitSet(
703 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
704 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
705 Builder.getIntN(VectorWidth, 0));
707 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
710 // Create "cond" block
712 // %EltAddr = getelementptr i32* %1, i32 0
713 // %Elt = load i32* %EltAddr
714 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
716 Instruction *ThenTerm =
717 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
718 /*BranchWeights=*/nullptr, DTU);
720 BasicBlock *CondBlock = ThenTerm->getParent();
721 CondBlock->setName("cond.load");
723 Builder.SetInsertPoint(CondBlock->getTerminator());
724 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
725 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
727 // Move the pointer if there are more blocks to come.
729 if ((Idx + 1) != VectorWidth)
730 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
732 // Create "else" block, fill it in the next iteration
733 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
734 NewIfBlock->setName("else");
735 BasicBlock *PrevIfBlock = IfBlock;
736 IfBlock = NewIfBlock;
738 // Create the phi to join the new and previous value.
739 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
740 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
741 ResultPhi->addIncoming(NewVResult, CondBlock);
742 ResultPhi->addIncoming(VResult, PrevIfBlock);
745 // Add a PHI for the pointer if this isn't the last iteration.
746 if ((Idx + 1) != VectorWidth) {
747 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
748 PtrPhi->addIncoming(NewPtr, CondBlock);
749 PtrPhi->addIncoming(Ptr, PrevIfBlock);
754 CI->replaceAllUsesWith(VResult);
755 CI->eraseFromParent();
760 static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
763 Value *Src = CI->getArgOperand(0);
764 Value *Ptr = CI->getArgOperand(1);
765 Value *Mask = CI->getArgOperand(2);
767 auto *VecType = cast<FixedVectorType>(Src->getType());
769 IRBuilder<> Builder(CI->getContext());
770 Instruction *InsertPt = CI;
771 BasicBlock *IfBlock = CI->getParent();
773 Builder.SetInsertPoint(InsertPt);
774 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
776 Type *EltTy = VecType->getElementType();
778 unsigned VectorWidth = VecType->getNumElements();
780 // Shorten the way if the mask is a vector of constants.
781 if (isConstantIntVector(Mask)) {
782 unsigned MemIndex = 0;
783 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
784 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
787 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
788 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
789 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
792 CI->eraseFromParent();
796 // If the mask is not v1i1, use scalar bit test operations. This generates
797 // better results on X86 at least.
799 if (VectorWidth != 1) {
800 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
801 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
804 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
805 // Fill the "else" block, created in the previous iteration
807 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
808 // br i1 %mask_1, label %cond.store, label %else
811 if (VectorWidth != 1) {
812 Value *Mask = Builder.getInt(APInt::getOneBitSet(
813 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
814 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
815 Builder.getIntN(VectorWidth, 0));
817 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
820 // Create "cond" block
822 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
823 // %EltAddr = getelementptr i32* %1, i32 0
824 // %store i32 %OneElt, i32* %EltAddr
826 Instruction *ThenTerm =
827 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
828 /*BranchWeights=*/nullptr, DTU);
830 BasicBlock *CondBlock = ThenTerm->getParent();
831 CondBlock->setName("cond.store");
833 Builder.SetInsertPoint(CondBlock->getTerminator());
834 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
835 Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
837 // Move the pointer if there are more blocks to come.
839 if ((Idx + 1) != VectorWidth)
840 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
842 // Create "else" block, fill it in the next iteration
843 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
844 NewIfBlock->setName("else");
845 BasicBlock *PrevIfBlock = IfBlock;
846 IfBlock = NewIfBlock;
848 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
850 // Add a PHI for the pointer if this isn't the last iteration.
851 if ((Idx + 1) != VectorWidth) {
852 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
853 PtrPhi->addIncoming(NewPtr, CondBlock);
854 PtrPhi->addIncoming(Ptr, PrevIfBlock);
858 CI->eraseFromParent();
863 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
865 std::optional<DomTreeUpdater> DTU;
867 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
869 bool EverMadeChange = false;
870 bool MadeChange = true;
871 auto &DL = F.getParent()->getDataLayout();
874 for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
875 bool ModifiedDTOnIteration = false;
876 MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
877 DTU ? &*DTU : nullptr);
879 // Restart BB iteration if the dominator tree of the Function was changed
880 if (ModifiedDTOnIteration)
884 EverMadeChange |= MadeChange;
886 return EverMadeChange;
889 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
890 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
891 DominatorTree *DT = nullptr;
892 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
893 DT = &DTWP->getDomTree();
894 return runImpl(F, TTI, DT);
898 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
899 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
900 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
901 if (!runImpl(F, TTI, DT))
902 return PreservedAnalyses::all();
903 PreservedAnalyses PA;
904 PA.preserve<TargetIRAnalysis>();
905 PA.preserve<DominatorTreeAnalysis>();
909 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
910 const TargetTransformInfo &TTI, const DataLayout &DL,
911 DomTreeUpdater *DTU) {
912 bool MadeChange = false;
914 BasicBlock::iterator CurInstIterator = BB.begin();
915 while (CurInstIterator != BB.end()) {
916 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
917 MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
925 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
926 const TargetTransformInfo &TTI,
927 const DataLayout &DL, DomTreeUpdater *DTU) {
928 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
930 // The scalarization code below does not work for scalable vectors.
931 if (isa<ScalableVectorType>(II->getType()) ||
933 [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
936 switch (II->getIntrinsicID()) {
939 case Intrinsic::masked_load:
940 // Scalarize unsupported vector masked load
941 if (TTI.isLegalMaskedLoad(
943 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
945 scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
947 case Intrinsic::masked_store:
948 if (TTI.isLegalMaskedStore(
949 CI->getArgOperand(0)->getType(),
950 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
952 scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
954 case Intrinsic::masked_gather: {
956 cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue();
957 Type *LoadTy = CI->getType();
958 Align Alignment = DL.getValueOrABITypeAlignment(MA,
959 LoadTy->getScalarType());
960 if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
961 !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
963 scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
966 case Intrinsic::masked_scatter: {
968 cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue();
969 Type *StoreTy = CI->getArgOperand(0)->getType();
970 Align Alignment = DL.getValueOrABITypeAlignment(MA,
971 StoreTy->getScalarType());
972 if (TTI.isLegalMaskedScatter(StoreTy, Alignment) &&
973 !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
976 scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
979 case Intrinsic::masked_expandload:
980 if (TTI.isLegalMaskedExpandLoad(CI->getType()))
982 scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
984 case Intrinsic::masked_compressstore:
985 if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
987 scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);