1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
11 /// This pass removes performs the following type substitution on all
12 /// non-compute shaders:
15 /// - v16i8 is used for constant memory resource descriptors. This type is
16 /// legal for some compute APIs, and we don't want to declare it as legal
17 /// in the backend, because we want the legalizer to expand all v16i8
20 /// - Having v1* types complicates the legalizer and we can easily replace
21 /// - them with the element type.
22 //===----------------------------------------------------------------------===//
25 #include "Utils/AMDGPUBaseInfo.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/InstVisitor.h"
33 class SITypeRewriter : public FunctionPass,
34 public InstVisitor<SITypeRewriter> {
42 SITypeRewriter() : FunctionPass(ID) { }
43 bool doInitialization(Module &M) override;
44 bool runOnFunction(Function &F) override;
45 StringRef getPassName() const override { return "SI Type Rewriter"; }
46 void visitLoadInst(LoadInst &I);
47 void visitCallInst(CallInst &I);
48 void visitBitCast(BitCastInst &I);
51 } // End anonymous namespace
53 char SITypeRewriter::ID = 0;
55 bool SITypeRewriter::doInitialization(Module &M) {
57 v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
58 v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
62 bool SITypeRewriter::runOnFunction(Function &F) {
63 if (!AMDGPU::isShader(F.getCallingConv()))
72 void SITypeRewriter::visitLoadInst(LoadInst &I) {
73 Value *Ptr = I.getPointerOperand();
74 Type *PtrTy = Ptr->getType();
75 Type *ElemTy = PtrTy->getPointerElementType();
76 IRBuilder<> Builder(&I);
77 if (ElemTy == v16i8) {
78 Value *BitCast = Builder.CreateBitCast(Ptr,
79 PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
80 LoadInst *Load = Builder.CreateLoad(BitCast);
81 SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
82 I.getAllMetadataOtherThanDebugLoc(MD);
83 for (unsigned i = 0, e = MD.size(); i != e; ++i) {
84 Load->setMetadata(MD[i].first, MD[i].second);
86 Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
87 I.replaceAllUsesWith(BitCastLoad);
92 void SITypeRewriter::visitCallInst(CallInst &I) {
93 IRBuilder<> Builder(&I);
95 SmallVector <Value*, 8> Args;
96 SmallVector <Type*, 8> Types;
97 bool NeedToReplace = false;
98 Function *F = I.getCalledFunction();
102 std::string Name = F->getName();
103 for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
104 Value *Arg = I.getArgOperand(i);
105 if (Arg->getType() == v16i8) {
106 Args.push_back(Builder.CreateBitCast(Arg, v4i32));
107 Types.push_back(v4i32);
108 NeedToReplace = true;
109 Name = Name + ".v4i32";
110 } else if (Arg->getType()->isVectorTy() &&
111 Arg->getType()->getVectorNumElements() == 1 &&
112 Arg->getType()->getVectorElementType() ==
113 Type::getInt32Ty(I.getContext())){
114 Type *ElementTy = Arg->getType()->getVectorElementType();
115 std::string TypeName = "i32";
116 InsertElementInst *Def = cast<InsertElementInst>(Arg);
117 Args.push_back(Def->getOperand(1));
118 Types.push_back(ElementTy);
119 std::string VecTypeName = "v1" + TypeName;
120 Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
121 NeedToReplace = true;
124 Types.push_back(Arg->getType());
128 if (!NeedToReplace) {
131 Function *NewF = Mod->getFunction(Name);
133 NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
134 NewF->setAttributes(F->getAttributes());
136 I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
140 void SITypeRewriter::visitBitCast(BitCastInst &I) {
141 IRBuilder<> Builder(&I);
142 if (I.getDestTy() != v4i32) {
146 if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
147 if (Op->getSrcTy() == v4i32) {
148 I.replaceAllUsesWith(Op->getOperand(0));
154 FunctionPass *llvm::createSITypeRewriter() {
155 return new SITypeRewriter();