]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - lib/Transforms/IPO/LowerTypeTests.cpp
Vendor import of llvm trunk r291274:
[FreeBSD/FreeBSD.git] / lib / Transforms / IPO / LowerTypeTests.cpp
1 //===-- LowerTypeTests.cpp - type metadata lowering pass ------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass lowers type metadata and calls to the llvm.type.test intrinsic.
11 // See http://llvm.org/docs/TypeMetadata.html for more information.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/IPO/LowerTypeTests.h"
16 #include "llvm/ADT/EquivalenceClasses.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/ADT/Triple.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/GlobalObject.h"
24 #include "llvm/IR/GlobalVariable.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InlineAsm.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/ModuleSummaryIndexYAML.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/Error.h"
35 #include "llvm/Support/FileSystem.h"
36 #include "llvm/Support/TrailingObjects.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/Transforms/IPO.h"
39 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
40 #include "llvm/Transforms/Utils/ModuleUtils.h"
41
42 using namespace llvm;
43 using namespace lowertypetests;
44
45 #define DEBUG_TYPE "lowertypetests"
46
47 STATISTIC(ByteArraySizeBits, "Byte array size in bits");
48 STATISTIC(ByteArraySizeBytes, "Byte array size in bytes");
49 STATISTIC(NumByteArraysCreated, "Number of byte arrays created");
50 STATISTIC(NumTypeTestCallsLowered, "Number of type test calls lowered");
51 STATISTIC(NumTypeIdDisjointSets, "Number of disjoint sets of type identifiers");
52
53 static cl::opt<bool> AvoidReuse(
54     "lowertypetests-avoid-reuse",
55     cl::desc("Try to avoid reuse of byte array addresses using aliases"),
56     cl::Hidden, cl::init(true));
57
58 static cl::opt<std::string> ClSummaryAction(
59     "lowertypetests-summary-action",
60     cl::desc("What to do with the summary when running this pass"), cl::Hidden);
61
62 static cl::opt<std::string> ClReadSummary(
63     "lowertypetests-read-summary",
64     cl::desc("Read summary from given YAML file before running pass"),
65     cl::Hidden);
66
67 static cl::opt<std::string> ClWriteSummary(
68     "lowertypetests-write-summary",
69     cl::desc("Write summary to given YAML file after running pass"),
70     cl::Hidden);
71
72 bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
73   if (Offset < ByteOffset)
74     return false;
75
76   if ((Offset - ByteOffset) % (uint64_t(1) << AlignLog2) != 0)
77     return false;
78
79   uint64_t BitOffset = (Offset - ByteOffset) >> AlignLog2;
80   if (BitOffset >= BitSize)
81     return false;
82
83   return Bits.count(BitOffset);
84 }
85
86 void BitSetInfo::print(raw_ostream &OS) const {
87   OS << "offset " << ByteOffset << " size " << BitSize << " align "
88      << (1 << AlignLog2);
89
90   if (isAllOnes()) {
91     OS << " all-ones\n";
92     return;
93   }
94
95   OS << " { ";
96   for (uint64_t B : Bits)
97     OS << B << ' ';
98   OS << "}\n";
99 }
100
101 BitSetInfo BitSetBuilder::build() {
102   if (Min > Max)
103     Min = 0;
104
105   // Normalize each offset against the minimum observed offset, and compute
106   // the bitwise OR of each of the offsets. The number of trailing zeros
107   // in the mask gives us the log2 of the alignment of all offsets, which
108   // allows us to compress the bitset by only storing one bit per aligned
109   // address.
110   uint64_t Mask = 0;
111   for (uint64_t &Offset : Offsets) {
112     Offset -= Min;
113     Mask |= Offset;
114   }
115
116   BitSetInfo BSI;
117   BSI.ByteOffset = Min;
118
119   BSI.AlignLog2 = 0;
120   if (Mask != 0)
121     BSI.AlignLog2 = countTrailingZeros(Mask, ZB_Undefined);
122
123   // Build the compressed bitset while normalizing the offsets against the
124   // computed alignment.
125   BSI.BitSize = ((Max - Min) >> BSI.AlignLog2) + 1;
126   for (uint64_t Offset : Offsets) {
127     Offset >>= BSI.AlignLog2;
128     BSI.Bits.insert(Offset);
129   }
130
131   return BSI;
132 }
133
134 void GlobalLayoutBuilder::addFragment(const std::set<uint64_t> &F) {
135   // Create a new fragment to hold the layout for F.
136   Fragments.emplace_back();
137   std::vector<uint64_t> &Fragment = Fragments.back();
138   uint64_t FragmentIndex = Fragments.size() - 1;
139
140   for (auto ObjIndex : F) {
141     uint64_t OldFragmentIndex = FragmentMap[ObjIndex];
142     if (OldFragmentIndex == 0) {
143       // We haven't seen this object index before, so just add it to the current
144       // fragment.
145       Fragment.push_back(ObjIndex);
146     } else {
147       // This index belongs to an existing fragment. Copy the elements of the
148       // old fragment into this one and clear the old fragment. We don't update
149       // the fragment map just yet, this ensures that any further references to
150       // indices from the old fragment in this fragment do not insert any more
151       // indices.
152       std::vector<uint64_t> &OldFragment = Fragments[OldFragmentIndex];
153       Fragment.insert(Fragment.end(), OldFragment.begin(), OldFragment.end());
154       OldFragment.clear();
155     }
156   }
157
158   // Update the fragment map to point our object indices to this fragment.
159   for (uint64_t ObjIndex : Fragment)
160     FragmentMap[ObjIndex] = FragmentIndex;
161 }
162
163 void ByteArrayBuilder::allocate(const std::set<uint64_t> &Bits,
164                                 uint64_t BitSize, uint64_t &AllocByteOffset,
165                                 uint8_t &AllocMask) {
166   // Find the smallest current allocation.
167   unsigned Bit = 0;
168   for (unsigned I = 1; I != BitsPerByte; ++I)
169     if (BitAllocs[I] < BitAllocs[Bit])
170       Bit = I;
171
172   AllocByteOffset = BitAllocs[Bit];
173
174   // Add our size to it.
175   unsigned ReqSize = AllocByteOffset + BitSize;
176   BitAllocs[Bit] = ReqSize;
177   if (Bytes.size() < ReqSize)
178     Bytes.resize(ReqSize);
179
180   // Set our bits.
181   AllocMask = 1 << Bit;
182   for (uint64_t B : Bits)
183     Bytes[AllocByteOffset + B] |= AllocMask;
184 }
185
186 namespace {
187
188 struct ByteArrayInfo {
189   std::set<uint64_t> Bits;
190   uint64_t BitSize;
191   GlobalVariable *ByteArray;
192   GlobalVariable *MaskGlobal;
193 };
194
195 /// A POD-like structure that we use to store a global reference together with
196 /// its metadata types. In this pass we frequently need to query the set of
197 /// metadata types referenced by a global, which at the IR level is an expensive
198 /// operation involving a map lookup; this data structure helps to reduce the
199 /// number of times we need to do this lookup.
200 class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
201   GlobalObject *GO;
202   size_t NTypes;
203
204   friend TrailingObjects;
205   size_t numTrailingObjects(OverloadToken<MDNode *>) const { return NTypes; }
206
207 public:
208   static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO,
209                                   ArrayRef<MDNode *> Types) {
210     auto *GTM = static_cast<GlobalTypeMember *>(Alloc.Allocate(
211         totalSizeToAlloc<MDNode *>(Types.size()), alignof(GlobalTypeMember)));
212     GTM->GO = GO;
213     GTM->NTypes = Types.size();
214     std::uninitialized_copy(Types.begin(), Types.end(),
215                             GTM->getTrailingObjects<MDNode *>());
216     return GTM;
217   }
218   GlobalObject *getGlobal() const {
219     return GO;
220   }
221   ArrayRef<MDNode *> types() const {
222     return makeArrayRef(getTrailingObjects<MDNode *>(), NTypes);
223   }
224 };
225
226 class LowerTypeTestsModule {
227   Module &M;
228
229   // This is for testing purposes only.
230   std::unique_ptr<ModuleSummaryIndex> OwnedSummary;
231
232   bool LinkerSubsectionsViaSymbols;
233   Triple::ArchType Arch;
234   Triple::OSType OS;
235   Triple::ObjectFormatType ObjectFormat;
236
237   IntegerType *Int1Ty = Type::getInt1Ty(M.getContext());
238   IntegerType *Int8Ty = Type::getInt8Ty(M.getContext());
239   PointerType *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
240   IntegerType *Int32Ty = Type::getInt32Ty(M.getContext());
241   PointerType *Int32PtrTy = PointerType::getUnqual(Int32Ty);
242   IntegerType *Int64Ty = Type::getInt64Ty(M.getContext());
243   IntegerType *IntPtrTy = M.getDataLayout().getIntPtrType(M.getContext(), 0);
244
245   // Indirect function call index assignment counter for WebAssembly
246   uint64_t IndirectIndex = 1;
247
248   // Mapping from type identifiers to the call sites that test them.
249   DenseMap<Metadata *, std::vector<CallInst *>> TypeTestCallSites;
250
251   /// This structure describes how to lower type tests for a particular type
252   /// identifier. It is either built directly from the global analysis (during
253   /// regular LTO or the regular LTO phase of ThinLTO), or indirectly using type
254   /// identifier summaries and external symbol references (in ThinLTO backends).
255   struct TypeIdLowering {
256     TypeTestResolution::Kind TheKind;
257
258     /// All except Unsat: the start address within the combined global.
259     Constant *OffsetedGlobal;
260
261     /// ByteArray, Inline, AllOnes: log2 of the required global alignment
262     /// relative to the start address.
263     Constant *AlignLog2;
264
265     /// ByteArray, Inline, AllOnes: size of the memory region covering members
266     /// of this type identifier as a multiple of 2^AlignLog2.
267     Constant *Size;
268
269     /// ByteArray, Inline, AllOnes: range of the size expressed as a bit width.
270     unsigned SizeBitWidth;
271
272     /// ByteArray: the byte array to test the address against.
273     Constant *TheByteArray;
274
275     /// ByteArray: the bit mask to apply to bytes loaded from the byte array.
276     Constant *BitMask;
277
278     /// Inline: the bit mask to test the address against.
279     Constant *InlineBits;
280   };
281
282   std::vector<ByteArrayInfo> ByteArrayInfos;
283
284   Function *WeakInitializerFn = nullptr;
285
286   BitSetInfo
287   buildBitSet(Metadata *TypeId,
288               const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout);
289   ByteArrayInfo *createByteArray(BitSetInfo &BSI);
290   void allocateByteArrays();
291   Value *createBitSetTest(IRBuilder<> &B, const TypeIdLowering &TIL,
292                           Value *BitOffset);
293   void lowerTypeTestCalls(
294       ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
295       const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout);
296   Value *lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
297                            const TypeIdLowering &TIL);
298   void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds,
299                                        ArrayRef<GlobalTypeMember *> Globals);
300   unsigned getJumpTableEntrySize();
301   Type *getJumpTableEntryType();
302   void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS,
303                             SmallVectorImpl<Value *> &AsmArgs, Function *Dest);
304   void verifyTypeMDNode(GlobalObject *GO, MDNode *Type);
305   void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds,
306                                  ArrayRef<GlobalTypeMember *> Functions);
307   void buildBitSetsFromFunctionsNative(ArrayRef<Metadata *> TypeIds,
308                                     ArrayRef<GlobalTypeMember *> Functions);
309   void buildBitSetsFromFunctionsWASM(ArrayRef<Metadata *> TypeIds,
310                                      ArrayRef<GlobalTypeMember *> Functions);
311   void buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds,
312                                    ArrayRef<GlobalTypeMember *> Globals);
313
314   void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT);
315   void moveInitializerToModuleConstructor(GlobalVariable *GV);
316   void findGlobalVariableUsersOf(Constant *C,
317                                  SmallSetVector<GlobalVariable *, 8> &Out);
318
319   void createJumpTable(Function *F, ArrayRef<GlobalTypeMember *> Functions);
320
321 public:
322   LowerTypeTestsModule(Module &M);
323   ~LowerTypeTestsModule();
324   bool lower();
325 };
326
327 struct LowerTypeTests : public ModulePass {
328   static char ID;
329   LowerTypeTests() : ModulePass(ID) {
330     initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry());
331   }
332
333   bool runOnModule(Module &M) override {
334     if (skipModule(M))
335       return false;
336     return LowerTypeTestsModule(M).lower();
337   }
338 };
339
340 } // anonymous namespace
341
342 INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false,
343                 false)
344 char LowerTypeTests::ID = 0;
345
346 ModulePass *llvm::createLowerTypeTestsPass() { return new LowerTypeTests; }
347
348 /// Build a bit set for TypeId using the object layouts in
349 /// GlobalLayout.
350 BitSetInfo LowerTypeTestsModule::buildBitSet(
351     Metadata *TypeId,
352     const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) {
353   BitSetBuilder BSB;
354
355   // Compute the byte offset of each address associated with this type
356   // identifier.
357   for (auto &GlobalAndOffset : GlobalLayout) {
358     for (MDNode *Type : GlobalAndOffset.first->types()) {
359       if (Type->getOperand(1) != TypeId)
360         continue;
361       uint64_t Offset =
362           cast<ConstantInt>(
363               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
364               ->getZExtValue();
365       BSB.addOffset(GlobalAndOffset.second + Offset);
366     }
367   }
368
369   return BSB.build();
370 }
371
372 /// Build a test that bit BitOffset mod sizeof(Bits)*8 is set in
373 /// Bits. This pattern matches to the bt instruction on x86.
374 static Value *createMaskedBitTest(IRBuilder<> &B, Value *Bits,
375                                   Value *BitOffset) {
376   auto BitsType = cast<IntegerType>(Bits->getType());
377   unsigned BitWidth = BitsType->getBitWidth();
378
379   BitOffset = B.CreateZExtOrTrunc(BitOffset, BitsType);
380   Value *BitIndex =
381       B.CreateAnd(BitOffset, ConstantInt::get(BitsType, BitWidth - 1));
382   Value *BitMask = B.CreateShl(ConstantInt::get(BitsType, 1), BitIndex);
383   Value *MaskedBits = B.CreateAnd(Bits, BitMask);
384   return B.CreateICmpNE(MaskedBits, ConstantInt::get(BitsType, 0));
385 }
386
387 ByteArrayInfo *LowerTypeTestsModule::createByteArray(BitSetInfo &BSI) {
388   // Create globals to stand in for byte arrays and masks. These never actually
389   // get initialized, we RAUW and erase them later in allocateByteArrays() once
390   // we know the offset and mask to use.
391   auto ByteArrayGlobal = new GlobalVariable(
392       M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr);
393   auto MaskGlobal = new GlobalVariable(M, Int8Ty, /*isConstant=*/true,
394                                        GlobalValue::PrivateLinkage, nullptr);
395
396   ByteArrayInfos.emplace_back();
397   ByteArrayInfo *BAI = &ByteArrayInfos.back();
398
399   BAI->Bits = BSI.Bits;
400   BAI->BitSize = BSI.BitSize;
401   BAI->ByteArray = ByteArrayGlobal;
402   BAI->MaskGlobal = MaskGlobal;
403   return BAI;
404 }
405
406 void LowerTypeTestsModule::allocateByteArrays() {
407   std::stable_sort(ByteArrayInfos.begin(), ByteArrayInfos.end(),
408                    [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) {
409                      return BAI1.BitSize > BAI2.BitSize;
410                    });
411
412   std::vector<uint64_t> ByteArrayOffsets(ByteArrayInfos.size());
413
414   ByteArrayBuilder BAB;
415   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
416     ByteArrayInfo *BAI = &ByteArrayInfos[I];
417
418     uint8_t Mask;
419     BAB.allocate(BAI->Bits, BAI->BitSize, ByteArrayOffsets[I], Mask);
420
421     BAI->MaskGlobal->replaceAllUsesWith(
422         ConstantExpr::getIntToPtr(ConstantInt::get(Int8Ty, Mask), Int8PtrTy));
423     BAI->MaskGlobal->eraseFromParent();
424   }
425
426   Constant *ByteArrayConst = ConstantDataArray::get(M.getContext(), BAB.Bytes);
427   auto ByteArray =
428       new GlobalVariable(M, ByteArrayConst->getType(), /*isConstant=*/true,
429                          GlobalValue::PrivateLinkage, ByteArrayConst);
430
431   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
432     ByteArrayInfo *BAI = &ByteArrayInfos[I];
433
434     Constant *Idxs[] = {ConstantInt::get(IntPtrTy, 0),
435                         ConstantInt::get(IntPtrTy, ByteArrayOffsets[I])};
436     Constant *GEP = ConstantExpr::getInBoundsGetElementPtr(
437         ByteArrayConst->getType(), ByteArray, Idxs);
438
439     // Create an alias instead of RAUW'ing the gep directly. On x86 this ensures
440     // that the pc-relative displacement is folded into the lea instead of the
441     // test instruction getting another displacement.
442     if (LinkerSubsectionsViaSymbols) {
443       BAI->ByteArray->replaceAllUsesWith(GEP);
444     } else {
445       GlobalAlias *Alias = GlobalAlias::create(
446           Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, &M);
447       BAI->ByteArray->replaceAllUsesWith(Alias);
448     }
449     BAI->ByteArray->eraseFromParent();
450   }
451
452   ByteArraySizeBits = BAB.BitAllocs[0] + BAB.BitAllocs[1] + BAB.BitAllocs[2] +
453                       BAB.BitAllocs[3] + BAB.BitAllocs[4] + BAB.BitAllocs[5] +
454                       BAB.BitAllocs[6] + BAB.BitAllocs[7];
455   ByteArraySizeBytes = BAB.Bytes.size();
456 }
457
458 /// Build a test that bit BitOffset is set in the type identifier that was
459 /// lowered to TIL, which must be either an Inline or a ByteArray.
460 Value *LowerTypeTestsModule::createBitSetTest(IRBuilder<> &B,
461                                               const TypeIdLowering &TIL,
462                                               Value *BitOffset) {
463   if (TIL.TheKind == TypeTestResolution::Inline) {
464     // If the bit set is sufficiently small, we can avoid a load by bit testing
465     // a constant.
466     return createMaskedBitTest(B, TIL.InlineBits, BitOffset);
467   } else {
468     Constant *ByteArray = TIL.TheByteArray;
469     if (!LinkerSubsectionsViaSymbols && AvoidReuse) {
470       // Each use of the byte array uses a different alias. This makes the
471       // backend less likely to reuse previously computed byte array addresses,
472       // improving the security of the CFI mechanism based on this pass.
473       ByteArray = GlobalAlias::create(Int8Ty, 0, GlobalValue::PrivateLinkage,
474                                       "bits_use", ByteArray, &M);
475     }
476
477     Value *ByteAddr = B.CreateGEP(Int8Ty, ByteArray, BitOffset);
478     Value *Byte = B.CreateLoad(ByteAddr);
479
480     Value *ByteAndMask =
481         B.CreateAnd(Byte, ConstantExpr::getPtrToInt(TIL.BitMask, Int8Ty));
482     return B.CreateICmpNE(ByteAndMask, ConstantInt::get(Int8Ty, 0));
483   }
484 }
485
486 static bool isKnownTypeIdMember(Metadata *TypeId, const DataLayout &DL,
487                                 Value *V, uint64_t COffset) {
488   if (auto GV = dyn_cast<GlobalObject>(V)) {
489     SmallVector<MDNode *, 2> Types;
490     GV->getMetadata(LLVMContext::MD_type, Types);
491     for (MDNode *Type : Types) {
492       if (Type->getOperand(1) != TypeId)
493         continue;
494       uint64_t Offset =
495           cast<ConstantInt>(
496               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
497               ->getZExtValue();
498       if (COffset == Offset)
499         return true;
500     }
501     return false;
502   }
503
504   if (auto GEP = dyn_cast<GEPOperator>(V)) {
505     APInt APOffset(DL.getPointerSizeInBits(0), 0);
506     bool Result = GEP->accumulateConstantOffset(DL, APOffset);
507     if (!Result)
508       return false;
509     COffset += APOffset.getZExtValue();
510     return isKnownTypeIdMember(TypeId, DL, GEP->getPointerOperand(), COffset);
511   }
512
513   if (auto Op = dyn_cast<Operator>(V)) {
514     if (Op->getOpcode() == Instruction::BitCast)
515       return isKnownTypeIdMember(TypeId, DL, Op->getOperand(0), COffset);
516
517     if (Op->getOpcode() == Instruction::Select)
518       return isKnownTypeIdMember(TypeId, DL, Op->getOperand(1), COffset) &&
519              isKnownTypeIdMember(TypeId, DL, Op->getOperand(2), COffset);
520   }
521
522   return false;
523 }
524
525 /// Lower a llvm.type.test call to its implementation. Returns the value to
526 /// replace the call with.
527 Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
528                                                const TypeIdLowering &TIL) {
529   if (TIL.TheKind == TypeTestResolution::Unsat)
530     return ConstantInt::getFalse(M.getContext());
531
532   Value *Ptr = CI->getArgOperand(0);
533   const DataLayout &DL = M.getDataLayout();
534   if (isKnownTypeIdMember(TypeId, DL, Ptr, 0))
535     return ConstantInt::getTrue(M.getContext());
536
537   BasicBlock *InitialBB = CI->getParent();
538
539   IRBuilder<> B(CI);
540
541   Value *PtrAsInt = B.CreatePtrToInt(Ptr, IntPtrTy);
542
543   Constant *OffsetedGlobalAsInt =
544       ConstantExpr::getPtrToInt(TIL.OffsetedGlobal, IntPtrTy);
545   if (TIL.TheKind == TypeTestResolution::Single)
546     return B.CreateICmpEQ(PtrAsInt, OffsetedGlobalAsInt);
547
548   Value *PtrOffset = B.CreateSub(PtrAsInt, OffsetedGlobalAsInt);
549
550   // We need to check that the offset both falls within our range and is
551   // suitably aligned. We can check both properties at the same time by
552   // performing a right rotate by log2(alignment) followed by an integer
553   // comparison against the bitset size. The rotate will move the lower
554   // order bits that need to be zero into the higher order bits of the
555   // result, causing the comparison to fail if they are nonzero. The rotate
556   // also conveniently gives us a bit offset to use during the load from
557   // the bitset.
558   Value *OffsetSHR =
559       B.CreateLShr(PtrOffset, ConstantExpr::getZExt(TIL.AlignLog2, IntPtrTy));
560   Value *OffsetSHL = B.CreateShl(
561       PtrOffset, ConstantExpr::getZExt(
562                      ConstantExpr::getSub(
563                          ConstantInt::get(Int8Ty, DL.getPointerSizeInBits(0)),
564                          TIL.AlignLog2),
565                      IntPtrTy));
566   Value *BitOffset = B.CreateOr(OffsetSHR, OffsetSHL);
567
568   Constant *BitSizeConst = ConstantExpr::getZExt(TIL.Size, IntPtrTy);
569   Value *OffsetInRange = B.CreateICmpULT(BitOffset, BitSizeConst);
570
571   // If the bit set is all ones, testing against it is unnecessary.
572   if (TIL.TheKind == TypeTestResolution::AllOnes)
573     return OffsetInRange;
574
575   TerminatorInst *Term = SplitBlockAndInsertIfThen(OffsetInRange, CI, false);
576   IRBuilder<> ThenB(Term);
577
578   // Now that we know that the offset is in range and aligned, load the
579   // appropriate bit from the bitset.
580   Value *Bit = createBitSetTest(ThenB, TIL, BitOffset);
581
582   // The value we want is 0 if we came directly from the initial block
583   // (having failed the range or alignment checks), or the loaded bit if
584   // we came from the block in which we loaded it.
585   B.SetInsertPoint(CI);
586   PHINode *P = B.CreatePHI(Int1Ty, 2);
587   P->addIncoming(ConstantInt::get(Int1Ty, 0), InitialBB);
588   P->addIncoming(Bit, ThenB.GetInsertBlock());
589   return P;
590 }
591
592 /// Given a disjoint set of type identifiers and globals, lay out the globals,
593 /// build the bit sets and lower the llvm.type.test calls.
594 void LowerTypeTestsModule::buildBitSetsFromGlobalVariables(
595     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals) {
596   // Build a new global with the combined contents of the referenced globals.
597   // This global is a struct whose even-indexed elements contain the original
598   // contents of the referenced globals and whose odd-indexed elements contain
599   // any padding required to align the next element to the next power of 2.
600   std::vector<Constant *> GlobalInits;
601   const DataLayout &DL = M.getDataLayout();
602   for (GlobalTypeMember *G : Globals) {
603     GlobalVariable *GV = cast<GlobalVariable>(G->getGlobal());
604     GlobalInits.push_back(GV->getInitializer());
605     uint64_t InitSize = DL.getTypeAllocSize(GV->getValueType());
606
607     // Compute the amount of padding required.
608     uint64_t Padding = NextPowerOf2(InitSize - 1) - InitSize;
609
610     // Cap at 128 was found experimentally to have a good data/instruction
611     // overhead tradeoff.
612     if (Padding > 128)
613       Padding = alignTo(InitSize, 128) - InitSize;
614
615     GlobalInits.push_back(
616         ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding)));
617   }
618   if (!GlobalInits.empty())
619     GlobalInits.pop_back();
620   Constant *NewInit = ConstantStruct::getAnon(M.getContext(), GlobalInits);
621   auto *CombinedGlobal =
622       new GlobalVariable(M, NewInit->getType(), /*isConstant=*/true,
623                          GlobalValue::PrivateLinkage, NewInit);
624
625   StructType *NewTy = cast<StructType>(NewInit->getType());
626   const StructLayout *CombinedGlobalLayout = DL.getStructLayout(NewTy);
627
628   // Compute the offsets of the original globals within the new global.
629   DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
630   for (unsigned I = 0; I != Globals.size(); ++I)
631     // Multiply by 2 to account for padding elements.
632     GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2);
633
634   lowerTypeTestCalls(TypeIds, CombinedGlobal, GlobalLayout);
635
636   // Build aliases pointing to offsets into the combined global for each
637   // global from which we built the combined global, and replace references
638   // to the original globals with references to the aliases.
639   for (unsigned I = 0; I != Globals.size(); ++I) {
640     GlobalVariable *GV = cast<GlobalVariable>(Globals[I]->getGlobal());
641
642     // Multiply by 2 to account for padding elements.
643     Constant *CombinedGlobalIdxs[] = {ConstantInt::get(Int32Ty, 0),
644                                       ConstantInt::get(Int32Ty, I * 2)};
645     Constant *CombinedGlobalElemPtr = ConstantExpr::getGetElementPtr(
646         NewInit->getType(), CombinedGlobal, CombinedGlobalIdxs);
647     if (LinkerSubsectionsViaSymbols) {
648       GV->replaceAllUsesWith(CombinedGlobalElemPtr);
649     } else {
650       assert(GV->getType()->getAddressSpace() == 0);
651       GlobalAlias *GAlias = GlobalAlias::create(NewTy->getElementType(I * 2), 0,
652                                                 GV->getLinkage(), "",
653                                                 CombinedGlobalElemPtr, &M);
654       GAlias->setVisibility(GV->getVisibility());
655       GAlias->takeName(GV);
656       GV->replaceAllUsesWith(GAlias);
657     }
658     GV->eraseFromParent();
659   }
660 }
661
662 void LowerTypeTestsModule::lowerTypeTestCalls(
663     ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
664     const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) {
665   CombinedGlobalAddr = ConstantExpr::getBitCast(CombinedGlobalAddr, Int8PtrTy);
666
667   // For each type identifier in this disjoint set...
668   for (Metadata *TypeId : TypeIds) {
669     // Build the bitset.
670     BitSetInfo BSI = buildBitSet(TypeId, GlobalLayout);
671     DEBUG({
672       if (auto MDS = dyn_cast<MDString>(TypeId))
673         dbgs() << MDS->getString() << ": ";
674       else
675         dbgs() << "<unnamed>: ";
676       BSI.print(dbgs());
677     });
678
679     TypeIdLowering TIL;
680     TIL.OffsetedGlobal = ConstantExpr::getGetElementPtr(
681         Int8Ty, CombinedGlobalAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset)),
682     TIL.AlignLog2 = ConstantInt::get(Int8Ty, BSI.AlignLog2);
683     if (BSI.isAllOnes()) {
684       TIL.TheKind = (BSI.BitSize == 1) ? TypeTestResolution::Single
685                                        : TypeTestResolution::AllOnes;
686       TIL.SizeBitWidth = (BSI.BitSize <= 256) ? 8 : 32;
687       TIL.Size = ConstantInt::get((BSI.BitSize <= 256) ? Int8Ty : Int32Ty,
688                                   BSI.BitSize);
689     } else if (BSI.BitSize <= 64) {
690       TIL.TheKind = TypeTestResolution::Inline;
691       TIL.SizeBitWidth = (BSI.BitSize <= 32) ? 5 : 6;
692       TIL.Size = ConstantInt::get(Int8Ty, BSI.BitSize);
693       uint64_t InlineBits = 0;
694       for (auto Bit : BSI.Bits)
695         InlineBits |= uint64_t(1) << Bit;
696       if (InlineBits == 0)
697         TIL.TheKind = TypeTestResolution::Unsat;
698       else
699         TIL.InlineBits = ConstantInt::get(
700             (BSI.BitSize <= 32) ? Int32Ty : Int64Ty, InlineBits);
701     } else {
702       TIL.TheKind = TypeTestResolution::ByteArray;
703       TIL.SizeBitWidth = (BSI.BitSize <= 256) ? 8 : 32;
704       TIL.Size = ConstantInt::get((BSI.BitSize <= 256) ? Int8Ty : Int32Ty,
705                                   BSI.BitSize);
706       ++NumByteArraysCreated;
707       ByteArrayInfo *BAI = createByteArray(BSI);
708       TIL.TheByteArray = BAI->ByteArray;
709       TIL.BitMask = BAI->MaskGlobal;
710     }
711
712     // Lower each call to llvm.type.test for this type identifier.
713     for (CallInst *CI : TypeTestCallSites[TypeId]) {
714       ++NumTypeTestCallsLowered;
715       Value *Lowered = lowerTypeTestCall(TypeId, CI, TIL);
716       CI->replaceAllUsesWith(Lowered);
717       CI->eraseFromParent();
718     }
719   }
720 }
721
722 void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) {
723   if (Type->getNumOperands() != 2)
724     report_fatal_error("All operands of type metadata must have 2 elements");
725
726   if (GO->isThreadLocal())
727     report_fatal_error("Bit set element may not be thread-local");
728   if (isa<GlobalVariable>(GO) && GO->hasSection())
729     report_fatal_error(
730         "A member of a type identifier may not have an explicit section");
731
732   if (isa<GlobalVariable>(GO) && GO->isDeclarationForLinker())
733     report_fatal_error(
734         "A global var member of a type identifier must be a definition");
735
736   auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Type->getOperand(0));
737   if (!OffsetConstMD)
738     report_fatal_error("Type offset must be a constant");
739   auto OffsetInt = dyn_cast<ConstantInt>(OffsetConstMD->getValue());
740   if (!OffsetInt)
741     report_fatal_error("Type offset must be an integer constant");
742 }
743
744 static const unsigned kX86JumpTableEntrySize = 8;
745 static const unsigned kARMJumpTableEntrySize = 4;
746
747 unsigned LowerTypeTestsModule::getJumpTableEntrySize() {
748   switch (Arch) {
749     case Triple::x86:
750     case Triple::x86_64:
751       return kX86JumpTableEntrySize;
752     case Triple::arm:
753     case Triple::thumb:
754     case Triple::aarch64:
755       return kARMJumpTableEntrySize;
756     default:
757       report_fatal_error("Unsupported architecture for jump tables");
758   }
759 }
760
761 // Create a jump table entry for the target. This consists of an instruction
762 // sequence containing a relative branch to Dest. Appends inline asm text,
763 // constraints and arguments to AsmOS, ConstraintOS and AsmArgs.
764 void LowerTypeTestsModule::createJumpTableEntry(
765     raw_ostream &AsmOS, raw_ostream &ConstraintOS,
766     SmallVectorImpl<Value *> &AsmArgs, Function *Dest) {
767   unsigned ArgIndex = AsmArgs.size();
768
769   if (Arch == Triple::x86 || Arch == Triple::x86_64) {
770     AsmOS << "jmp ${" << ArgIndex << ":c}@plt\n";
771     AsmOS << "int3\nint3\nint3\n";
772   } else if (Arch == Triple::arm || Arch == Triple::aarch64) {
773     AsmOS << "b $" << ArgIndex << "\n";
774   } else if (Arch == Triple::thumb) {
775     AsmOS << "b.w $" << ArgIndex << "\n";
776   } else {
777     report_fatal_error("Unsupported architecture for jump tables");
778   }
779
780   ConstraintOS << (ArgIndex > 0 ? ",s" : "s");
781   AsmArgs.push_back(Dest);
782 }
783
784 Type *LowerTypeTestsModule::getJumpTableEntryType() {
785   return ArrayType::get(Int8Ty, getJumpTableEntrySize());
786 }
787
788 /// Given a disjoint set of type identifiers and functions, build the bit sets
789 /// and lower the llvm.type.test calls, architecture dependently.
790 void LowerTypeTestsModule::buildBitSetsFromFunctions(
791     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) {
792   if (Arch == Triple::x86 || Arch == Triple::x86_64 || Arch == Triple::arm ||
793       Arch == Triple::thumb || Arch == Triple::aarch64)
794     buildBitSetsFromFunctionsNative(TypeIds, Functions);
795   else if (Arch == Triple::wasm32 || Arch == Triple::wasm64)
796     buildBitSetsFromFunctionsWASM(TypeIds, Functions);
797   else
798     report_fatal_error("Unsupported architecture for jump tables");
799 }
800
801 void LowerTypeTestsModule::moveInitializerToModuleConstructor(
802     GlobalVariable *GV) {
803   if (WeakInitializerFn == nullptr) {
804     WeakInitializerFn = Function::Create(
805         FunctionType::get(Type::getVoidTy(M.getContext()),
806                           /* IsVarArg */ false),
807         GlobalValue::InternalLinkage, "__cfi_global_var_init", &M);
808     BasicBlock *BB =
809         BasicBlock::Create(M.getContext(), "entry", WeakInitializerFn);
810     ReturnInst::Create(M.getContext(), BB);
811     WeakInitializerFn->setSection(
812         ObjectFormat == Triple::MachO
813             ? "__TEXT,__StaticInit,regular,pure_instructions"
814             : ".text.startup");
815     // This code is equivalent to relocation application, and should run at the
816     // earliest possible time (i.e. with the highest priority).
817     appendToGlobalCtors(M, WeakInitializerFn, /* Priority */ 0);
818   }
819
820   IRBuilder<> IRB(WeakInitializerFn->getEntryBlock().getTerminator());
821   GV->setConstant(false);
822   IRB.CreateAlignedStore(GV->getInitializer(), GV, GV->getAlignment());
823   GV->setInitializer(Constant::getNullValue(GV->getValueType()));
824 }
825
826 void LowerTypeTestsModule::findGlobalVariableUsersOf(
827     Constant *C, SmallSetVector<GlobalVariable *, 8> &Out) {
828   for (auto *U : C->users()){
829     if (auto *GV = dyn_cast<GlobalVariable>(U))
830       Out.insert(GV);
831     else if (auto *C2 = dyn_cast<Constant>(U))
832       findGlobalVariableUsersOf(C2, Out);
833   }
834 }
835
836 // Replace all uses of F with (F ? JT : 0).
837 void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr(
838     Function *F, Constant *JT) {
839   // The target expression can not appear in a constant initializer on most
840   // (all?) targets. Switch to a runtime initializer.
841   SmallSetVector<GlobalVariable *, 8> GlobalVarUsers;
842   findGlobalVariableUsersOf(F, GlobalVarUsers);
843   for (auto GV : GlobalVarUsers)
844     moveInitializerToModuleConstructor(GV);
845
846   // Can not RAUW F with an expression that uses F. Replace with a temporary
847   // placeholder first.
848   Function *PlaceholderFn =
849       Function::Create(cast<FunctionType>(F->getValueType()),
850                        GlobalValue::ExternalWeakLinkage, "", &M);
851   F->replaceAllUsesWith(PlaceholderFn);
852
853   Constant *Target = ConstantExpr::getSelect(
854       ConstantExpr::getICmp(CmpInst::ICMP_NE, F,
855                             Constant::getNullValue(F->getType())),
856       JT, Constant::getNullValue(F->getType()));
857   PlaceholderFn->replaceAllUsesWith(Target);
858   PlaceholderFn->eraseFromParent();
859 }
860
861 void LowerTypeTestsModule::createJumpTable(
862     Function *F, ArrayRef<GlobalTypeMember *> Functions) {
863   std::string AsmStr, ConstraintStr;
864   raw_string_ostream AsmOS(AsmStr), ConstraintOS(ConstraintStr);
865   SmallVector<Value *, 16> AsmArgs;
866   AsmArgs.reserve(Functions.size() * 2);
867
868   for (unsigned I = 0; I != Functions.size(); ++I)
869     createJumpTableEntry(AsmOS, ConstraintOS, AsmArgs,
870                          cast<Function>(Functions[I]->getGlobal()));
871
872   // Try to emit the jump table at the end of the text segment.
873   // Jump table must come after __cfi_check in the cross-dso mode.
874   // FIXME: this magic section name seems to do the trick.
875   F->setSection(ObjectFormat == Triple::MachO
876                     ? "__TEXT,__text,regular,pure_instructions"
877                     : ".text.cfi");
878   // Align the whole table by entry size.
879   F->setAlignment(getJumpTableEntrySize());
880   // Skip prologue.
881   // Disabled on win32 due to https://llvm.org/bugs/show_bug.cgi?id=28641#c3.
882   // Luckily, this function does not get any prologue even without the
883   // attribute.
884   if (OS != Triple::Win32)
885     F->addFnAttr(llvm::Attribute::Naked);
886   // Thumb jump table assembly needs Thumb2. The following attribute is added by
887   // Clang for -march=armv7.
888   if (Arch == Triple::thumb)
889     F->addFnAttr("target-cpu", "cortex-a8");
890
891   BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F);
892   IRBuilder<> IRB(BB);
893
894   SmallVector<Type *, 16> ArgTypes;
895   ArgTypes.reserve(AsmArgs.size());
896   for (const auto &Arg : AsmArgs)
897     ArgTypes.push_back(Arg->getType());
898   InlineAsm *JumpTableAsm =
899       InlineAsm::get(FunctionType::get(IRB.getVoidTy(), ArgTypes, false),
900                      AsmOS.str(), ConstraintOS.str(),
901                      /*hasSideEffects=*/true);
902
903   IRB.CreateCall(JumpTableAsm, AsmArgs);
904   IRB.CreateUnreachable();
905 }
906
907 /// Given a disjoint set of type identifiers and functions, build a jump table
908 /// for the functions, build the bit sets and lower the llvm.type.test calls.
909 void LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
910     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) {
911   // Unlike the global bitset builder, the function bitset builder cannot
912   // re-arrange functions in a particular order and base its calculations on the
913   // layout of the functions' entry points, as we have no idea how large a
914   // particular function will end up being (the size could even depend on what
915   // this pass does!) Instead, we build a jump table, which is a block of code
916   // consisting of one branch instruction for each of the functions in the bit
917   // set that branches to the target function, and redirect any taken function
918   // addresses to the corresponding jump table entry. In the object file's
919   // symbol table, the symbols for the target functions also refer to the jump
920   // table entries, so that addresses taken outside the module will pass any
921   // verification done inside the module.
922   //
923   // In more concrete terms, suppose we have three functions f, g, h which are
924   // of the same type, and a function foo that returns their addresses:
925   //
926   // f:
927   // mov 0, %eax
928   // ret
929   //
930   // g:
931   // mov 1, %eax
932   // ret
933   //
934   // h:
935   // mov 2, %eax
936   // ret
937   //
938   // foo:
939   // mov f, %eax
940   // mov g, %edx
941   // mov h, %ecx
942   // ret
943   //
944   // We output the jump table as module-level inline asm string. The end result
945   // will (conceptually) look like this:
946   //
947   // f = .cfi.jumptable
948   // g = .cfi.jumptable + 4
949   // h = .cfi.jumptable + 8
950   // .cfi.jumptable:
951   // jmp f.cfi  ; 5 bytes
952   // int3       ; 1 byte
953   // int3       ; 1 byte
954   // int3       ; 1 byte
955   // jmp g.cfi  ; 5 bytes
956   // int3       ; 1 byte
957   // int3       ; 1 byte
958   // int3       ; 1 byte
959   // jmp h.cfi  ; 5 bytes
960   // int3       ; 1 byte
961   // int3       ; 1 byte
962   // int3       ; 1 byte
963   //
964   // f.cfi:
965   // mov 0, %eax
966   // ret
967   //
968   // g.cfi:
969   // mov 1, %eax
970   // ret
971   //
972   // h.cfi:
973   // mov 2, %eax
974   // ret
975   //
976   // foo:
977   // mov f, %eax
978   // mov g, %edx
979   // mov h, %ecx
980   // ret
981   //
982   // Because the addresses of f, g, h are evenly spaced at a power of 2, in the
983   // normal case the check can be carried out using the same kind of simple
984   // arithmetic that we normally use for globals.
985
986   // FIXME: find a better way to represent the jumptable in the IR.
987
988   assert(!Functions.empty());
989
990   // Build a simple layout based on the regular layout of jump tables.
991   DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
992   unsigned EntrySize = getJumpTableEntrySize();
993   for (unsigned I = 0; I != Functions.size(); ++I)
994     GlobalLayout[Functions[I]] = I * EntrySize;
995
996   Function *JumpTableFn =
997       Function::Create(FunctionType::get(Type::getVoidTy(M.getContext()),
998                                          /* IsVarArg */ false),
999                        GlobalValue::PrivateLinkage, ".cfi.jumptable", &M);
1000   ArrayType *JumpTableType =
1001       ArrayType::get(getJumpTableEntryType(), Functions.size());
1002   auto JumpTable =
1003       ConstantExpr::getPointerCast(JumpTableFn, JumpTableType->getPointerTo(0));
1004
1005   lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout);
1006
1007   // Build aliases pointing to offsets into the jump table, and replace
1008   // references to the original functions with references to the aliases.
1009   for (unsigned I = 0; I != Functions.size(); ++I) {
1010     Function *F = cast<Function>(Functions[I]->getGlobal());
1011
1012     Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast(
1013         ConstantExpr::getInBoundsGetElementPtr(
1014             JumpTableType, JumpTable,
1015             ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0),
1016                                  ConstantInt::get(IntPtrTy, I)}),
1017         F->getType());
1018     if (LinkerSubsectionsViaSymbols || F->isDeclarationForLinker()) {
1019
1020       if (F->isWeakForLinker())
1021         replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr);
1022       else
1023         F->replaceAllUsesWith(CombinedGlobalElemPtr);
1024     } else {
1025       assert(F->getType()->getAddressSpace() == 0);
1026
1027       GlobalAlias *FAlias = GlobalAlias::create(F->getValueType(), 0,
1028                                                 F->getLinkage(), "",
1029                                                 CombinedGlobalElemPtr, &M);
1030       FAlias->setVisibility(F->getVisibility());
1031       FAlias->takeName(F);
1032       if (FAlias->hasName())
1033         F->setName(FAlias->getName() + ".cfi");
1034       F->replaceAllUsesWith(FAlias);
1035     }
1036     if (!F->isDeclarationForLinker())
1037       F->setLinkage(GlobalValue::InternalLinkage);
1038   }
1039
1040   createJumpTable(JumpTableFn, Functions);
1041 }
1042
1043 /// Assign a dummy layout using an incrementing counter, tag each function
1044 /// with its index represented as metadata, and lower each type test to an
1045 /// integer range comparison. During generation of the indirect function call
1046 /// table in the backend, it will assign the given indexes.
1047 /// Note: Dynamic linking is not supported, as the WebAssembly ABI has not yet
1048 /// been finalized.
1049 void LowerTypeTestsModule::buildBitSetsFromFunctionsWASM(
1050     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) {
1051   assert(!Functions.empty());
1052
1053   // Build consecutive monotonic integer ranges for each call target set
1054   DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
1055
1056   for (GlobalTypeMember *GTM : Functions) {
1057     Function *F = cast<Function>(GTM->getGlobal());
1058
1059     // Skip functions that are not address taken, to avoid bloating the table
1060     if (!F->hasAddressTaken())
1061       continue;
1062
1063     // Store metadata with the index for each function
1064     MDNode *MD = MDNode::get(F->getContext(),
1065                              ArrayRef<Metadata *>(ConstantAsMetadata::get(
1066                                  ConstantInt::get(Int64Ty, IndirectIndex))));
1067     F->setMetadata("wasm.index", MD);
1068
1069     // Assign the counter value
1070     GlobalLayout[GTM] = IndirectIndex++;
1071   }
1072
1073   // The indirect function table index space starts at zero, so pass a NULL
1074   // pointer as the subtracted "jump table" offset.
1075   lowerTypeTestCalls(TypeIds, ConstantPointerNull::get(Int32PtrTy),
1076                      GlobalLayout);
1077 }
1078
1079 void LowerTypeTestsModule::buildBitSetsFromDisjointSet(
1080     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals) {
1081   llvm::DenseMap<Metadata *, uint64_t> TypeIdIndices;
1082   for (unsigned I = 0; I != TypeIds.size(); ++I)
1083     TypeIdIndices[TypeIds[I]] = I;
1084
1085   // For each type identifier, build a set of indices that refer to members of
1086   // the type identifier.
1087   std::vector<std::set<uint64_t>> TypeMembers(TypeIds.size());
1088   unsigned GlobalIndex = 0;
1089   for (GlobalTypeMember *GTM : Globals) {
1090     for (MDNode *Type : GTM->types()) {
1091       // Type = { offset, type identifier }
1092       unsigned TypeIdIndex = TypeIdIndices[Type->getOperand(1)];
1093       TypeMembers[TypeIdIndex].insert(GlobalIndex);
1094     }
1095     GlobalIndex++;
1096   }
1097
1098   // Order the sets of indices by size. The GlobalLayoutBuilder works best
1099   // when given small index sets first.
1100   std::stable_sort(
1101       TypeMembers.begin(), TypeMembers.end(),
1102       [](const std::set<uint64_t> &O1, const std::set<uint64_t> &O2) {
1103         return O1.size() < O2.size();
1104       });
1105
1106   // Create a GlobalLayoutBuilder and provide it with index sets as layout
1107   // fragments. The GlobalLayoutBuilder tries to lay out members of fragments as
1108   // close together as possible.
1109   GlobalLayoutBuilder GLB(Globals.size());
1110   for (auto &&MemSet : TypeMembers)
1111     GLB.addFragment(MemSet);
1112
1113   // Build the bitsets from this disjoint set.
1114   if (Globals.empty() || isa<GlobalVariable>(Globals[0]->getGlobal())) {
1115     // Build a vector of global variables with the computed layout.
1116     std::vector<GlobalTypeMember *> OrderedGVs(Globals.size());
1117     auto OGI = OrderedGVs.begin();
1118     for (auto &&F : GLB.Fragments) {
1119       for (auto &&Offset : F) {
1120         auto GV = dyn_cast<GlobalVariable>(Globals[Offset]->getGlobal());
1121         if (!GV)
1122           report_fatal_error("Type identifier may not contain both global "
1123                              "variables and functions");
1124         *OGI++ = Globals[Offset];
1125       }
1126     }
1127
1128     buildBitSetsFromGlobalVariables(TypeIds, OrderedGVs);
1129   } else {
1130     // Build a vector of functions with the computed layout.
1131     std::vector<GlobalTypeMember *> OrderedFns(Globals.size());
1132     auto OFI = OrderedFns.begin();
1133     for (auto &&F : GLB.Fragments) {
1134       for (auto &&Offset : F) {
1135         auto Fn = dyn_cast<Function>(Globals[Offset]->getGlobal());
1136         if (!Fn)
1137           report_fatal_error("Type identifier may not contain both global "
1138                              "variables and functions");
1139         *OFI++ = Globals[Offset];
1140       }
1141     }
1142
1143     buildBitSetsFromFunctions(TypeIds, OrderedFns);
1144   }
1145 }
1146
1147 /// Lower all type tests in this module.
1148 LowerTypeTestsModule::LowerTypeTestsModule(Module &M) : M(M) {
1149   // Handle the command-line summary arguments. This code is for testing
1150   // purposes only, so we handle errors directly.
1151   if (!ClSummaryAction.empty()) {
1152     OwnedSummary = make_unique<ModuleSummaryIndex>();
1153     if (!ClReadSummary.empty()) {
1154       ExitOnError ExitOnErr("-lowertypetests-read-summary: " + ClReadSummary +
1155                             ": ");
1156       auto ReadSummaryFile =
1157           ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
1158
1159       yaml::Input In(ReadSummaryFile->getBuffer());
1160       In >> *OwnedSummary;
1161       ExitOnErr(errorCodeToError(In.error()));
1162     }
1163   }
1164
1165   Triple TargetTriple(M.getTargetTriple());
1166   LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX();
1167   Arch = TargetTriple.getArch();
1168   OS = TargetTriple.getOS();
1169   ObjectFormat = TargetTriple.getObjectFormat();
1170 }
1171
1172 LowerTypeTestsModule::~LowerTypeTestsModule() {
1173   if (ClSummaryAction.empty() || ClWriteSummary.empty())
1174     return;
1175
1176   ExitOnError ExitOnErr("-lowertypetests-write-summary: " + ClWriteSummary +
1177                         ": ");
1178   std::error_code EC;
1179   raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text);
1180   ExitOnErr(errorCodeToError(EC));
1181
1182   yaml::Output Out(OS);
1183   Out << *OwnedSummary;
1184 }
1185
1186 bool LowerTypeTestsModule::lower() {
1187   Function *TypeTestFunc =
1188       M.getFunction(Intrinsic::getName(Intrinsic::type_test));
1189   if (!TypeTestFunc || TypeTestFunc->use_empty())
1190     return false;
1191
1192   // Equivalence class set containing type identifiers and the globals that
1193   // reference them. This is used to partition the set of type identifiers in
1194   // the module into disjoint sets.
1195   typedef EquivalenceClasses<PointerUnion<GlobalTypeMember *, Metadata *>>
1196       GlobalClassesTy;
1197   GlobalClassesTy GlobalClasses;
1198
1199   // Verify the type metadata and build a few data structures to let us
1200   // efficiently enumerate the type identifiers associated with a global:
1201   // a list of GlobalTypeMembers (a GlobalObject stored alongside a vector
1202   // of associated type metadata) and a mapping from type identifiers to their
1203   // list of GlobalTypeMembers and last observed index in the list of globals.
1204   // The indices will be used later to deterministically order the list of type
1205   // identifiers.
1206   BumpPtrAllocator Alloc;
1207   struct TIInfo {
1208     unsigned Index;
1209     std::vector<GlobalTypeMember *> RefGlobals;
1210   };
1211   llvm::DenseMap<Metadata *, TIInfo> TypeIdInfo;
1212   unsigned I = 0;
1213   SmallVector<MDNode *, 2> Types;
1214   for (GlobalObject &GO : M.global_objects()) {
1215     Types.clear();
1216     GO.getMetadata(LLVMContext::MD_type, Types);
1217     if (Types.empty())
1218       continue;
1219
1220     auto *GTM = GlobalTypeMember::create(Alloc, &GO, Types);
1221     for (MDNode *Type : Types) {
1222       verifyTypeMDNode(&GO, Type);
1223       auto &Info = TypeIdInfo[cast<MDNode>(Type)->getOperand(1)];
1224       Info.Index = ++I;
1225       Info.RefGlobals.push_back(GTM);
1226     }
1227   }
1228
1229   for (const Use &U : TypeTestFunc->uses()) {
1230     auto CI = cast<CallInst>(U.getUser());
1231
1232     auto BitSetMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
1233     if (!BitSetMDVal)
1234       report_fatal_error("Second argument of llvm.type.test must be metadata");
1235     auto BitSet = BitSetMDVal->getMetadata();
1236
1237     // Add the call site to the list of call sites for this type identifier. We
1238     // also use TypeTestCallSites to keep track of whether we have seen this
1239     // type identifier before. If we have, we don't need to re-add the
1240     // referenced globals to the equivalence class.
1241     std::pair<DenseMap<Metadata *, std::vector<CallInst *>>::iterator, bool>
1242         Ins = TypeTestCallSites.insert(
1243             std::make_pair(BitSet, std::vector<CallInst *>()));
1244     Ins.first->second.push_back(CI);
1245     if (!Ins.second)
1246       continue;
1247
1248     // Add the type identifier to the equivalence class.
1249     GlobalClassesTy::iterator GCI = GlobalClasses.insert(BitSet);
1250     GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI);
1251
1252     // Add the referenced globals to the type identifier's equivalence class.
1253     for (GlobalTypeMember *GTM : TypeIdInfo[BitSet].RefGlobals)
1254       CurSet = GlobalClasses.unionSets(
1255           CurSet, GlobalClasses.findLeader(GlobalClasses.insert(GTM)));
1256   }
1257
1258   if (GlobalClasses.empty())
1259     return false;
1260
1261   // Build a list of disjoint sets ordered by their maximum global index for
1262   // determinism.
1263   std::vector<std::pair<GlobalClassesTy::iterator, unsigned>> Sets;
1264   for (GlobalClassesTy::iterator I = GlobalClasses.begin(),
1265                                  E = GlobalClasses.end();
1266        I != E; ++I) {
1267     if (!I->isLeader())
1268       continue;
1269     ++NumTypeIdDisjointSets;
1270
1271     unsigned MaxIndex = 0;
1272     for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I);
1273          MI != GlobalClasses.member_end(); ++MI) {
1274       if ((*MI).is<Metadata *>())
1275         MaxIndex = std::max(MaxIndex, TypeIdInfo[MI->get<Metadata *>()].Index);
1276     }
1277     Sets.emplace_back(I, MaxIndex);
1278   }
1279   std::sort(Sets.begin(), Sets.end(),
1280             [](const std::pair<GlobalClassesTy::iterator, unsigned> &S1,
1281                const std::pair<GlobalClassesTy::iterator, unsigned> &S2) {
1282               return S1.second < S2.second;
1283             });
1284
1285   // For each disjoint set we found...
1286   for (const auto &S : Sets) {
1287     // Build the list of type identifiers in this disjoint set.
1288     std::vector<Metadata *> TypeIds;
1289     std::vector<GlobalTypeMember *> Globals;
1290     for (GlobalClassesTy::member_iterator MI =
1291              GlobalClasses.member_begin(S.first);
1292          MI != GlobalClasses.member_end(); ++MI) {
1293       if ((*MI).is<Metadata *>())
1294         TypeIds.push_back(MI->get<Metadata *>());
1295       else
1296         Globals.push_back(MI->get<GlobalTypeMember *>());
1297     }
1298
1299     // Order type identifiers by global index for determinism. This ordering is
1300     // stable as there is a one-to-one mapping between metadata and indices.
1301     std::sort(TypeIds.begin(), TypeIds.end(), [&](Metadata *M1, Metadata *M2) {
1302       return TypeIdInfo[M1].Index < TypeIdInfo[M2].Index;
1303     });
1304
1305     // Build bitsets for this disjoint set.
1306     buildBitSetsFromDisjointSet(TypeIds, Globals);
1307   }
1308
1309   allocateByteArrays();
1310
1311   return true;
1312 }
1313
1314 PreservedAnalyses LowerTypeTestsPass::run(Module &M,
1315                                           ModuleAnalysisManager &AM) {
1316   bool Changed = LowerTypeTestsModule(M).lower();
1317   if (!Changed)
1318     return PreservedAnalyses::all();
1319   return PreservedAnalyses::none();
1320 }