//===-- NVPTXLowerStructArgs.cpp - Copy struct args to local memory =====--===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // // Copy struct args to local memory. This is needed for kernel functions only. // This is a preparation for handling cases like // // kernel void foo(struct A arg, ...) // { // struct A *p = &arg; // ... // ... = p->filed1 ... (this is no generic address for .param) // p->filed2 = ... (this is no write access to .param) // } // //===----------------------------------------------------------------------===// #include "NVPTX.h" #include "NVPTXUtilities.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Pass.h" using namespace llvm; namespace llvm { void initializeNVPTXLowerStructArgsPass(PassRegistry &); } class LLVM_LIBRARY_VISIBILITY NVPTXLowerStructArgs : public FunctionPass { bool runOnFunction(Function &F) override; void handleStructPtrArgs(Function &); void handleParam(Argument *); public: static char ID; // Pass identification, replacement for typeid NVPTXLowerStructArgs() : FunctionPass(ID) {} const char *getPassName() const override { return "Copy structure (byval *) arguments to stack"; } }; char NVPTXLowerStructArgs::ID = 1; INITIALIZE_PASS(NVPTXLowerStructArgs, "nvptx-lower-struct-args", "Lower structure arguments (NVPTX)", false, false) void NVPTXLowerStructArgs::handleParam(Argument *Arg) { Function *Func = Arg->getParent(); Instruction *FirstInst = &(Func->getEntryBlock().front()); PointerType *PType = dyn_cast(Arg->getType()); assert(PType && "Expecting pointer type in handleParam"); Type *StructType = PType->getElementType(); AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst); /* Set the alignment to alignment of the byval parameter. This is because, * later load/stores assume that alignment, and we are going to replace * the use of the byval parameter with this alloca instruction. */ AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1)); Arg->replaceAllUsesWith(AllocA); // Get the cvt.gen.to.param intrinsic Type *CvtTypes[] = { Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_PARAM), Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_GENERIC)}; Function *CvtFunc = Intrinsic::getDeclaration( Func->getParent(), Intrinsic::nvvm_ptr_gen_to_param, CvtTypes); Value *BitcastArgs[] = { new BitCastInst(Arg, Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_GENERIC), Arg->getName(), FirstInst)}; CallInst *CallCVT = CallInst::Create(CvtFunc, BitcastArgs, "cvt_to_param", FirstInst); BitCastInst *BitCast = new BitCastInst( CallCVT, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(), FirstInst); LoadInst *LI = new LoadInst(BitCast, Arg->getName(), FirstInst); new StoreInst(LI, AllocA, FirstInst); } // ============================================================================= // If the function had a struct ptr arg, say foo(%struct.x *byval %d), then // add the following instructions to the first basic block : // // %temp = alloca %struct.x, align 8 // %tt1 = bitcast %struct.x * %d to i8 * // %tt2 = llvm.nvvm.cvt.gen.to.param %tt2 // %tempd = bitcast i8 addrspace(101) * to %struct.x addrspace(101) * // %tv = load %struct.x addrspace(101) * %tempd // store %struct.x %tv, %struct.x * %temp, align 8 // // The above code allocates some space in the stack and copies the incoming // struct from param space to local space. // Then replace all occurences of %d by %temp. // ============================================================================= void NVPTXLowerStructArgs::handleStructPtrArgs(Function &F) { for (Argument &Arg : F.args()) { if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) { handleParam(&Arg); } } } // ============================================================================= // Main function for this pass. // ============================================================================= bool NVPTXLowerStructArgs::runOnFunction(Function &F) { // Skip non-kernels. See the comments at the top of this file. if (!isKernelFunction(F)) return false; handleStructPtrArgs(F); return true; } FunctionPass *llvm::createNVPTXLowerStructArgsPass() { return new NVPTXLowerStructArgs(); }