1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines some vectorizer utilities.
11 //===----------------------------------------------------------------------===//
13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
14 #define LLVM_ANALYSIS_VECTORUTILS_H
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/IR/VFABIDemangler.h"
20 #include "llvm/Support/CheckedArithmetic.h"
23 class TargetLibraryInfo;
25 /// The Vector Function Database.
27 /// Helper class used to find the vector functions associated to a
30 /// The Module of the CallInst CI.
32 /// The CallInst instance being queried for scalar to vector mappings.
34 /// List of vector functions descriptors associated to the call
36 const SmallVector<VFInfo, 8> ScalarToVectorMappings;
38 /// Retrieve the scalar-to-vector mappings associated to the rule of
39 /// a vector Function ABI.
40 static void getVFABIMappings(const CallInst &CI,
41 SmallVectorImpl<VFInfo> &Mappings) {
42 if (!CI.getCalledFunction())
45 const StringRef ScalarName = CI.getCalledFunction()->getName();
47 SmallVector<std::string, 8> ListOfStrings;
48 // The check for the vector-function-abi-variant attribute is done when
49 // retrieving the vector variant names here.
50 VFABI::getVectorVariantNames(CI, ListOfStrings);
51 if (ListOfStrings.empty())
53 for (const auto &MangledName : ListOfStrings) {
54 const std::optional<VFInfo> Shape =
55 VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType());
56 // A match is found via scalar and vector names, and also by
57 // ensuring that the variant described in the attribute has a
58 // corresponding definition or declaration of the vector
59 // function in the Module M.
60 if (Shape && (Shape->ScalarName == ScalarName)) {
61 assert(CI.getModule()->getFunction(Shape->VectorName) &&
62 "Vector function is missing.");
63 Mappings.push_back(*Shape);
69 /// Retrieve all the VFInfo instances associated to the CallInst CI.
70 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
71 SmallVector<VFInfo, 8> Ret;
73 // Get mappings from the Vector Function ABI variants.
74 getVFABIMappings(CI, Ret);
76 // Other non-VFABI variants should be retrieved here.
81 static bool hasMaskedVariant(const CallInst &CI,
82 std::optional<ElementCount> VF = std::nullopt) {
83 // Check whether we have at least one masked vector version of a scalar
84 // function. If no VF is specified then we check for any masked variant,
85 // otherwise we look for one that matches the supplied VF.
86 auto Mappings = VFDatabase::getMappings(CI);
87 for (VFInfo Info : Mappings)
88 if (!VF || Info.Shape.VF == *VF)
95 /// Constructor, requires a CallInst instance.
96 VFDatabase(CallInst &CI)
97 : M(CI.getModule()), CI(CI),
98 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
99 /// \defgroup VFDatabase query interface.
102 /// Retrieve the Function with VFShape \p Shape.
103 Function *getVectorizedFunction(const VFShape &Shape) const {
104 if (Shape == VFShape::getScalarShape(CI.getFunctionType()))
105 return CI.getCalledFunction();
107 for (const auto &Info : ScalarToVectorMappings)
108 if (Info.Shape == Shape)
109 return M->getFunction(Info.VectorName);
116 template <typename T> class ArrayRef;
118 template <typename InstTy> class InterleaveGroup;
121 class ScalarEvolution;
122 class TargetTransformInfo;
126 namespace Intrinsic {
130 /// A helper function for converting Scalar types to vector types. If
131 /// the incoming type is void, we return void. If the EC represents a
132 /// scalar, we return the scalar type.
133 inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
134 if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
136 return VectorType::get(Scalar, EC);
139 inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
140 return ToVectorTy(Scalar, ElementCount::getFixed(VF));
143 /// Identify if the intrinsic is trivially vectorizable.
144 /// This method returns true if the intrinsic's argument types are all scalars
145 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
146 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
147 bool isTriviallyVectorizable(Intrinsic::ID ID);
149 /// Identifies if the vector form of the intrinsic has a scalar operand.
150 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
151 unsigned ScalarOpdIdx);
153 /// Identifies if the vector form of the intrinsic is overloaded on the type of
154 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
155 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
157 /// Returns intrinsic ID for call.
158 /// For the input call instruction it finds mapping intrinsic and returns
159 /// its intrinsic ID, in case it does not found it return not_intrinsic.
160 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
161 const TargetLibraryInfo *TLI);
163 /// Given a vector and an element number, see if the scalar value is
164 /// already around as a register, for example if it were inserted then extracted
166 Value *findScalarElement(Value *V, unsigned EltNo);
168 /// If all non-negative \p Mask elements are the same value, return that value.
169 /// If all elements are negative (undefined) or \p Mask contains different
170 /// non-negative values, return -1.
171 int getSplatIndex(ArrayRef<int> Mask);
173 /// Get splat value if the input is a splat vector or return nullptr.
174 /// The value may be extracted from a splat constants vector or from
175 /// a sequence of instructions that broadcast a single value into a vector.
176 Value *getSplatValue(const Value *V);
178 /// Return true if each element of the vector value \p V is poisoned or equal to
179 /// every other non-poisoned element. If an index element is specified, either
180 /// every element of the vector is poisoned or the element at that index is not
181 /// poisoned and equal to every other non-poisoned element.
182 /// This may be more powerful than the related getSplatValue() because it is
183 /// not limited by finding a scalar source value to a splatted vector.
184 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
186 /// Transform a shuffle mask's output demanded element mask into demanded
187 /// element masks for the 2 operands, returns false if the mask isn't valid.
188 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
189 /// \p AllowUndefElts permits "-1" indices to be treated as undef.
190 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
191 const APInt &DemandedElts, APInt &DemandedLHS,
192 APInt &DemandedRHS, bool AllowUndefElts = false);
194 /// Replace each shuffle mask index with the scaled sequential indices for an
195 /// equivalent mask of narrowed elements. Mask elements that are less than 0
196 /// (sentinel values) are repeated in the output mask.
198 /// Example with Scale = 4:
199 /// <4 x i32> <3, 2, 0, -1> -->
200 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
202 /// This is the reverse process of widening shuffle mask elements, but it always
203 /// succeeds because the indexes can always be multiplied (scaled up) to map to
204 /// narrower vector elements.
205 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
206 SmallVectorImpl<int> &ScaledMask);
208 /// Try to transform a shuffle mask by replacing elements with the scaled index
209 /// for an equivalent mask of widened elements. If all mask elements that would
210 /// map to a wider element of the new mask are the same negative number
211 /// (sentinel value), that element of the new mask is the same value. If any
212 /// element in a given slice is negative and some other element in that slice is
213 /// not the same value, return false (partial matches with sentinel values are
216 /// Example with Scale = 4:
217 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
218 /// <4 x i32> <3, 2, 0, -1>
220 /// This is the reverse process of narrowing shuffle mask elements if it
221 /// succeeds. This transform is not always possible because indexes may not
222 /// divide evenly (scale down) to map to wider vector elements.
223 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
224 SmallVectorImpl<int> &ScaledMask);
226 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds,
227 /// to get the shuffle mask with widest possible elements.
228 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
229 SmallVectorImpl<int> &ScaledMask);
231 /// Splits and processes shuffle mask depending on the number of input and
232 /// output registers. The function does 2 main things: 1) splits the
233 /// source/destination vectors into real registers; 2) do the mask analysis to
234 /// identify which real registers are permuted. Then the function processes
235 /// resulting registers mask using provided action items. If no input register
236 /// is defined, \p NoInputAction action is used. If only 1 input register is
237 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to
238 /// process > 2 input registers and masks.
239 /// \param Mask Original shuffle mask.
240 /// \param NumOfSrcRegs Number of source registers.
241 /// \param NumOfDestRegs Number of destination registers.
242 /// \param NumOfUsedRegs Number of actually used destination registers.
243 void processShuffleMasks(
244 ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
245 unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
246 function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
247 function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
249 /// Compute a map of integer instructions to their minimum legal type
252 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
253 /// type (e.g. i32) whenever arithmetic is performed on them.
255 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
256 /// the arithmetic type down again. However InstCombine refuses to create
257 /// illegal types, so for targets without i8 or i16 registers, the lengthening
258 /// and shrinking remains.
260 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
261 /// their scalar equivalents do not, so during vectorization it is important to
262 /// remove these lengthens and truncates when deciding the profitability of
265 /// This function analyzes the given range of instructions and determines the
266 /// minimum type size each can be converted to. It attempts to remove or
267 /// minimize type size changes across each def-use chain, so for example in the
270 /// %1 = load i8, i8*
271 /// %2 = add i8 %1, 2
272 /// %3 = load i16, i16*
273 /// %4 = zext i8 %2 to i32
274 /// %5 = zext i16 %3 to i32
275 /// %6 = add i32 %4, %5
276 /// %7 = trunc i32 %6 to i16
278 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
279 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
281 /// If the optional TargetTransformInfo is provided, this function tries harder
282 /// to do less work by only looking at illegal types.
283 MapVector<Instruction*, uint64_t>
284 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
286 const TargetTransformInfo *TTI=nullptr);
288 /// Compute the union of two access-group lists.
290 /// If the list contains just one access group, it is returned directly. If the
291 /// list is empty, returns nullptr.
292 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
294 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
295 /// are both in. If either instruction does not access memory at all, it is
296 /// considered to be in every list.
298 /// If the list contains just one access group, it is returned directly. If the
299 /// list is empty, returns nullptr.
300 MDNode *intersectAccessGroups(const Instruction *Inst1,
301 const Instruction *Inst2);
303 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
304 /// MD_nontemporal, MD_access_group].
305 /// For K in Kinds, we get the MDNode for K from each of the
306 /// elements of VL, compute their "intersection" (i.e., the most generic
307 /// metadata value that covers all of the individual values), and set I's
308 /// metadata for M equal to the intersection value.
310 /// This function always sets a (possibly null) value for each K in Kinds.
311 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
313 /// Create a mask that filters the members of an interleave group where there
316 /// For example, the mask for \p Group with interleave-factor 3
317 /// and \p VF 4, that has only its first member present is:
319 /// <1,0,0,1,0,0,1,0,0,1,0,0>
321 /// Note: The result is a mask of 0's and 1's, as opposed to the other
322 /// create[*]Mask() utilities which create a shuffle mask (mask that
323 /// consists of indices).
324 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
325 const InterleaveGroup<Instruction> &Group);
327 /// Create a mask with replicated elements.
329 /// This function creates a shuffle mask for replicating each of the \p VF
330 /// elements in a vector \p ReplicationFactor times. It can be used to
331 /// transform a mask of \p VF elements into a mask of
332 /// \p VF * \p ReplicationFactor elements used by a predicated
333 /// interleaved-group of loads/stores whose Interleaved-factor ==
334 /// \p ReplicationFactor.
336 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
338 /// <0,0,0,1,1,1,2,2,2,3,3,3>
339 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
342 /// Create an interleave shuffle mask.
344 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
345 /// vectorization factor \p VF into a single wide vector. The mask is of the
348 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
350 /// For example, the mask for VF = 4 and NumVecs = 2 is:
352 /// <0, 4, 1, 5, 2, 6, 3, 7>.
353 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
355 /// Create a stride shuffle mask.
357 /// This function creates a shuffle mask whose elements begin at \p Start and
358 /// are incremented by \p Stride. The mask can be used to deinterleave an
359 /// interleaved vector into separate vectors of vectorization factor \p VF. The
360 /// mask is of the form:
362 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
364 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
367 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
370 /// Create a sequential shuffle mask.
372 /// This function creates shuffle mask whose elements are sequential and begin
373 /// at \p Start. The mask contains \p NumInts integers and is padded with \p
374 /// NumUndefs undef values. The mask is of the form:
376 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
378 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
380 /// <0, 1, 2, 3, undef, undef, undef, undef>
381 llvm::SmallVector<int, 16>
382 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
384 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle
385 /// mask assuming both operands are identical. This assumes that the unary
386 /// shuffle will use elements from operand 0 (operand 1 will be unused).
387 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask,
390 /// Concatenate a list of vectors.
392 /// This function generates code that concatenate the vectors in \p Vecs into a
393 /// single large vector. The number of vectors should be greater than one, and
394 /// their element types should be the same. The number of elements in the
395 /// vectors should also be the same; however, if the last vector has fewer
396 /// elements, it will be padded with undefs.
397 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
399 /// Given a mask vector of i1, Return true if all of the elements of this
400 /// predicate mask are known to be false or undef. That is, return true if all
401 /// lanes can be assumed inactive.
402 bool maskIsAllZeroOrUndef(Value *Mask);
404 /// Given a mask vector of i1, Return true if all of the elements of this
405 /// predicate mask are known to be true or undef. That is, return true if all
406 /// lanes can be assumed active.
407 bool maskIsAllOneOrUndef(Value *Mask);
409 /// Given a mask vector of i1, Return true if any of the elements of this
410 /// predicate mask are known to be true or undef. That is, return true if at
411 /// least one lane can be assumed active.
412 bool maskContainsAllOneOrUndef(Value *Mask);
414 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
415 /// for each lane which may be active.
416 APInt possiblyDemandedEltsInMask(Value *Mask);
418 /// The group of interleaved loads/stores sharing the same stride and
419 /// close to each other.
421 /// Each member in this group has an index starting from 0, and the largest
422 /// index should be less than interleaved factor, which is equal to the absolute
423 /// value of the access's stride.
425 /// E.g. An interleaved load group of factor 4:
426 /// for (unsigned i = 0; i < 1024; i+=4) {
427 /// a = A[i]; // Member of index 0
428 /// b = A[i+1]; // Member of index 1
429 /// d = A[i+3]; // Member of index 3
433 /// An interleaved store group of factor 4:
434 /// for (unsigned i = 0; i < 1024; i+=4) {
436 /// A[i] = a; // Member of index 0
437 /// A[i+1] = b; // Member of index 1
438 /// A[i+2] = c; // Member of index 2
439 /// A[i+3] = d; // Member of index 3
442 /// Note: the interleaved load group could have gaps (missing members), but
443 /// the interleaved store group doesn't allow gaps.
444 template <typename InstTy> class InterleaveGroup {
446 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
447 : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
448 InsertPos(nullptr) {}
450 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
451 : Alignment(Alignment), InsertPos(Instr) {
452 Factor = std::abs(Stride);
453 assert(Factor > 1 && "Invalid interleave factor");
455 Reverse = Stride < 0;
459 bool isReverse() const { return Reverse; }
460 uint32_t getFactor() const { return Factor; }
461 Align getAlign() const { return Alignment; }
462 uint32_t getNumMembers() const { return Members.size(); }
464 /// Try to insert a new member \p Instr with index \p Index and
465 /// alignment \p NewAlign. The index is related to the leader and it could be
466 /// negative if it is the new leader.
468 /// \returns false if the instruction doesn't belong to the group.
469 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
470 // Make sure the key fits in an int32_t.
471 std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
474 int32_t Key = *MaybeKey;
476 // Skip if the key is used for either the tombstone or empty special values.
477 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
478 DenseMapInfo<int32_t>::getEmptyKey() == Key)
481 // Skip if there is already a member with the same index.
482 if (Members.contains(Key))
485 if (Key > LargestKey) {
486 // The largest index is always less than the interleave factor.
487 if (Index >= static_cast<int32_t>(Factor))
491 } else if (Key < SmallestKey) {
493 // Make sure the largest index fits in an int32_t.
494 std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
495 if (!MaybeLargestIndex)
498 // The largest index is always less than the interleave factor.
499 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
505 // It's always safe to select the minimum alignment.
506 Alignment = std::min(Alignment, NewAlign);
507 Members[Key] = Instr;
511 /// Get the member with the given index \p Index
513 /// \returns nullptr if contains no such member.
514 InstTy *getMember(uint32_t Index) const {
515 int32_t Key = SmallestKey + Index;
516 return Members.lookup(Key);
519 /// Get the index for the given member. Unlike the key in the member
520 /// map, the index starts from 0.
521 uint32_t getIndex(const InstTy *Instr) const {
522 for (auto I : Members) {
523 if (I.second == Instr)
524 return I.first - SmallestKey;
527 llvm_unreachable("InterleaveGroup contains no such member");
530 InstTy *getInsertPos() const { return InsertPos; }
531 void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
533 /// Add metadata (e.g. alias info) from the instructions in this group to \p
536 /// FIXME: this function currently does not add noalias metadata a'la
537 /// addNewMedata. To do that we need to compute the intersection of the
538 /// noalias info from all members.
539 void addMetadata(InstTy *NewInst) const;
541 /// Returns true if this Group requires a scalar iteration to handle gaps.
542 bool requiresScalarEpilogue() const {
543 // If the last member of the Group exists, then a scalar epilog is not
544 // needed for this group.
545 if (getMember(getFactor() - 1))
548 // We have a group with gaps. It therefore can't be a reversed access,
549 // because such groups get invalidated (TODO).
550 assert(!isReverse() && "Group should have been invalidated");
552 // This is a group of loads, with gaps, and without a last-member
557 uint32_t Factor; // Interleave Factor.
560 DenseMap<int32_t, InstTy *> Members;
561 int32_t SmallestKey = 0;
562 int32_t LargestKey = 0;
564 // To avoid breaking dependences, vectorized instructions of an interleave
565 // group should be inserted at either the first load or the last store in
568 // E.g. %even = load i32 // Insert Position
569 // %add = add i32 %even // Use of %even
573 // %odd = add i32 // Def of %odd
574 // store i32 %odd // Insert Position
578 /// Drive the analysis of interleaved memory accesses in the loop.
580 /// Use this class to analyze interleaved accesses only when we can vectorize
581 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
582 /// on interleaved accesses is unsafe.
584 /// The analysis collects interleave groups and records the relationships
585 /// between the member and the group in a map.
586 class InterleavedAccessInfo {
588 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
589 DominatorTree *DT, LoopInfo *LI,
590 const LoopAccessInfo *LAI)
591 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
593 ~InterleavedAccessInfo() { invalidateGroups(); }
595 /// Analyze the interleaved accesses and collect them in interleave
596 /// groups. Substitute symbolic strides using \p Strides.
597 /// Consider also predicated loads/stores in the analysis if
598 /// \p EnableMaskedInterleavedGroup is true.
599 void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
601 /// Invalidate groups, e.g., in case all blocks in loop will be predicated
602 /// contrary to original assumption. Although we currently prevent group
603 /// formation for predicated accesses, we may be able to relax this limitation
604 /// in the future once we handle more complicated blocks. Returns true if any
605 /// groups were invalidated.
606 bool invalidateGroups() {
607 if (InterleaveGroups.empty()) {
609 !RequiresScalarEpilogue &&
610 "RequiresScalarEpilog should not be set without interleave groups");
614 InterleaveGroupMap.clear();
615 for (auto *Ptr : InterleaveGroups)
617 InterleaveGroups.clear();
618 RequiresScalarEpilogue = false;
622 /// Check if \p Instr belongs to any interleave group.
623 bool isInterleaved(Instruction *Instr) const {
624 return InterleaveGroupMap.contains(Instr);
627 /// Get the interleave group that \p Instr belongs to.
629 /// \returns nullptr if doesn't have such group.
630 InterleaveGroup<Instruction> *
631 getInterleaveGroup(const Instruction *Instr) const {
632 return InterleaveGroupMap.lookup(Instr);
635 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
636 getInterleaveGroups() {
637 return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
640 /// Returns true if an interleaved group that may access memory
641 /// out-of-bounds requires a scalar epilogue iteration for correctness.
642 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
644 /// Invalidate groups that require a scalar epilogue (due to gaps). This can
645 /// happen when optimizing for size forbids a scalar epilogue, and the gap
646 /// cannot be filtered by masking the load/store.
647 void invalidateGroupsRequiringScalarEpilogue();
649 /// Returns true if we have any interleave groups.
650 bool hasGroups() const { return !InterleaveGroups.empty(); }
653 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
654 /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
655 /// The interleaved access analysis can also add new predicates (for example
656 /// by versioning strides of pointers).
657 PredicatedScalarEvolution &PSE;
662 const LoopAccessInfo *LAI;
664 /// True if the loop may contain non-reversed interleaved groups with
665 /// out-of-bounds accesses. We ensure we don't speculatively access memory
666 /// out-of-bounds by executing at least one scalar epilogue iteration.
667 bool RequiresScalarEpilogue = false;
669 /// Holds the relationships between the members and the interleave group.
670 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
672 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
674 /// Holds dependences among the memory accesses in the loop. It maps a source
675 /// access to a set of dependent sink accesses.
676 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
678 /// The descriptor for a strided memory access.
679 struct StrideDescriptor {
680 StrideDescriptor() = default;
681 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
683 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
685 // The access's stride. It is negative for a reverse access.
688 // The scalar expression of this access.
689 const SCEV *Scev = nullptr;
691 // The size of the memory object.
694 // The alignment of this access.
698 /// A type for holding instructions and their stride descriptors.
699 using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
701 /// Create a new interleave group with the given instruction \p Instr,
702 /// stride \p Stride and alignment \p Align.
704 /// \returns the newly created interleave group.
705 InterleaveGroup<Instruction> *
706 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
707 assert(!InterleaveGroupMap.count(Instr) &&
708 "Already in an interleaved access group");
709 InterleaveGroupMap[Instr] =
710 new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
711 InterleaveGroups.insert(InterleaveGroupMap[Instr]);
712 return InterleaveGroupMap[Instr];
715 /// Release the group and remove all the relationships.
716 void releaseGroup(InterleaveGroup<Instruction> *Group) {
717 for (unsigned i = 0; i < Group->getFactor(); i++)
718 if (Instruction *Member = Group->getMember(i))
719 InterleaveGroupMap.erase(Member);
721 InterleaveGroups.erase(Group);
725 /// Collect all the accesses with a constant stride in program order.
726 void collectConstStrideAccesses(
727 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
728 const DenseMap<Value *, const SCEV *> &Strides);
730 /// Returns true if \p Stride is allowed in an interleaved group.
731 static bool isStrided(int Stride);
733 /// Returns true if \p BB is a predicated block.
734 bool isPredicated(BasicBlock *BB) const {
735 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
738 /// Returns true if LoopAccessInfo can be used for dependence queries.
739 bool areDependencesValid() const {
740 return LAI && LAI->getDepChecker().getDependences();
743 /// Returns true if memory accesses \p A and \p B can be reordered, if
744 /// necessary, when constructing interleaved groups.
746 /// \p A must precede \p B in program order. We return false if reordering is
747 /// not necessary or is prevented because \p A and \p B may be dependent.
748 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
749 StrideEntry *B) const {
750 // Code motion for interleaved accesses can potentially hoist strided loads
751 // and sink strided stores. The code below checks the legality of the
752 // following two conditions:
754 // 1. Potentially moving a strided load (B) before any store (A) that
757 // 2. Potentially moving a strided store (A) after any load or store (B)
760 // It's legal to reorder A and B if we know there isn't a dependence from A
761 // to B. Note that this determination is conservative since some
762 // dependences could potentially be reordered safely.
764 // A is potentially the source of a dependence.
765 auto *Src = A->first;
766 auto SrcDes = A->second;
768 // B is potentially the sink of a dependence.
769 auto *Sink = B->first;
770 auto SinkDes = B->second;
772 // Code motion for interleaved accesses can't violate WAR dependences.
773 // Thus, reordering is legal if the source isn't a write.
774 if (!Src->mayWriteToMemory())
777 // At least one of the accesses must be strided.
778 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
781 // If dependence information is not available from LoopAccessInfo,
782 // conservatively assume the instructions can't be reordered.
783 if (!areDependencesValid())
786 // If we know there is a dependence from source to sink, assume the
787 // instructions can't be reordered. Otherwise, reordering is legal.
788 return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink);
791 /// Collect the dependences from LoopAccessInfo.
793 /// We process the dependences once during the interleaved access analysis to
794 /// enable constant-time dependence queries.
795 void collectDependences() {
796 if (!areDependencesValid())
798 auto *Deps = LAI->getDepChecker().getDependences();
799 for (auto Dep : *Deps)
800 Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI));