1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
13 //===----------------------------------------------------------------------===//
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/CodeGen/TargetLowering.h"
21 #include "llvm/CodeGen/TargetPassConfig.h"
22 #include "llvm/CodeGen/TargetSubtargetInfo.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constant.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/InstrTypes.h"
30 #include "llvm/IR/Instruction.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/IntrinsicsARM.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/PatternMatch.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Transforms/Utils/Local.h"
47 #define DEBUG_TYPE "mve-gather-scatter-lowering"
49 cl::opt<bool> EnableMaskedGatherScatters(
50 "enable-arm-maskedgatscat", cl::Hidden, cl::init(false),
51 cl::desc("Enable the generation of masked gathers and scatters"));
55 class MVEGatherScatterLowering : public FunctionPass {
57 static char ID; // Pass identification, replacement for typeid
59 explicit MVEGatherScatterLowering() : FunctionPass(ID) {
60 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
63 bool runOnFunction(Function &F) override;
65 StringRef getPassName() const override {
66 return "MVE gather/scatter lowering";
69 void getAnalysisUsage(AnalysisUsage &AU) const override {
71 AU.addRequired<TargetPassConfig>();
72 AU.addRequired<LoopInfoWrapperPass>();
73 FunctionPass::getAnalysisUsage(AU);
77 LoopInfo *LI = nullptr;
79 // Check this is a valid gather with correct alignment
80 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
82 // Check whether Ptr is hidden behind a bitcast and look through it
83 void lookThroughBitcast(Value *&Ptr);
84 // Check for a getelementptr and deduce base and offsets from it, on success
85 // returning the base directly and the offsets indirectly using the Offsets
87 Value *checkGEP(Value *&Offsets, Type *Ty, GetElementPtrInst *GEP,
88 IRBuilder<> &Builder);
89 // Compute the scale of this gather/scatter instruction
90 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
91 // If the value is a constant, or derived from constants via additions
92 // and multilications, return its numeric value
93 Optional<int64_t> getIfConst(const Value *V);
94 // If Inst is an add instruction, check whether one summand is a
95 // constant. If so, scale this constant and return it together with
97 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
99 Value *lowerGather(IntrinsicInst *I);
100 // Create a gather from a base + vector of offsets
101 Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
102 Instruction *&Root, IRBuilder<> &Builder);
103 // Create a gather from a vector of pointers
104 Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
105 IRBuilder<> &Builder, int64_t Increment = 0);
106 // Create an incrementing gather from a vector of pointers
107 Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
108 IRBuilder<> &Builder,
109 int64_t Increment = 0);
111 Value *lowerScatter(IntrinsicInst *I);
112 // Create a scatter to a base + vector of offsets
113 Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
114 IRBuilder<> &Builder);
115 // Create a scatter to a vector of pointers
116 Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
117 IRBuilder<> &Builder,
118 int64_t Increment = 0);
119 // Create an incrementing scatter from a vector of pointers
120 Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
121 IRBuilder<> &Builder,
122 int64_t Increment = 0);
124 // QI gathers and scatters can increment their offsets on their own if
125 // the increment is a constant value (digit)
126 Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr,
127 Value *Ptr, GetElementPtrInst *GEP,
128 IRBuilder<> &Builder);
129 // QI gathers/scatters can increment their offsets on their own if the
130 // increment is a constant value (digit) - this creates a writeback QI
132 Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
133 Value *Ptr, unsigned TypeScale,
134 IRBuilder<> &Builder);
135 // Check whether these offsets could be moved out of the loop they're in
136 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
137 // Pushes the given add out of the loop
138 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
139 // Pushes the given mul out of the loop
140 void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
141 Value *OffsSecondOperand, unsigned LoopIncrement,
142 IRBuilder<> &Builder);
145 } // end anonymous namespace
147 char MVEGatherScatterLowering::ID = 0;
149 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
150 "MVE gather/scattering lowering pass", false, false)
152 Pass *llvm::createMVEGatherScatterLoweringPass() {
153 return new MVEGatherScatterLowering();
156 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
159 if (((NumElements == 4 &&
160 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
161 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
162 (NumElements == 16 && ElemSize == 8)) &&
163 Alignment >= ElemSize / 8)
165 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
166 << "valid alignment or vector type \n");
170 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty,
171 GetElementPtrInst *GEP,
172 IRBuilder<> &Builder) {
175 dbgs() << "masked gathers/scatters: no getelementpointer found\n");
178 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
179 << " Looking at intrinsic for base + vector of offsets\n");
180 Value *GEPPtr = GEP->getPointerOperand();
181 if (GEPPtr->getType()->isVectorTy()) {
184 if (GEP->getNumOperands() != 2) {
185 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
186 << " operands. Expanding.\n");
189 Offsets = GEP->getOperand(1);
190 // Paranoid check whether the number of parallel lanes is the same
191 assert(cast<FixedVectorType>(Ty)->getNumElements() ==
192 cast<FixedVectorType>(Offsets->getType())->getNumElements());
193 // Only <N x i32> offsets can be integrated into an arm gather, any smaller
194 // type would have to be sign extended by the gep - and arm gathers can only
195 // zero extend. Additionally, the offsets do have to originate from a zext of
196 // a vector with element types smaller or equal the type of the gather we're
198 if (Offsets->getType()->getScalarSizeInBits() != 32)
200 if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets))
201 Offsets = ZextOffs->getOperand(0);
202 else if (!(cast<FixedVectorType>(Offsets->getType())->getNumElements() == 4 &&
203 Offsets->getType()->getScalarSizeInBits() == 32))
206 if (Ty != Offsets->getType()) {
207 if ((Ty->getScalarSizeInBits() <
208 Offsets->getType()->getScalarSizeInBits())) {
209 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no correct offset type."
210 << " Can't create intrinsic.\n");
213 Offsets = Builder.CreateZExt(
214 Offsets, VectorType::getInteger(cast<VectorType>(Ty)));
217 // If none of the checks failed, return the gep's base pointer
218 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
222 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
223 // Look through bitcast instruction if #elements is the same
224 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
225 auto *BCTy = cast<FixedVectorType>(BitCast->getType());
226 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
227 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
229 dbgs() << "masked gathers/scatters: looking through bitcast\n");
230 Ptr = BitCast->getOperand(0);
235 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
236 unsigned MemoryElemSize) {
237 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
238 // or a 8bit, 16bit or 32bit load/store scaled by 1
239 if (GEPElemSize == 32 && MemoryElemSize == 32)
241 else if (GEPElemSize == 16 && MemoryElemSize == 16)
243 else if (GEPElemSize == 8)
245 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
246 << "create intrinsic\n");
250 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
251 const Constant *C = dyn_cast<Constant>(V);
253 return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
254 if (!isa<Instruction>(V))
255 return Optional<int64_t>{};
257 const Instruction *I = cast<Instruction>(V);
258 if (I->getOpcode() == Instruction::Add ||
259 I->getOpcode() == Instruction::Mul) {
260 Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
261 Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
263 return Optional<int64_t>{};
264 if (I->getOpcode() == Instruction::Add)
265 return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
266 if (I->getOpcode() == Instruction::Mul)
267 return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
269 return Optional<int64_t>{};
272 std::pair<Value *, int64_t>
273 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
274 std::pair<Value *, int64_t> ReturnFalse =
275 std::pair<Value *, int64_t>(nullptr, 0);
276 // At this point, the instruction we're looking at must be an add or we
278 Instruction *Add = dyn_cast<Instruction>(Inst);
279 if (Add == nullptr || Add->getOpcode() != Instruction::Add)
283 Optional<int64_t> Const;
284 // Find out which operand the value that is increased is
285 if ((Const = getIfConst(Add->getOperand(0))))
286 Summand = Add->getOperand(1);
287 else if ((Const = getIfConst(Add->getOperand(1))))
288 Summand = Add->getOperand(0);
292 // Check that the constant is small enough for an incrementing gather
293 int64_t Immediate = Const.getValue() << TypeScale;
294 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
297 return std::pair<Value *, int64_t>(Summand, Immediate);
300 Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
301 using namespace PatternMatch;
302 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
304 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
305 // Attempt to turn the masked gather in I into a MVE intrinsic
306 // Potentially optimising the addressing modes as we do so.
307 auto *Ty = cast<FixedVectorType>(I->getType());
308 Value *Ptr = I->getArgOperand(0);
309 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
310 Value *Mask = I->getArgOperand(2);
311 Value *PassThru = I->getArgOperand(3);
313 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
316 lookThroughBitcast(Ptr);
317 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
319 IRBuilder<> Builder(I->getContext());
320 Builder.SetInsertPoint(I);
321 Builder.SetCurrentDebugLocation(I->getDebugLoc());
323 Instruction *Root = I;
324 Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
326 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
330 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
331 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
332 << "creating select\n");
333 Load = Builder.CreateSelect(Mask, Load, PassThru);
336 Root->replaceAllUsesWith(Load);
337 Root->eraseFromParent();
339 // If this was an extending gather, we need to get rid of the sext/zext
340 // sext/zext as well as of the gather itself
341 I->eraseFromParent();
343 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
347 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
349 IRBuilder<> &Builder,
351 using namespace PatternMatch;
352 auto *Ty = cast<FixedVectorType>(I->getType());
353 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
354 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
355 // Can't build an intrinsic for this
357 Value *Mask = I->getArgOperand(2);
358 if (match(Mask, m_One()))
359 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
360 {Ty, Ptr->getType()},
361 {Ptr, Builder.getInt32(Increment)});
363 return Builder.CreateIntrinsic(
364 Intrinsic::arm_mve_vldr_gather_base_predicated,
365 {Ty, Ptr->getType(), Mask->getType()},
366 {Ptr, Builder.getInt32(Increment), Mask});
369 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
370 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
371 using namespace PatternMatch;
372 auto *Ty = cast<FixedVectorType>(I->getType());
375 << "masked gathers: loading from vector of pointers with writeback\n");
376 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
377 // Can't build an intrinsic for this
379 Value *Mask = I->getArgOperand(2);
380 if (match(Mask, m_One()))
381 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
382 {Ty, Ptr->getType()},
383 {Ptr, Builder.getInt32(Increment)});
385 return Builder.CreateIntrinsic(
386 Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
387 {Ty, Ptr->getType(), Mask->getType()},
388 {Ptr, Builder.getInt32(Increment), Mask});
391 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
392 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
393 using namespace PatternMatch;
395 Type *OriginalTy = I->getType();
396 Type *ResultTy = OriginalTy;
398 unsigned Unsigned = 1;
399 // The size of the gather was already checked in isLegalTypeAndAlignment;
400 // if it was not a full vector width an appropriate extend should follow.
402 if (OriginalTy->getPrimitiveSizeInBits() < 128) {
403 // Only transform gathers with exactly one use
407 // The correct root to replace is not the CallInst itself, but the
408 // instruction which extends it
409 Extend = cast<Instruction>(*I->users().begin());
410 if (isa<SExtInst>(Extend)) {
412 } else if (!isa<ZExtInst>(Extend)) {
413 LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
417 LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
418 ResultTy = Extend->getType();
419 // The final size of the gather must be a full vector width
420 if (ResultTy->getPrimitiveSizeInBits() != 128) {
421 LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
427 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
429 Value *BasePtr = checkGEP(Offsets, ResultTy, GEP, Builder);
432 // Check whether the offset is a constant increment that could be merged into
434 Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
438 int Scale = computeScale(
439 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
440 OriginalTy->getScalarSizeInBits());
445 Value *Mask = I->getArgOperand(2);
446 if (!match(Mask, m_One()))
447 return Builder.CreateIntrinsic(
448 Intrinsic::arm_mve_vldr_gather_offset_predicated,
449 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
450 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
451 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
453 return Builder.CreateIntrinsic(
454 Intrinsic::arm_mve_vldr_gather_offset,
455 {ResultTy, BasePtr->getType(), Offsets->getType()},
456 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
457 Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
460 Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
461 using namespace PatternMatch;
462 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n");
464 // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
465 // Attempt to turn the masked scatter in I into a MVE intrinsic
466 // Potentially optimising the addressing modes as we do so.
467 Value *Input = I->getArgOperand(0);
468 Value *Ptr = I->getArgOperand(1);
469 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
470 auto *Ty = cast<FixedVectorType>(Input->getType());
472 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
476 lookThroughBitcast(Ptr);
477 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
479 IRBuilder<> Builder(I->getContext());
480 Builder.SetInsertPoint(I);
481 Builder.SetCurrentDebugLocation(I->getDebugLoc());
483 Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
485 Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
489 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n");
490 I->eraseFromParent();
494 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
495 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
496 using namespace PatternMatch;
497 Value *Input = I->getArgOperand(0);
498 auto *Ty = cast<FixedVectorType>(Input->getType());
499 // Only QR variants allow truncating
500 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
501 // Can't build an intrinsic for this
504 Value *Mask = I->getArgOperand(3);
505 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
506 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
507 if (match(Mask, m_One()))
508 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
509 {Ptr->getType(), Input->getType()},
510 {Ptr, Builder.getInt32(Increment), Input});
512 return Builder.CreateIntrinsic(
513 Intrinsic::arm_mve_vstr_scatter_base_predicated,
514 {Ptr->getType(), Input->getType(), Mask->getType()},
515 {Ptr, Builder.getInt32(Increment), Input, Mask});
518 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
519 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
520 using namespace PatternMatch;
521 Value *Input = I->getArgOperand(0);
522 auto *Ty = cast<FixedVectorType>(Input->getType());
525 << "masked scatters: storing to a vector of pointers with writeback\n");
526 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
527 // Can't build an intrinsic for this
529 Value *Mask = I->getArgOperand(3);
530 if (match(Mask, m_One()))
531 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
532 {Ptr->getType(), Input->getType()},
533 {Ptr, Builder.getInt32(Increment), Input});
535 return Builder.CreateIntrinsic(
536 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
537 {Ptr->getType(), Input->getType(), Mask->getType()},
538 {Ptr, Builder.getInt32(Increment), Input, Mask});
541 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
542 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
543 using namespace PatternMatch;
544 Value *Input = I->getArgOperand(0);
545 Value *Mask = I->getArgOperand(3);
546 Type *InputTy = Input->getType();
547 Type *MemoryTy = InputTy;
548 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
549 << " to base + vector of offsets\n");
550 // If the input has been truncated, try to integrate that trunc into the
551 // scatter instruction (we don't care about alignment here)
552 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
553 Value *PreTrunc = Trunc->getOperand(0);
554 Type *PreTruncTy = PreTrunc->getType();
555 if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
557 InputTy = PreTruncTy;
560 if (InputTy->getPrimitiveSizeInBits() != 128) {
562 dbgs() << "masked scatters: cannot create scatters for non-standard"
563 << " input types. Expanding.\n");
567 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
569 Value *BasePtr = checkGEP(Offsets, InputTy, GEP, Builder);
572 // Check whether the offset is a constant increment that could be merged into
575 tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
578 int Scale = computeScale(
579 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
580 MemoryTy->getScalarSizeInBits());
584 if (!match(Mask, m_One()))
585 return Builder.CreateIntrinsic(
586 Intrinsic::arm_mve_vstr_scatter_offset_predicated,
587 {BasePtr->getType(), Offsets->getType(), Input->getType(),
589 {BasePtr, Offsets, Input,
590 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
591 Builder.getInt32(Scale), Mask});
593 return Builder.CreateIntrinsic(
594 Intrinsic::arm_mve_vstr_scatter_offset,
595 {BasePtr->getType(), Offsets->getType(), Input->getType()},
596 {BasePtr, Offsets, Input,
597 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
598 Builder.getInt32(Scale)});
601 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
602 IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP,
603 IRBuilder<> &Builder) {
605 if (I->getIntrinsicID() == Intrinsic::masked_gather)
606 Ty = cast<FixedVectorType>(I->getType());
608 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
609 // Incrementing gathers only exist for v4i32
610 if (Ty->getNumElements() != 4 ||
611 Ty->getScalarSizeInBits() != 32)
613 Loop *L = LI->getLoopFor(I->getParent());
615 // Incrementing gathers are not beneficial outside of a loop
617 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
618 "wb gather/scatter\n");
620 // The gep was in charge of making sure the offsets are scaled correctly
621 // - calculate that factor so it can be applied by hand
622 DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
624 computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
625 DT.getTypeSizeInBits(GEP->getType()) /
626 cast<FixedVectorType>(GEP->getType())->getNumElements());
630 if (GEP->hasOneUse()) {
631 // Only in this case do we want to build a wb gather, because the wb will
632 // change the phi which does affect other users of the gep (which will still
633 // be using the phi in the old way)
635 tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder);
639 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
640 "non-wb gather/scatter\n");
642 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
643 if (Add.first == nullptr)
645 Value *OffsetsIncoming = Add.first;
646 int64_t Immediate = Add.second;
648 // Make sure the offsets are scaled correctly
649 Instruction *ScaledOffsets = BinaryOperator::Create(
650 Instruction::Shl, OffsetsIncoming,
651 Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
653 // Add the base to the offsets
654 OffsetsIncoming = BinaryOperator::Create(
655 Instruction::Add, ScaledOffsets,
656 Builder.CreateVectorSplat(
657 Ty->getNumElements(),
658 Builder.CreatePtrToInt(
660 cast<VectorType>(ScaledOffsets->getType())->getElementType())),
663 if (I->getIntrinsicID() == Intrinsic::masked_gather)
664 return cast<IntrinsicInst>(
665 tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate));
667 return cast<IntrinsicInst>(
668 tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate));
671 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
672 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
673 IRBuilder<> &Builder) {
674 // Check whether this gather's offset is incremented by a constant - if so,
675 // and the load is of the right type, we can merge this into a QI gather
676 Loop *L = LI->getLoopFor(I->getParent());
677 // Offsets that are worth merging into this instruction will be incremented
678 // by a constant, thus we're looking for an add of a phi and a constant
679 PHINode *Phi = dyn_cast<PHINode>(Offsets);
680 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
681 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
682 // No phi means no IV to write back to; if there is a phi, we expect it
683 // to have exactly two incoming values; the only phis we are interested in
684 // will be loop IV's and have exactly two uses, one in their increment and
685 // one in the gather's gep
688 unsigned IncrementIndex =
689 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
690 // Look through the phi to the phi increment
691 Offsets = Phi->getIncomingValue(IncrementIndex);
693 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
694 if (Add.first == nullptr)
696 Value *OffsetsIncoming = Add.first;
697 int64_t Immediate = Add.second;
698 if (OffsetsIncoming != Phi)
699 // Then the increment we are looking at is not an increment of the
700 // induction variable, and we don't want to do a writeback
703 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
705 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
707 // Make sure the offsets are scaled correctly
708 Instruction *ScaledOffsets = BinaryOperator::Create(
709 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
710 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
711 "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
712 // Add the base to the offsets
713 OffsetsIncoming = BinaryOperator::Create(
714 Instruction::Add, ScaledOffsets,
715 Builder.CreateVectorSplat(
717 Builder.CreatePtrToInt(
719 cast<VectorType>(ScaledOffsets->getType())->getElementType())),
720 "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
721 // The gather is pre-incrementing
722 OffsetsIncoming = BinaryOperator::Create(
723 Instruction::Sub, OffsetsIncoming,
724 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
725 "PreIncrementStartIndex",
726 &Phi->getIncomingBlock(1 - IncrementIndex)->back());
727 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
729 Builder.SetInsertPoint(I);
733 if (I->getIntrinsicID() == Intrinsic::masked_gather) {
734 // Build the incrementing gather
735 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
736 // One value to be handed to whoever uses the gather, one is the loop
738 EndResult = Builder.CreateExtractValue(Load, 0, "Gather");
739 NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement");
741 // Build the incrementing scatter
742 NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
743 EndResult = NewInduction;
745 Instruction *AddInst = cast<Instruction>(Offsets);
746 AddInst->replaceAllUsesWith(NewInduction);
747 AddInst->eraseFromParent();
748 Phi->setIncomingValue(IncrementIndex, NewInduction);
753 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
754 Value *OffsSecondOperand,
755 unsigned StartIndex) {
756 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
757 Instruction *InsertionPoint =
758 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
759 // Initialize the phi with a vector that contains a sum of the constants
760 Instruction *NewIndex = BinaryOperator::Create(
761 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
762 "PushedOutAdd", InsertionPoint);
763 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
765 // Order such that start index comes first (this reduces mov's)
766 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
767 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
768 Phi->getIncomingBlock(IncrementIndex));
769 Phi->removeIncomingValue(IncrementIndex);
770 Phi->removeIncomingValue(StartIndex);
773 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
774 Value *IncrementPerRound,
775 Value *OffsSecondOperand,
776 unsigned LoopIncrement,
777 IRBuilder<> &Builder) {
778 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
780 // Create a new scalar add outside of the loop and transform it to a splat
781 // by which loop variable can be incremented
782 Instruction *InsertionPoint = &cast<Instruction>(
783 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
785 // Create a new index
786 Value *StartIndex = BinaryOperator::Create(
787 Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
788 OffsSecondOperand, "PushedOutMul", InsertionPoint);
790 Instruction *Product =
791 BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
792 OffsSecondOperand, "Product", InsertionPoint);
793 // Increment NewIndex by Product instead of the multiplication
794 Instruction *NewIncrement = BinaryOperator::Create(
795 Instruction::Add, Phi, Product, "IncrementPushedOutMul",
796 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
799 Phi->addIncoming(StartIndex,
800 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
801 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
802 Phi->removeIncomingValue((unsigned)0);
803 Phi->removeIncomingValue((unsigned)0);
807 // Check whether all usages of this instruction are as offsets of
808 // gathers/scatters or simple arithmetics only used by gathers/scatters
809 static bool hasAllGatScatUsers(Instruction *I) {
810 if (I->hasNUses(0)) {
814 for (User *U : I->users()) {
815 if (!isa<Instruction>(U))
817 if (isa<GetElementPtrInst>(U) ||
818 isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
821 unsigned OpCode = cast<Instruction>(U)->getOpcode();
822 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
823 hasAllGatScatUsers(cast<Instruction>(U))) {
832 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
834 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n");
835 // Optimise the addresses of gathers/scatters by moving invariant
836 // calculations out of the loop
837 if (!isa<Instruction>(Offsets))
839 Instruction *Offs = cast<Instruction>(Offsets);
840 if (Offs->getOpcode() != Instruction::Add &&
841 Offs->getOpcode() != Instruction::Mul)
843 Loop *L = LI->getLoopFor(BB);
846 if (!Offs->hasOneUse()) {
847 if (!hasAllGatScatUsers(Offs))
851 // Find out which, if any, operand of the instruction
855 if (isa<PHINode>(Offs->getOperand(0))) {
856 Phi = cast<PHINode>(Offs->getOperand(0));
858 } else if (isa<PHINode>(Offs->getOperand(1))) {
859 Phi = cast<PHINode>(Offs->getOperand(1));
863 if (isa<Instruction>(Offs->getOperand(0)) &&
864 L->contains(cast<Instruction>(Offs->getOperand(0))))
865 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
866 if (isa<Instruction>(Offs->getOperand(1)) &&
867 L->contains(cast<Instruction>(Offs->getOperand(1))))
868 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
872 if (isa<PHINode>(Offs->getOperand(0))) {
873 Phi = cast<PHINode>(Offs->getOperand(0));
875 } else if (isa<PHINode>(Offs->getOperand(1))) {
876 Phi = cast<PHINode>(Offs->getOperand(1));
883 // A phi node we want to perform this function on should be from the
884 // loop header, and shouldn't have more than 2 incoming values
885 if (Phi->getParent() != L->getHeader() ||
886 Phi->getNumIncomingValues() != 2)
889 // The phi must be an induction variable
891 int IncrementingBlock = -1;
893 for (int i = 0; i < 2; i++)
894 if ((Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) != nullptr)
895 if (Op->getOpcode() == Instruction::Add &&
896 (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
897 IncrementingBlock = i;
898 if (IncrementingBlock == -1)
901 Instruction *IncInstruction =
902 cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
904 // If the phi is not used by anything else, we can just adapt it when
905 // replacing the instruction; if it is, we'll have to duplicate it
907 Value *IncrementPerRound = IncInstruction->getOperand(
908 (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
910 // Get the value that is added to/multiplied with the phi
911 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
913 if (IncrementPerRound->getType() != OffsSecondOperand->getType())
914 // Something has gone wrong, abort
917 // Only proceed if the increment per round is a constant or an instruction
918 // which does not originate from within the loop
919 if (!isa<Constant>(IncrementPerRound) &&
920 !(isa<Instruction>(IncrementPerRound) &&
921 !L->contains(cast<Instruction>(IncrementPerRound))))
924 if (Phi->getNumUses() == 2) {
925 // No other users -> reuse existing phi (One user is the instruction
926 // we're looking at, the other is the phi increment)
927 if (IncInstruction->getNumUses() != 1) {
928 // If the incrementing instruction does have more users than
929 // our phi, we need to copy it
930 IncInstruction = BinaryOperator::Create(
931 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
932 IncrementPerRound, "LoopIncrement", IncInstruction);
933 Phi->setIncomingValue(IncrementingBlock, IncInstruction);
937 // There are other users -> create a new phi
938 NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
939 std::vector<Value *> Increases;
940 // Copy the incoming values of the old phi
941 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
942 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
943 IncInstruction = BinaryOperator::Create(
944 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
945 IncrementPerRound, "LoopIncrement", IncInstruction);
946 NewPhi->addIncoming(IncInstruction,
947 Phi->getIncomingBlock(IncrementingBlock));
948 IncrementingBlock = 1;
951 IRBuilder<> Builder(BB->getContext());
952 Builder.SetInsertPoint(Phi);
953 Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
955 switch (Offs->getOpcode()) {
956 case Instruction::Add:
957 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
959 case Instruction::Mul:
960 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
967 dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
969 // The instruction has now been "absorbed" into the phi value
970 Offs->replaceAllUsesWith(NewPhi);
971 if (Offs->hasNUses(0))
972 Offs->eraseFromParent();
973 // Clean up the old increment in case it's unused because we built a new
975 if (IncInstruction->hasNUses(0))
976 IncInstruction->eraseFromParent();
981 bool MVEGatherScatterLowering::runOnFunction(Function &F) {
982 if (!EnableMaskedGatherScatters)
984 auto &TPC = getAnalysis<TargetPassConfig>();
985 auto &TM = TPC.getTM<TargetMachine>();
986 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
987 if (!ST->hasMVEIntegerOps())
989 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
990 SmallVector<IntrinsicInst *, 4> Gathers;
991 SmallVector<IntrinsicInst *, 4> Scatters;
993 bool Changed = false;
995 for (BasicBlock &BB : F) {
996 for (Instruction &I : BB) {
997 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
998 if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
999 Gathers.push_back(II);
1000 if (isa<GetElementPtrInst>(II->getArgOperand(0)))
1001 Changed |= optimiseOffsets(
1002 cast<Instruction>(II->getArgOperand(0))->getOperand(1),
1003 II->getParent(), LI);
1004 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
1005 Scatters.push_back(II);
1006 if (isa<GetElementPtrInst>(II->getArgOperand(1)))
1007 Changed |= optimiseOffsets(
1008 cast<Instruction>(II->getArgOperand(1))->getOperand(1),
1009 II->getParent(), LI);
1014 for (unsigned i = 0; i < Gathers.size(); i++) {
1015 IntrinsicInst *I = Gathers[i];
1016 Value *L = lowerGather(I);
1020 // Get rid of any now dead instructions
1021 SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
1025 for (unsigned i = 0; i < Scatters.size(); i++) {
1026 IntrinsicInst *I = Scatters[i];
1027 Value *S = lowerScatter(I);
1031 // Get rid of any now dead instructions
1032 SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());