]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm/lib/Target/AMDGPU/AMDGPUOpenCLImageTypeLoweringPass.cpp
MFC r343601:
[FreeBSD/FreeBSD.git] / contrib / llvm / lib / Target / AMDGPU / AMDGPUOpenCLImageTypeLoweringPass.cpp
1 //===- AMDGPUOpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 /// sampler resource ID getter functions.
13 ///
14 /// Image attributes (size and format) are expected to be passed to the kernel
15 /// as kernel arguments immediately following the image argument itself,
16 /// therefore this pass adds image size and format arguments to the kernel
17 /// functions in the module. The kernel functions with image arguments are
18 /// re-created using the new signature. The new arguments are added to the
19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 /// Note: this pass may invalidate pointers to functions.
21 ///
22 /// Resource IDs of read-only images, write-only images and samplers are
23 /// defined to be their index among the kernel arguments of the same
24 /// type and access qualifier.
25 //
26 //===----------------------------------------------------------------------===//
27
28 #include "AMDGPU.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/IR/Argument.h"
33 #include "llvm/IR/DerivedTypes.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/IR/Instructions.h"
38 #include "llvm/IR/Metadata.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/IR/Use.h"
42 #include "llvm/IR/User.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Support/Casting.h"
45 #include "llvm/Support/ErrorHandling.h"
46 #include "llvm/Transforms/Utils/Cloning.h"
47 #include "llvm/Transforms/Utils/ValueMapper.h"
48 #include <cassert>
49 #include <cstddef>
50 #include <cstdint>
51 #include <tuple>
52
53 using namespace llvm;
54
55 static StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
56 static StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
57 static StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
58 static StringRef GetSamplerResourceIDFunc =
59     "llvm.OpenCL.sampler.get.resource.id";
60
61 static StringRef ImageSizeArgMDType =   "__llvm_image_size";
62 static StringRef ImageFormatArgMDType = "__llvm_image_format";
63
64 static StringRef KernelsMDNodeName = "opencl.kernels";
65 static StringRef KernelArgMDNodeNames[] = {
66   "kernel_arg_addr_space",
67   "kernel_arg_access_qual",
68   "kernel_arg_type",
69   "kernel_arg_base_type",
70   "kernel_arg_type_qual"};
71 static const unsigned NumKernelArgMDNodes = 5;
72
73 namespace {
74
75 using MDVector = SmallVector<Metadata *, 8>;
76 struct KernelArgMD {
77   MDVector ArgVector[NumKernelArgMDNodes];
78 };
79
80 } // end anonymous namespace
81
82 static inline bool
83 IsImageType(StringRef TypeString) {
84   return TypeString == "image2d_t" || TypeString == "image3d_t";
85 }
86
87 static inline bool
88 IsSamplerType(StringRef TypeString) {
89   return TypeString == "sampler_t";
90 }
91
92 static Function *
93 GetFunctionFromMDNode(MDNode *Node) {
94   if (!Node)
95     return nullptr;
96
97   size_t NumOps = Node->getNumOperands();
98   if (NumOps != NumKernelArgMDNodes + 1)
99     return nullptr;
100
101   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
102   if (!F)
103     return nullptr;
104
105   // Sanity checks.
106   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
107   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
108     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
109     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
110       return nullptr;
111     if (!ArgNode->getOperand(0))
112       return nullptr;
113
114     // FIXME: It should be possible to do image lowering when some metadata
115     // args missing or not in the expected order.
116     MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
117     if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
118       return nullptr;
119   }
120
121   return F;
122 }
123
124 static StringRef
125 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
126   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
127   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
128 }
129
130 static StringRef
131 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
132   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
133   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
134 }
135
136 static MDVector
137 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
138   MDVector Res;
139   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
140     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
141     Res.push_back(Node->getOperand(OpIdx));
142   }
143   return Res;
144 }
145
146 static void
147 PushArgMD(KernelArgMD &MD, const MDVector &V) {
148   assert(V.size() == NumKernelArgMDNodes);
149   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
150     MD.ArgVector[i].push_back(V[i]);
151   }
152 }
153
154 namespace {
155
156 class AMDGPUOpenCLImageTypeLoweringPass : public ModulePass {
157   static char ID;
158
159   LLVMContext *Context;
160   Type *Int32Type;
161   Type *ImageSizeType;
162   Type *ImageFormatType;
163   SmallVector<Instruction *, 4> InstsToErase;
164
165   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
166                         Argument &ImageSizeArg,
167                         Argument &ImageFormatArg) {
168     bool Modified = false;
169
170     for (auto &Use : ImageArg.uses()) {
171       auto Inst = dyn_cast<CallInst>(Use.getUser());
172       if (!Inst) {
173         continue;
174       }
175
176       Function *F = Inst->getCalledFunction();
177       if (!F)
178         continue;
179
180       Value *Replacement = nullptr;
181       StringRef Name = F->getName();
182       if (Name.startswith(GetImageResourceIDFunc)) {
183         Replacement = ConstantInt::get(Int32Type, ResourceID);
184       } else if (Name.startswith(GetImageSizeFunc)) {
185         Replacement = &ImageSizeArg;
186       } else if (Name.startswith(GetImageFormatFunc)) {
187         Replacement = &ImageFormatArg;
188       } else {
189         continue;
190       }
191
192       Inst->replaceAllUsesWith(Replacement);
193       InstsToErase.push_back(Inst);
194       Modified = true;
195     }
196
197     return Modified;
198   }
199
200   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
201     bool Modified = false;
202
203     for (const auto &Use : SamplerArg.uses()) {
204       auto Inst = dyn_cast<CallInst>(Use.getUser());
205       if (!Inst) {
206         continue;
207       }
208
209       Function *F = Inst->getCalledFunction();
210       if (!F)
211         continue;
212
213       Value *Replacement = nullptr;
214       StringRef Name = F->getName();
215       if (Name == GetSamplerResourceIDFunc) {
216         Replacement = ConstantInt::get(Int32Type, ResourceID);
217       } else {
218         continue;
219       }
220
221       Inst->replaceAllUsesWith(Replacement);
222       InstsToErase.push_back(Inst);
223       Modified = true;
224     }
225
226     return Modified;
227   }
228
229   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
230     uint32_t NumReadOnlyImageArgs = 0;
231     uint32_t NumWriteOnlyImageArgs = 0;
232     uint32_t NumSamplerArgs = 0;
233
234     bool Modified = false;
235     InstsToErase.clear();
236     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
237       Argument &Arg = *ArgI;
238       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
239
240       // Handle image types.
241       if (IsImageType(Type)) {
242         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
243         uint32_t ResourceID;
244         if (AccessQual == "read_only") {
245           ResourceID = NumReadOnlyImageArgs++;
246         } else if (AccessQual == "write_only") {
247           ResourceID = NumWriteOnlyImageArgs++;
248         } else {
249           llvm_unreachable("Wrong image access qualifier.");
250         }
251
252         Argument &SizeArg = *(++ArgI);
253         Argument &FormatArg = *(++ArgI);
254         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
255
256       // Handle sampler type.
257       } else if (IsSamplerType(Type)) {
258         uint32_t ResourceID = NumSamplerArgs++;
259         Modified |= replaceSamplerUses(Arg, ResourceID);
260       }
261     }
262     for (unsigned i = 0; i < InstsToErase.size(); ++i) {
263       InstsToErase[i]->eraseFromParent();
264     }
265
266     return Modified;
267   }
268
269   std::tuple<Function *, MDNode *>
270   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
271     bool Modified = false;
272
273     FunctionType *FT = F->getFunctionType();
274     SmallVector<Type *, 8> ArgTypes;
275
276     // Metadata operands for new MDNode.
277     KernelArgMD NewArgMDs;
278     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
279
280     // Add implicit arguments to the signature.
281     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
282       ArgTypes.push_back(FT->getParamType(i));
283       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
284       PushArgMD(NewArgMDs, ArgMD);
285
286       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
287         continue;
288
289       // Add size implicit argument.
290       ArgTypes.push_back(ImageSizeType);
291       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
292       PushArgMD(NewArgMDs, ArgMD);
293
294       // Add format implicit argument.
295       ArgTypes.push_back(ImageFormatType);
296       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
297       PushArgMD(NewArgMDs, ArgMD);
298
299       Modified = true;
300     }
301     if (!Modified) {
302       return std::make_tuple(nullptr, nullptr);
303     }
304
305     // Create function with new signature and clone the old body into it.
306     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
307     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
308     ValueToValueMapTy VMap;
309     auto NewFArgIt = NewF->arg_begin();
310     for (auto &Arg: F->args()) {
311       auto ArgName = Arg.getName();
312       NewFArgIt->setName(ArgName);
313       VMap[&Arg] = &(*NewFArgIt++);
314       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
315         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
316         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
317       }
318     }
319     SmallVector<ReturnInst*, 8> Returns;
320     CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
321
322     // Build new MDNode.
323     SmallVector<Metadata *, 6> KernelMDArgs;
324     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
325     for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
326       KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
327     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
328
329     return std::make_tuple(NewF, NewMDNode);
330   }
331
332   bool transformKernels(Module &M) {
333     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
334     if (!KernelsMDNode)
335       return false;
336
337     bool Modified = false;
338     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
339       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
340       Function *F = GetFunctionFromMDNode(KernelMDNode);
341       if (!F)
342         continue;
343
344       Function *NewF;
345       MDNode *NewMDNode;
346       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
347       if (NewF) {
348         // Replace old function and metadata with new ones.
349         F->eraseFromParent();
350         M.getFunctionList().push_back(NewF);
351         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
352                               NewF->getAttributes());
353         KernelsMDNode->setOperand(i, NewMDNode);
354
355         F = NewF;
356         KernelMDNode = NewMDNode;
357         Modified = true;
358       }
359
360       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
361     }
362
363     return Modified;
364   }
365
366 public:
367   AMDGPUOpenCLImageTypeLoweringPass() : ModulePass(ID) {}
368
369   bool runOnModule(Module &M) override {
370     Context = &M.getContext();
371     Int32Type = Type::getInt32Ty(M.getContext());
372     ImageSizeType = ArrayType::get(Int32Type, 3);
373     ImageFormatType = ArrayType::get(Int32Type, 2);
374
375     return transformKernels(M);
376   }
377
378   StringRef getPassName() const override {
379     return "AMDGPU OpenCL Image Type Pass";
380   }
381 };
382
383 } // end anonymous namespace
384
385 char AMDGPUOpenCLImageTypeLoweringPass::ID = 0;
386
387 ModulePass *llvm::createAMDGPUOpenCLImageTypeLoweringPass() {
388   return new AMDGPUOpenCLImageTypeLoweringPass();
389 }