]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
zfs: merge openzfs/zfs@eb62221ff (zfs-2.1-release) into stable/13
[FreeBSD/FreeBSD.git] / contrib / llvm-project / llvm / lib / Target / SPIRV / SPIRVBuiltins.cpp
1 //===- SPIRVBuiltins.cpp - SPIR-V Built-in Functions ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering builtin function calls and types using their
10 // demangled names and TableGen records.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "SPIRVBuiltins.h"
15 #include "SPIRV.h"
16 #include "SPIRVUtils.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/IR/IntrinsicsSPIRV.h"
19 #include <string>
20 #include <tuple>
21
22 #define DEBUG_TYPE "spirv-builtins"
23
24 namespace llvm {
25 namespace SPIRV {
26 #define GET_BuiltinGroup_DECL
27 #include "SPIRVGenTables.inc"
28
29 struct DemangledBuiltin {
30   StringRef Name;
31   InstructionSet::InstructionSet Set;
32   BuiltinGroup Group;
33   uint8_t MinNumArgs;
34   uint8_t MaxNumArgs;
35 };
36
37 #define GET_DemangledBuiltins_DECL
38 #define GET_DemangledBuiltins_IMPL
39
40 struct IncomingCall {
41   const std::string BuiltinName;
42   const DemangledBuiltin *Builtin;
43
44   const Register ReturnRegister;
45   const SPIRVType *ReturnType;
46   const SmallVectorImpl<Register> &Arguments;
47
48   IncomingCall(const std::string BuiltinName, const DemangledBuiltin *Builtin,
49                const Register ReturnRegister, const SPIRVType *ReturnType,
50                const SmallVectorImpl<Register> &Arguments)
51       : BuiltinName(BuiltinName), Builtin(Builtin),
52         ReturnRegister(ReturnRegister), ReturnType(ReturnType),
53         Arguments(Arguments) {}
54 };
55
56 struct NativeBuiltin {
57   StringRef Name;
58   InstructionSet::InstructionSet Set;
59   uint32_t Opcode;
60 };
61
62 #define GET_NativeBuiltins_DECL
63 #define GET_NativeBuiltins_IMPL
64
65 struct GroupBuiltin {
66   StringRef Name;
67   uint32_t Opcode;
68   uint32_t GroupOperation;
69   bool IsElect;
70   bool IsAllOrAny;
71   bool IsAllEqual;
72   bool IsBallot;
73   bool IsInverseBallot;
74   bool IsBallotBitExtract;
75   bool IsBallotFindBit;
76   bool IsLogical;
77   bool NoGroupOperation;
78   bool HasBoolArg;
79 };
80
81 #define GET_GroupBuiltins_DECL
82 #define GET_GroupBuiltins_IMPL
83
84 struct GetBuiltin {
85   StringRef Name;
86   InstructionSet::InstructionSet Set;
87   BuiltIn::BuiltIn Value;
88 };
89
90 using namespace BuiltIn;
91 #define GET_GetBuiltins_DECL
92 #define GET_GetBuiltins_IMPL
93
94 struct ImageQueryBuiltin {
95   StringRef Name;
96   InstructionSet::InstructionSet Set;
97   uint32_t Component;
98 };
99
100 #define GET_ImageQueryBuiltins_DECL
101 #define GET_ImageQueryBuiltins_IMPL
102
103 struct ConvertBuiltin {
104   StringRef Name;
105   InstructionSet::InstructionSet Set;
106   bool IsDestinationSigned;
107   bool IsSaturated;
108   bool IsRounded;
109   FPRoundingMode::FPRoundingMode RoundingMode;
110 };
111
112 struct VectorLoadStoreBuiltin {
113   StringRef Name;
114   InstructionSet::InstructionSet Set;
115   uint32_t Number;
116   bool IsRounded;
117   FPRoundingMode::FPRoundingMode RoundingMode;
118 };
119
120 using namespace FPRoundingMode;
121 #define GET_ConvertBuiltins_DECL
122 #define GET_ConvertBuiltins_IMPL
123
124 using namespace InstructionSet;
125 #define GET_VectorLoadStoreBuiltins_DECL
126 #define GET_VectorLoadStoreBuiltins_IMPL
127
128 #define GET_CLMemoryScope_DECL
129 #define GET_CLSamplerAddressingMode_DECL
130 #define GET_CLMemoryFenceFlags_DECL
131 #define GET_ExtendedBuiltins_DECL
132 #include "SPIRVGenTables.inc"
133 } // namespace SPIRV
134
135 //===----------------------------------------------------------------------===//
136 // Misc functions for looking up builtins and veryfying requirements using
137 // TableGen records
138 //===----------------------------------------------------------------------===//
139
140 /// Looks up the demangled builtin call in the SPIRVBuiltins.td records using
141 /// the provided \p DemangledCall and specified \p Set.
142 ///
143 /// The lookup follows the following algorithm, returning the first successful
144 /// match:
145 /// 1. Search with the plain demangled name (expecting a 1:1 match).
146 /// 2. Search with the prefix before or suffix after the demangled name
147 /// signyfying the type of the first argument.
148 ///
149 /// \returns Wrapper around the demangled call and found builtin definition.
150 static std::unique_ptr<const SPIRV::IncomingCall>
151 lookupBuiltin(StringRef DemangledCall,
152               SPIRV::InstructionSet::InstructionSet Set,
153               Register ReturnRegister, const SPIRVType *ReturnType,
154               const SmallVectorImpl<Register> &Arguments) {
155   // Extract the builtin function name and types of arguments from the call
156   // skeleton.
157   std::string BuiltinName =
158       DemangledCall.substr(0, DemangledCall.find('(')).str();
159
160   // Check if the extracted name contains type information between angle
161   // brackets. If so, the builtin is an instantiated template - needs to have
162   // the information after angle brackets and return type removed.
163   if (BuiltinName.find('<') && BuiltinName.back() == '>') {
164     BuiltinName = BuiltinName.substr(0, BuiltinName.find('<'));
165     BuiltinName = BuiltinName.substr(BuiltinName.find_last_of(" ") + 1);
166   }
167
168   // Check if the extracted name begins with "__spirv_ImageSampleExplicitLod"
169   // contains return type information at the end "_R<type>", if so extract the
170   // plain builtin name without the type information.
171   if (StringRef(BuiltinName).contains("__spirv_ImageSampleExplicitLod") &&
172       StringRef(BuiltinName).contains("_R")) {
173     BuiltinName = BuiltinName.substr(0, BuiltinName.find("_R"));
174   }
175
176   SmallVector<StringRef, 10> BuiltinArgumentTypes;
177   StringRef BuiltinArgs =
178       DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
179   BuiltinArgs.split(BuiltinArgumentTypes, ',', -1, false);
180
181   // Look up the builtin in the defined set. Start with the plain demangled
182   // name, expecting a 1:1 match in the defined builtin set.
183   const SPIRV::DemangledBuiltin *Builtin;
184   if ((Builtin = SPIRV::lookupBuiltin(BuiltinName, Set)))
185     return std::make_unique<SPIRV::IncomingCall>(
186         BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
187
188   // If the initial look up was unsuccessful and the demangled call takes at
189   // least 1 argument, add a prefix or suffix signifying the type of the first
190   // argument and repeat the search.
191   if (BuiltinArgumentTypes.size() >= 1) {
192     char FirstArgumentType = BuiltinArgumentTypes[0][0];
193     // Prefix to be added to the builtin's name for lookup.
194     // For example, OpenCL "abs" taking an unsigned value has a prefix "u_".
195     std::string Prefix;
196
197     switch (FirstArgumentType) {
198     // Unsigned:
199     case 'u':
200       if (Set == SPIRV::InstructionSet::OpenCL_std)
201         Prefix = "u_";
202       else if (Set == SPIRV::InstructionSet::GLSL_std_450)
203         Prefix = "u";
204       break;
205     // Signed:
206     case 'c':
207     case 's':
208     case 'i':
209     case 'l':
210       if (Set == SPIRV::InstructionSet::OpenCL_std)
211         Prefix = "s_";
212       else if (Set == SPIRV::InstructionSet::GLSL_std_450)
213         Prefix = "s";
214       break;
215     // Floating-point:
216     case 'f':
217     case 'd':
218     case 'h':
219       if (Set == SPIRV::InstructionSet::OpenCL_std ||
220           Set == SPIRV::InstructionSet::GLSL_std_450)
221         Prefix = "f";
222       break;
223     }
224
225     // If argument-type name prefix was added, look up the builtin again.
226     if (!Prefix.empty() &&
227         (Builtin = SPIRV::lookupBuiltin(Prefix + BuiltinName, Set)))
228       return std::make_unique<SPIRV::IncomingCall>(
229           BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
230
231     // If lookup with a prefix failed, find a suffix to be added to the
232     // builtin's name for lookup. For example, OpenCL "group_reduce_max" taking
233     // an unsigned value has a suffix "u".
234     std::string Suffix;
235
236     switch (FirstArgumentType) {
237     // Unsigned:
238     case 'u':
239       Suffix = "u";
240       break;
241     // Signed:
242     case 'c':
243     case 's':
244     case 'i':
245     case 'l':
246       Suffix = "s";
247       break;
248     // Floating-point:
249     case 'f':
250     case 'd':
251     case 'h':
252       Suffix = "f";
253       break;
254     }
255
256     // If argument-type name suffix was added, look up the builtin again.
257     if (!Suffix.empty() &&
258         (Builtin = SPIRV::lookupBuiltin(BuiltinName + Suffix, Set)))
259       return std::make_unique<SPIRV::IncomingCall>(
260           BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
261   }
262
263   // No builtin with such name was found in the set.
264   return nullptr;
265 }
266
267 //===----------------------------------------------------------------------===//
268 // Helper functions for building misc instructions
269 //===----------------------------------------------------------------------===//
270
271 /// Helper function building either a resulting scalar or vector bool register
272 /// depending on the expected \p ResultType.
273 ///
274 /// \returns Tuple of the resulting register and its type.
275 static std::tuple<Register, SPIRVType *>
276 buildBoolRegister(MachineIRBuilder &MIRBuilder, const SPIRVType *ResultType,
277                   SPIRVGlobalRegistry *GR) {
278   LLT Type;
279   SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
280
281   if (ResultType->getOpcode() == SPIRV::OpTypeVector) {
282     unsigned VectorElements = ResultType->getOperand(2).getImm();
283     BoolType =
284         GR->getOrCreateSPIRVVectorType(BoolType, VectorElements, MIRBuilder);
285     const FixedVectorType *LLVMVectorType =
286         cast<FixedVectorType>(GR->getTypeForSPIRVType(BoolType));
287     Type = LLT::vector(LLVMVectorType->getElementCount(), 1);
288   } else {
289     Type = LLT::scalar(1);
290   }
291
292   Register ResultRegister =
293       MIRBuilder.getMRI()->createGenericVirtualRegister(Type);
294   GR->assignSPIRVTypeToVReg(BoolType, ResultRegister, MIRBuilder.getMF());
295   return std::make_tuple(ResultRegister, BoolType);
296 }
297
298 /// Helper function for building either a vector or scalar select instruction
299 /// depending on the expected \p ResultType.
300 static bool buildSelectInst(MachineIRBuilder &MIRBuilder,
301                             Register ReturnRegister, Register SourceRegister,
302                             const SPIRVType *ReturnType,
303                             SPIRVGlobalRegistry *GR) {
304   Register TrueConst, FalseConst;
305
306   if (ReturnType->getOpcode() == SPIRV::OpTypeVector) {
307     unsigned Bits = GR->getScalarOrVectorBitWidth(ReturnType);
308     uint64_t AllOnes = APInt::getAllOnesValue(Bits).getZExtValue();
309     TrueConst = GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType);
310     FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType);
311   } else {
312     TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType);
313     FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType);
314   }
315   return MIRBuilder.buildSelect(ReturnRegister, SourceRegister, TrueConst,
316                                 FalseConst);
317 }
318
319 /// Helper function for building a load instruction loading into the
320 /// \p DestinationReg.
321 static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
322                               MachineIRBuilder &MIRBuilder,
323                               SPIRVGlobalRegistry *GR, LLT LowLevelType,
324                               Register DestinationReg = Register(0)) {
325   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
326   if (!DestinationReg.isValid()) {
327     DestinationReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
328     MRI->setType(DestinationReg, LLT::scalar(32));
329     GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF());
330   }
331   // TODO: consider using correct address space and alignment (p0 is canonical
332   // type for selection though).
333   MachinePointerInfo PtrInfo = MachinePointerInfo();
334   MIRBuilder.buildLoad(DestinationReg, PtrRegister, PtrInfo, Align());
335   return DestinationReg;
336 }
337
338 /// Helper function for building a load instruction for loading a builtin global
339 /// variable of \p BuiltinValue value.
340 static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
341                                          SPIRVType *VariableType,
342                                          SPIRVGlobalRegistry *GR,
343                                          SPIRV::BuiltIn::BuiltIn BuiltinValue,
344                                          LLT LLType,
345                                          Register Reg = Register(0)) {
346   Register NewRegister =
347       MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
348   MIRBuilder.getMRI()->setType(NewRegister,
349                                LLT::pointer(0, GR->getPointerSize()));
350   SPIRVType *PtrType = GR->getOrCreateSPIRVPointerType(
351       VariableType, MIRBuilder, SPIRV::StorageClass::Input);
352   GR->assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
353
354   // Set up the global OpVariable with the necessary builtin decorations.
355   Register Variable = GR->buildGlobalVariable(
356       NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr,
357       SPIRV::StorageClass::Input, nullptr, true, true,
358       SPIRV::LinkageType::Import, MIRBuilder, false);
359
360   // Load the value from the global variable.
361   Register LoadedRegister =
362       buildLoadInst(VariableType, Variable, MIRBuilder, GR, LLType, Reg);
363   MIRBuilder.getMRI()->setType(LoadedRegister, LLType);
364   return LoadedRegister;
365 }
366
367 /// Helper external function for inserting ASSIGN_TYPE instuction between \p Reg
368 /// and its definition, set the new register as a destination of the definition,
369 /// assign SPIRVType to both registers. If SpirvTy is provided, use it as
370 /// SPIRVType in ASSIGN_TYPE, otherwise create it from \p Ty. Defined in
371 /// SPIRVPreLegalizer.cpp.
372 extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
373                                   SPIRVGlobalRegistry *GR,
374                                   MachineIRBuilder &MIB,
375                                   MachineRegisterInfo &MRI);
376
377 // TODO: Move to TableGen.
378 static SPIRV::MemorySemantics::MemorySemantics
379 getSPIRVMemSemantics(std::memory_order MemOrder) {
380   switch (MemOrder) {
381   case std::memory_order::memory_order_relaxed:
382     return SPIRV::MemorySemantics::None;
383   case std::memory_order::memory_order_acquire:
384     return SPIRV::MemorySemantics::Acquire;
385   case std::memory_order::memory_order_release:
386     return SPIRV::MemorySemantics::Release;
387   case std::memory_order::memory_order_acq_rel:
388     return SPIRV::MemorySemantics::AcquireRelease;
389   case std::memory_order::memory_order_seq_cst:
390     return SPIRV::MemorySemantics::SequentiallyConsistent;
391   default:
392     llvm_unreachable("Unknown CL memory scope");
393   }
394 }
395
396 static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
397   switch (ClScope) {
398   case SPIRV::CLMemoryScope::memory_scope_work_item:
399     return SPIRV::Scope::Invocation;
400   case SPIRV::CLMemoryScope::memory_scope_work_group:
401     return SPIRV::Scope::Workgroup;
402   case SPIRV::CLMemoryScope::memory_scope_device:
403     return SPIRV::Scope::Device;
404   case SPIRV::CLMemoryScope::memory_scope_all_svm_devices:
405     return SPIRV::Scope::CrossDevice;
406   case SPIRV::CLMemoryScope::memory_scope_sub_group:
407     return SPIRV::Scope::Subgroup;
408   }
409   llvm_unreachable("Unknown CL memory scope");
410 }
411
412 static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
413                                     SPIRVGlobalRegistry *GR,
414                                     unsigned BitWidth = 32) {
415   SPIRVType *IntType = GR->getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
416   return GR->buildConstantInt(Val, MIRBuilder, IntType);
417 }
418
419 static Register buildScopeReg(Register CLScopeRegister,
420                               MachineIRBuilder &MIRBuilder,
421                               SPIRVGlobalRegistry *GR,
422                               const MachineRegisterInfo *MRI) {
423   auto CLScope =
424       static_cast<SPIRV::CLMemoryScope>(getIConstVal(CLScopeRegister, MRI));
425   SPIRV::Scope::Scope Scope = getSPIRVScope(CLScope);
426
427   if (CLScope == static_cast<unsigned>(Scope))
428     return CLScopeRegister;
429
430   return buildConstantIntReg(Scope, MIRBuilder, GR);
431 }
432
433 static Register buildMemSemanticsReg(Register SemanticsRegister,
434                                      Register PtrRegister,
435                                      const MachineRegisterInfo *MRI,
436                                      SPIRVGlobalRegistry *GR) {
437   std::memory_order Order =
438       static_cast<std::memory_order>(getIConstVal(SemanticsRegister, MRI));
439   unsigned Semantics =
440       getSPIRVMemSemantics(Order) |
441       getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
442
443   if (Order == Semantics)
444     return SemanticsRegister;
445
446   return Register();
447 }
448
449 /// Helper function for translating atomic init to OpStore.
450 static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
451                                 MachineIRBuilder &MIRBuilder) {
452   assert(Call->Arguments.size() == 2 &&
453          "Need 2 arguments for atomic init translation");
454
455   MIRBuilder.buildInstr(SPIRV::OpStore)
456       .addUse(Call->Arguments[0])
457       .addUse(Call->Arguments[1]);
458   return true;
459 }
460
461 /// Helper function for building an atomic load instruction.
462 static bool buildAtomicLoadInst(const SPIRV::IncomingCall *Call,
463                                 MachineIRBuilder &MIRBuilder,
464                                 SPIRVGlobalRegistry *GR) {
465   Register PtrRegister = Call->Arguments[0];
466   // TODO: if true insert call to __translate_ocl_memory_sccope before
467   // OpAtomicLoad and the function implementation. We can use Translator's
468   // output for transcoding/atomic_explicit_arguments.cl as an example.
469   Register ScopeRegister;
470   if (Call->Arguments.size() > 1)
471     ScopeRegister = Call->Arguments[1];
472   else
473     ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
474
475   Register MemSemanticsReg;
476   if (Call->Arguments.size() > 2) {
477     // TODO: Insert call to __translate_ocl_memory_order before OpAtomicLoad.
478     MemSemanticsReg = Call->Arguments[2];
479   } else {
480     int Semantics =
481         SPIRV::MemorySemantics::SequentiallyConsistent |
482         getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
483     MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
484   }
485
486   MIRBuilder.buildInstr(SPIRV::OpAtomicLoad)
487       .addDef(Call->ReturnRegister)
488       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
489       .addUse(PtrRegister)
490       .addUse(ScopeRegister)
491       .addUse(MemSemanticsReg);
492   return true;
493 }
494
495 /// Helper function for building an atomic store instruction.
496 static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
497                                  MachineIRBuilder &MIRBuilder,
498                                  SPIRVGlobalRegistry *GR) {
499   Register ScopeRegister =
500       buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
501   Register PtrRegister = Call->Arguments[0];
502   int Semantics =
503       SPIRV::MemorySemantics::SequentiallyConsistent |
504       getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
505   Register MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
506
507   MIRBuilder.buildInstr(SPIRV::OpAtomicStore)
508       .addUse(PtrRegister)
509       .addUse(ScopeRegister)
510       .addUse(MemSemanticsReg)
511       .addUse(Call->Arguments[1]);
512   return true;
513 }
514
515 /// Helper function for building an atomic compare-exchange instruction.
516 static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
517                                            MachineIRBuilder &MIRBuilder,
518                                            SPIRVGlobalRegistry *GR) {
519   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
520   unsigned Opcode =
521       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
522   bool IsCmpxchg = Call->Builtin->Name.contains("cmpxchg");
523   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
524
525   Register ObjectPtr = Call->Arguments[0];   // Pointer (volatile A *object.)
526   Register ExpectedArg = Call->Arguments[1]; // Comparator (C* expected).
527   Register Desired = Call->Arguments[2];     // Value (C Desired).
528   SPIRVType *SpvDesiredTy = GR->getSPIRVTypeForVReg(Desired);
529   LLT DesiredLLT = MRI->getType(Desired);
530
531   assert(GR->getSPIRVTypeForVReg(ObjectPtr)->getOpcode() ==
532          SPIRV::OpTypePointer);
533   unsigned ExpectedType = GR->getSPIRVTypeForVReg(ExpectedArg)->getOpcode();
534   assert(IsCmpxchg ? ExpectedType == SPIRV::OpTypeInt
535                    : ExpectedType == SPIRV::OpTypePointer);
536   assert(GR->isScalarOfType(Desired, SPIRV::OpTypeInt));
537
538   SPIRVType *SpvObjectPtrTy = GR->getSPIRVTypeForVReg(ObjectPtr);
539   assert(SpvObjectPtrTy->getOperand(2).isReg() && "SPIRV type is expected");
540   auto StorageClass = static_cast<SPIRV::StorageClass::StorageClass>(
541       SpvObjectPtrTy->getOperand(1).getImm());
542   auto MemSemStorage = getMemSemanticsForStorageClass(StorageClass);
543
544   Register MemSemEqualReg;
545   Register MemSemUnequalReg;
546   uint64_t MemSemEqual =
547       IsCmpxchg
548           ? SPIRV::MemorySemantics::None
549           : SPIRV::MemorySemantics::SequentiallyConsistent | MemSemStorage;
550   uint64_t MemSemUnequal =
551       IsCmpxchg
552           ? SPIRV::MemorySemantics::None
553           : SPIRV::MemorySemantics::SequentiallyConsistent | MemSemStorage;
554   if (Call->Arguments.size() >= 4) {
555     assert(Call->Arguments.size() >= 5 &&
556            "Need 5+ args for explicit atomic cmpxchg");
557     auto MemOrdEq =
558         static_cast<std::memory_order>(getIConstVal(Call->Arguments[3], MRI));
559     auto MemOrdNeq =
560         static_cast<std::memory_order>(getIConstVal(Call->Arguments[4], MRI));
561     MemSemEqual = getSPIRVMemSemantics(MemOrdEq) | MemSemStorage;
562     MemSemUnequal = getSPIRVMemSemantics(MemOrdNeq) | MemSemStorage;
563     if (MemOrdEq == MemSemEqual)
564       MemSemEqualReg = Call->Arguments[3];
565     if (MemOrdNeq == MemSemEqual)
566       MemSemUnequalReg = Call->Arguments[4];
567   }
568   if (!MemSemEqualReg.isValid())
569     MemSemEqualReg = buildConstantIntReg(MemSemEqual, MIRBuilder, GR);
570   if (!MemSemUnequalReg.isValid())
571     MemSemUnequalReg = buildConstantIntReg(MemSemUnequal, MIRBuilder, GR);
572
573   Register ScopeReg;
574   auto Scope = IsCmpxchg ? SPIRV::Scope::Workgroup : SPIRV::Scope::Device;
575   if (Call->Arguments.size() >= 6) {
576     assert(Call->Arguments.size() == 6 &&
577            "Extra args for explicit atomic cmpxchg");
578     auto ClScope = static_cast<SPIRV::CLMemoryScope>(
579         getIConstVal(Call->Arguments[5], MRI));
580     Scope = getSPIRVScope(ClScope);
581     if (ClScope == static_cast<unsigned>(Scope))
582       ScopeReg = Call->Arguments[5];
583   }
584   if (!ScopeReg.isValid())
585     ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR);
586
587   Register Expected = IsCmpxchg
588                           ? ExpectedArg
589                           : buildLoadInst(SpvDesiredTy, ExpectedArg, MIRBuilder,
590                                           GR, LLT::scalar(32));
591   MRI->setType(Expected, DesiredLLT);
592   Register Tmp = !IsCmpxchg ? MRI->createGenericVirtualRegister(DesiredLLT)
593                             : Call->ReturnRegister;
594   GR->assignSPIRVTypeToVReg(SpvDesiredTy, Tmp, MIRBuilder.getMF());
595
596   SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
597   MIRBuilder.buildInstr(Opcode)
598       .addDef(Tmp)
599       .addUse(GR->getSPIRVTypeID(IntTy))
600       .addUse(ObjectPtr)
601       .addUse(ScopeReg)
602       .addUse(MemSemEqualReg)
603       .addUse(MemSemUnequalReg)
604       .addUse(Desired)
605       .addUse(Expected);
606   if (!IsCmpxchg) {
607     MIRBuilder.buildInstr(SPIRV::OpStore).addUse(ExpectedArg).addUse(Tmp);
608     MIRBuilder.buildICmp(CmpInst::ICMP_EQ, Call->ReturnRegister, Tmp, Expected);
609   }
610   return true;
611 }
612
613 /// Helper function for building an atomic load instruction.
614 static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
615                                MachineIRBuilder &MIRBuilder,
616                                SPIRVGlobalRegistry *GR) {
617   const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
618   SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup;
619   Register ScopeRegister;
620
621   if (Call->Arguments.size() >= 4) {
622     assert(Call->Arguments.size() == 4 &&
623            "Too many args for explicit atomic RMW");
624     ScopeRegister = buildScopeReg(Call->Arguments[3], MIRBuilder, GR, MRI);
625   }
626
627   if (!ScopeRegister.isValid())
628     ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
629
630   Register PtrRegister = Call->Arguments[0];
631   unsigned Semantics = SPIRV::MemorySemantics::None;
632   Register MemSemanticsReg;
633
634   if (Call->Arguments.size() >= 3)
635     MemSemanticsReg =
636         buildMemSemanticsReg(Call->Arguments[2], PtrRegister, MRI, GR);
637
638   if (!MemSemanticsReg.isValid())
639     MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
640
641   MIRBuilder.buildInstr(Opcode)
642       .addDef(Call->ReturnRegister)
643       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
644       .addUse(PtrRegister)
645       .addUse(ScopeRegister)
646       .addUse(MemSemanticsReg)
647       .addUse(Call->Arguments[1]);
648   return true;
649 }
650
651 /// Helper function for building atomic flag instructions (e.g.
652 /// OpAtomicFlagTestAndSet).
653 static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
654                                 unsigned Opcode, MachineIRBuilder &MIRBuilder,
655                                 SPIRVGlobalRegistry *GR) {
656   const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
657
658   Register PtrRegister = Call->Arguments[0];
659   unsigned Semantics = SPIRV::MemorySemantics::SequentiallyConsistent;
660   Register MemSemanticsReg;
661
662   if (Call->Arguments.size() >= 2)
663     MemSemanticsReg =
664         buildMemSemanticsReg(Call->Arguments[1], PtrRegister, MRI, GR);
665
666   if (!MemSemanticsReg.isValid())
667     MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
668
669   assert((Opcode != SPIRV::OpAtomicFlagClear ||
670           (Semantics != SPIRV::MemorySemantics::Acquire &&
671            Semantics != SPIRV::MemorySemantics::AcquireRelease)) &&
672          "Invalid memory order argument!");
673
674   SPIRV::Scope::Scope Scope = SPIRV::Scope::Device;
675   Register ScopeRegister;
676
677   if (Call->Arguments.size() >= 3)
678     ScopeRegister = buildScopeReg(Call->Arguments[2], MIRBuilder, GR, MRI);
679
680   if (!ScopeRegister.isValid())
681     ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
682
683   auto MIB = MIRBuilder.buildInstr(Opcode);
684   if (Opcode == SPIRV::OpAtomicFlagTestAndSet)
685     MIB.addDef(Call->ReturnRegister)
686         .addUse(GR->getSPIRVTypeID(Call->ReturnType));
687
688   MIB.addUse(PtrRegister).addUse(ScopeRegister).addUse(MemSemanticsReg);
689   return true;
690 }
691
692 /// Helper function for building barriers, i.e., memory/control ordering
693 /// operations.
694 static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
695                              MachineIRBuilder &MIRBuilder,
696                              SPIRVGlobalRegistry *GR) {
697   const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
698   unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
699   unsigned MemSemantics = SPIRV::MemorySemantics::None;
700
701   if (MemFlags & SPIRV::CLK_LOCAL_MEM_FENCE)
702     MemSemantics |= SPIRV::MemorySemantics::WorkgroupMemory;
703
704   if (MemFlags & SPIRV::CLK_GLOBAL_MEM_FENCE)
705     MemSemantics |= SPIRV::MemorySemantics::CrossWorkgroupMemory;
706
707   if (MemFlags & SPIRV::CLK_IMAGE_MEM_FENCE)
708     MemSemantics |= SPIRV::MemorySemantics::ImageMemory;
709
710   if (Opcode == SPIRV::OpMemoryBarrier) {
711     std::memory_order MemOrder =
712         static_cast<std::memory_order>(getIConstVal(Call->Arguments[1], MRI));
713     MemSemantics = getSPIRVMemSemantics(MemOrder) | MemSemantics;
714   } else {
715     MemSemantics |= SPIRV::MemorySemantics::SequentiallyConsistent;
716   }
717
718   Register MemSemanticsReg;
719   if (MemFlags == MemSemantics)
720     MemSemanticsReg = Call->Arguments[0];
721   else
722     MemSemanticsReg = buildConstantIntReg(MemSemantics, MIRBuilder, GR);
723
724   Register ScopeReg;
725   SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup;
726   SPIRV::Scope::Scope MemScope = Scope;
727   if (Call->Arguments.size() >= 2) {
728     assert(
729         ((Opcode != SPIRV::OpMemoryBarrier && Call->Arguments.size() == 2) ||
730          (Opcode == SPIRV::OpMemoryBarrier && Call->Arguments.size() == 3)) &&
731         "Extra args for explicitly scoped barrier");
732     Register ScopeArg = (Opcode == SPIRV::OpMemoryBarrier) ? Call->Arguments[2]
733                                                            : Call->Arguments[1];
734     SPIRV::CLMemoryScope CLScope =
735         static_cast<SPIRV::CLMemoryScope>(getIConstVal(ScopeArg, MRI));
736     MemScope = getSPIRVScope(CLScope);
737     if (!(MemFlags & SPIRV::CLK_LOCAL_MEM_FENCE) ||
738         (Opcode == SPIRV::OpMemoryBarrier))
739       Scope = MemScope;
740
741     if (CLScope == static_cast<unsigned>(Scope))
742       ScopeReg = Call->Arguments[1];
743   }
744
745   if (!ScopeReg.isValid())
746     ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR);
747
748   auto MIB = MIRBuilder.buildInstr(Opcode).addUse(ScopeReg);
749   if (Opcode != SPIRV::OpMemoryBarrier)
750     MIB.addUse(buildConstantIntReg(MemScope, MIRBuilder, GR));
751   MIB.addUse(MemSemanticsReg);
752   return true;
753 }
754
755 static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) {
756   switch (dim) {
757   case SPIRV::Dim::DIM_1D:
758   case SPIRV::Dim::DIM_Buffer:
759     return 1;
760   case SPIRV::Dim::DIM_2D:
761   case SPIRV::Dim::DIM_Cube:
762   case SPIRV::Dim::DIM_Rect:
763     return 2;
764   case SPIRV::Dim::DIM_3D:
765     return 3;
766   default:
767     llvm_unreachable("Cannot get num components for given Dim");
768   }
769 }
770
771 /// Helper function for obtaining the number of size components.
772 static unsigned getNumSizeComponents(SPIRVType *imgType) {
773   assert(imgType->getOpcode() == SPIRV::OpTypeImage);
774   auto dim = static_cast<SPIRV::Dim::Dim>(imgType->getOperand(2).getImm());
775   unsigned numComps = getNumComponentsForDim(dim);
776   bool arrayed = imgType->getOperand(4).getImm() == 1;
777   return arrayed ? numComps + 1 : numComps;
778 }
779
780 //===----------------------------------------------------------------------===//
781 // Implementation functions for each builtin group
782 //===----------------------------------------------------------------------===//
783
784 static bool generateExtInst(const SPIRV::IncomingCall *Call,
785                             MachineIRBuilder &MIRBuilder,
786                             SPIRVGlobalRegistry *GR) {
787   // Lookup the extended instruction number in the TableGen records.
788   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
789   uint32_t Number =
790       SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number;
791
792   // Build extended instruction.
793   auto MIB =
794       MIRBuilder.buildInstr(SPIRV::OpExtInst)
795           .addDef(Call->ReturnRegister)
796           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
797           .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
798           .addImm(Number);
799
800   for (auto Argument : Call->Arguments)
801     MIB.addUse(Argument);
802   return true;
803 }
804
805 static bool generateRelationalInst(const SPIRV::IncomingCall *Call,
806                                    MachineIRBuilder &MIRBuilder,
807                                    SPIRVGlobalRegistry *GR) {
808   // Lookup the instruction opcode in the TableGen records.
809   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
810   unsigned Opcode =
811       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
812
813   Register CompareRegister;
814   SPIRVType *RelationType;
815   std::tie(CompareRegister, RelationType) =
816       buildBoolRegister(MIRBuilder, Call->ReturnType, GR);
817
818   // Build relational instruction.
819   auto MIB = MIRBuilder.buildInstr(Opcode)
820                  .addDef(CompareRegister)
821                  .addUse(GR->getSPIRVTypeID(RelationType));
822
823   for (auto Argument : Call->Arguments)
824     MIB.addUse(Argument);
825
826   // Build select instruction.
827   return buildSelectInst(MIRBuilder, Call->ReturnRegister, CompareRegister,
828                          Call->ReturnType, GR);
829 }
830
831 static bool generateGroupInst(const SPIRV::IncomingCall *Call,
832                               MachineIRBuilder &MIRBuilder,
833                               SPIRVGlobalRegistry *GR) {
834   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
835   const SPIRV::GroupBuiltin *GroupBuiltin =
836       SPIRV::lookupGroupBuiltin(Builtin->Name);
837   const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
838   Register Arg0;
839   if (GroupBuiltin->HasBoolArg) {
840     Register ConstRegister = Call->Arguments[0];
841     auto ArgInstruction = getDefInstrMaybeConstant(ConstRegister, MRI);
842     // TODO: support non-constant bool values.
843     assert(ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT &&
844            "Only constant bool value args are supported");
845     if (GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode() !=
846         SPIRV::OpTypeBool)
847       Arg0 = GR->buildConstantInt(getIConstVal(ConstRegister, MRI), MIRBuilder,
848                                   GR->getOrCreateSPIRVBoolType(MIRBuilder));
849   }
850
851   Register GroupResultRegister = Call->ReturnRegister;
852   SPIRVType *GroupResultType = Call->ReturnType;
853
854   // TODO: maybe we need to check whether the result type is already boolean
855   // and in this case do not insert select instruction.
856   const bool HasBoolReturnTy =
857       GroupBuiltin->IsElect || GroupBuiltin->IsAllOrAny ||
858       GroupBuiltin->IsAllEqual || GroupBuiltin->IsLogical ||
859       GroupBuiltin->IsInverseBallot || GroupBuiltin->IsBallotBitExtract;
860
861   if (HasBoolReturnTy)
862     std::tie(GroupResultRegister, GroupResultType) =
863         buildBoolRegister(MIRBuilder, Call->ReturnType, GR);
864
865   auto Scope = Builtin->Name.startswith("sub_group") ? SPIRV::Scope::Subgroup
866                                                      : SPIRV::Scope::Workgroup;
867   Register ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
868
869   // Build work/sub group instruction.
870   auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
871                  .addDef(GroupResultRegister)
872                  .addUse(GR->getSPIRVTypeID(GroupResultType))
873                  .addUse(ScopeRegister);
874
875   if (!GroupBuiltin->NoGroupOperation)
876     MIB.addImm(GroupBuiltin->GroupOperation);
877   if (Call->Arguments.size() > 0) {
878     MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
879     for (unsigned i = 1; i < Call->Arguments.size(); i++)
880       MIB.addUse(Call->Arguments[i]);
881   }
882
883   // Build select instruction.
884   if (HasBoolReturnTy)
885     buildSelectInst(MIRBuilder, Call->ReturnRegister, GroupResultRegister,
886                     Call->ReturnType, GR);
887   return true;
888 }
889
890 // These queries ask for a single size_t result for a given dimension index, e.g
891 // size_t get_global_id(uint dimindex). In SPIR-V, the builtins corresonding to
892 // these values are all vec3 types, so we need to extract the correct index or
893 // return defaultVal (0 or 1 depending on the query). We also handle extending
894 // or tuncating in case size_t does not match the expected result type's
895 // bitwidth.
896 //
897 // For a constant index >= 3 we generate:
898 //  %res = OpConstant %SizeT 0
899 //
900 // For other indices we generate:
901 //  %g = OpVariable %ptr_V3_SizeT Input
902 //  OpDecorate %g BuiltIn XXX
903 //  OpDecorate %g LinkageAttributes "__spirv_BuiltInXXX"
904 //  OpDecorate %g Constant
905 //  %loadedVec = OpLoad %V3_SizeT %g
906 //
907 //  Then, if the index is constant < 3, we generate:
908 //    %res = OpCompositeExtract %SizeT %loadedVec idx
909 //  If the index is dynamic, we generate:
910 //    %tmp = OpVectorExtractDynamic %SizeT %loadedVec %idx
911 //    %cmp = OpULessThan %bool %idx %const_3
912 //    %res = OpSelect %SizeT %cmp %tmp %const_0
913 //
914 //  If the bitwidth of %res does not match the expected return type, we add an
915 //  extend or truncate.
916 static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
917                               MachineIRBuilder &MIRBuilder,
918                               SPIRVGlobalRegistry *GR,
919                               SPIRV::BuiltIn::BuiltIn BuiltinValue,
920                               uint64_t DefaultValue) {
921   Register IndexRegister = Call->Arguments[0];
922   const unsigned ResultWidth = Call->ReturnType->getOperand(1).getImm();
923   const unsigned PointerSize = GR->getPointerSize();
924   const SPIRVType *PointerSizeType =
925       GR->getOrCreateSPIRVIntegerType(PointerSize, MIRBuilder);
926   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
927   auto IndexInstruction = getDefInstrMaybeConstant(IndexRegister, MRI);
928
929   // Set up the final register to do truncation or extension on at the end.
930   Register ToTruncate = Call->ReturnRegister;
931
932   // If the index is constant, we can statically determine if it is in range.
933   bool IsConstantIndex =
934       IndexInstruction->getOpcode() == TargetOpcode::G_CONSTANT;
935
936   // If it's out of range (max dimension is 3), we can just return the constant
937   // default value (0 or 1 depending on which query function).
938   if (IsConstantIndex && getIConstVal(IndexRegister, MRI) >= 3) {
939     Register defaultReg = Call->ReturnRegister;
940     if (PointerSize != ResultWidth) {
941       defaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
942       GR->assignSPIRVTypeToVReg(PointerSizeType, defaultReg,
943                                 MIRBuilder.getMF());
944       ToTruncate = defaultReg;
945     }
946     auto NewRegister =
947         GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
948     MIRBuilder.buildCopy(defaultReg, NewRegister);
949   } else { // If it could be in range, we need to load from the given builtin.
950     auto Vec3Ty =
951         GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder);
952     Register LoadedVector =
953         buildBuiltinVariableLoad(MIRBuilder, Vec3Ty, GR, BuiltinValue,
954                                  LLT::fixed_vector(3, PointerSize));
955     // Set up the vreg to extract the result to (possibly a new temporary one).
956     Register Extracted = Call->ReturnRegister;
957     if (!IsConstantIndex || PointerSize != ResultWidth) {
958       Extracted = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
959       GR->assignSPIRVTypeToVReg(PointerSizeType, Extracted, MIRBuilder.getMF());
960     }
961     // Use Intrinsic::spv_extractelt so dynamic vs static extraction is
962     // handled later: extr = spv_extractelt LoadedVector, IndexRegister.
963     MachineInstrBuilder ExtractInst = MIRBuilder.buildIntrinsic(
964         Intrinsic::spv_extractelt, ArrayRef<Register>{Extracted}, true);
965     ExtractInst.addUse(LoadedVector).addUse(IndexRegister);
966
967     // If the index is dynamic, need check if it's < 3, and then use a select.
968     if (!IsConstantIndex) {
969       insertAssignInstr(Extracted, nullptr, PointerSizeType, GR, MIRBuilder,
970                         *MRI);
971
972       auto IndexType = GR->getSPIRVTypeForVReg(IndexRegister);
973       auto BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
974
975       Register CompareRegister =
976           MRI->createGenericVirtualRegister(LLT::scalar(1));
977       GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());
978
979       // Use G_ICMP to check if idxVReg < 3.
980       MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
981                            GR->buildConstantInt(3, MIRBuilder, IndexType));
982
983       // Get constant for the default value (0 or 1 depending on which
984       // function).
985       Register DefaultRegister =
986           GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
987
988       // Get a register for the selection result (possibly a new temporary one).
989       Register SelectionResult = Call->ReturnRegister;
990       if (PointerSize != ResultWidth) {
991         SelectionResult =
992             MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
993         GR->assignSPIRVTypeToVReg(PointerSizeType, SelectionResult,
994                                   MIRBuilder.getMF());
995       }
996       // Create the final G_SELECT to return the extracted value or the default.
997       MIRBuilder.buildSelect(SelectionResult, CompareRegister, Extracted,
998                              DefaultRegister);
999       ToTruncate = SelectionResult;
1000     } else {
1001       ToTruncate = Extracted;
1002     }
1003   }
1004   // Alter the result's bitwidth if it does not match the SizeT value extracted.
1005   if (PointerSize != ResultWidth)
1006     MIRBuilder.buildZExtOrTrunc(Call->ReturnRegister, ToTruncate);
1007   return true;
1008 }
1009
1010 static bool generateBuiltinVar(const SPIRV::IncomingCall *Call,
1011                                MachineIRBuilder &MIRBuilder,
1012                                SPIRVGlobalRegistry *GR) {
1013   // Lookup the builtin variable record.
1014   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1015   SPIRV::BuiltIn::BuiltIn Value =
1016       SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
1017
1018   if (Value == SPIRV::BuiltIn::GlobalInvocationId)
1019     return genWorkgroupQuery(Call, MIRBuilder, GR, Value, 0);
1020
1021   // Build a load instruction for the builtin variable.
1022   unsigned BitWidth = GR->getScalarOrVectorBitWidth(Call->ReturnType);
1023   LLT LLType;
1024   if (Call->ReturnType->getOpcode() == SPIRV::OpTypeVector)
1025     LLType =
1026         LLT::fixed_vector(Call->ReturnType->getOperand(2).getImm(), BitWidth);
1027   else
1028     LLType = LLT::scalar(BitWidth);
1029
1030   return buildBuiltinVariableLoad(MIRBuilder, Call->ReturnType, GR, Value,
1031                                   LLType, Call->ReturnRegister);
1032 }
1033
1034 static bool generateAtomicInst(const SPIRV::IncomingCall *Call,
1035                                MachineIRBuilder &MIRBuilder,
1036                                SPIRVGlobalRegistry *GR) {
1037   // Lookup the instruction opcode in the TableGen records.
1038   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1039   unsigned Opcode =
1040       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1041
1042   switch (Opcode) {
1043   case SPIRV::OpStore:
1044     return buildAtomicInitInst(Call, MIRBuilder);
1045   case SPIRV::OpAtomicLoad:
1046     return buildAtomicLoadInst(Call, MIRBuilder, GR);
1047   case SPIRV::OpAtomicStore:
1048     return buildAtomicStoreInst(Call, MIRBuilder, GR);
1049   case SPIRV::OpAtomicCompareExchange:
1050   case SPIRV::OpAtomicCompareExchangeWeak:
1051     return buildAtomicCompareExchangeInst(Call, MIRBuilder, GR);
1052   case SPIRV::OpAtomicIAdd:
1053   case SPIRV::OpAtomicISub:
1054   case SPIRV::OpAtomicOr:
1055   case SPIRV::OpAtomicXor:
1056   case SPIRV::OpAtomicAnd:
1057   case SPIRV::OpAtomicExchange:
1058     return buildAtomicRMWInst(Call, Opcode, MIRBuilder, GR);
1059   case SPIRV::OpMemoryBarrier:
1060     return buildBarrierInst(Call, SPIRV::OpMemoryBarrier, MIRBuilder, GR);
1061   case SPIRV::OpAtomicFlagTestAndSet:
1062   case SPIRV::OpAtomicFlagClear:
1063     return buildAtomicFlagInst(Call, Opcode, MIRBuilder, GR);
1064   default:
1065     return false;
1066   }
1067 }
1068
1069 static bool generateBarrierInst(const SPIRV::IncomingCall *Call,
1070                                 MachineIRBuilder &MIRBuilder,
1071                                 SPIRVGlobalRegistry *GR) {
1072   // Lookup the instruction opcode in the TableGen records.
1073   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1074   unsigned Opcode =
1075       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1076
1077   return buildBarrierInst(Call, Opcode, MIRBuilder, GR);
1078 }
1079
1080 static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call,
1081                                   MachineIRBuilder &MIRBuilder,
1082                                   SPIRVGlobalRegistry *GR) {
1083   unsigned Opcode = GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode();
1084   bool IsVec = Opcode == SPIRV::OpTypeVector;
1085   // Use OpDot only in case of vector args and OpFMul in case of scalar args.
1086   MIRBuilder.buildInstr(IsVec ? SPIRV::OpDot : SPIRV::OpFMulS)
1087       .addDef(Call->ReturnRegister)
1088       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1089       .addUse(Call->Arguments[0])
1090       .addUse(Call->Arguments[1]);
1091   return true;
1092 }
1093
1094 static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
1095                                  MachineIRBuilder &MIRBuilder,
1096                                  SPIRVGlobalRegistry *GR) {
1097   // Lookup the builtin record.
1098   SPIRV::BuiltIn::BuiltIn Value =
1099       SPIRV::lookupGetBuiltin(Call->Builtin->Name, Call->Builtin->Set)->Value;
1100   uint64_t IsDefault = (Value == SPIRV::BuiltIn::GlobalSize ||
1101                         Value == SPIRV::BuiltIn::WorkgroupSize ||
1102                         Value == SPIRV::BuiltIn::EnqueuedWorkgroupSize);
1103   return genWorkgroupQuery(Call, MIRBuilder, GR, Value, IsDefault ? 1 : 0);
1104 }
1105
1106 static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
1107                                        MachineIRBuilder &MIRBuilder,
1108                                        SPIRVGlobalRegistry *GR) {
1109   // Lookup the image size query component number in the TableGen records.
1110   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1111   uint32_t Component =
1112       SPIRV::lookupImageQueryBuiltin(Builtin->Name, Builtin->Set)->Component;
1113   // Query result may either be a vector or a scalar. If return type is not a
1114   // vector, expect only a single size component. Otherwise get the number of
1115   // expected components.
1116   SPIRVType *RetTy = Call->ReturnType;
1117   unsigned NumExpectedRetComponents = RetTy->getOpcode() == SPIRV::OpTypeVector
1118                                           ? RetTy->getOperand(2).getImm()
1119                                           : 1;
1120   // Get the actual number of query result/size components.
1121   SPIRVType *ImgType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
1122   unsigned NumActualRetComponents = getNumSizeComponents(ImgType);
1123   Register QueryResult = Call->ReturnRegister;
1124   SPIRVType *QueryResultType = Call->ReturnType;
1125   if (NumExpectedRetComponents != NumActualRetComponents) {
1126     QueryResult = MIRBuilder.getMRI()->createGenericVirtualRegister(
1127         LLT::fixed_vector(NumActualRetComponents, 32));
1128     SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
1129     QueryResultType = GR->getOrCreateSPIRVVectorType(
1130         IntTy, NumActualRetComponents, MIRBuilder);
1131     GR->assignSPIRVTypeToVReg(QueryResultType, QueryResult, MIRBuilder.getMF());
1132   }
1133   bool IsDimBuf = ImgType->getOperand(2).getImm() == SPIRV::Dim::DIM_Buffer;
1134   unsigned Opcode =
1135       IsDimBuf ? SPIRV::OpImageQuerySize : SPIRV::OpImageQuerySizeLod;
1136   auto MIB = MIRBuilder.buildInstr(Opcode)
1137                  .addDef(QueryResult)
1138                  .addUse(GR->getSPIRVTypeID(QueryResultType))
1139                  .addUse(Call->Arguments[0]);
1140   if (!IsDimBuf)
1141     MIB.addUse(buildConstantIntReg(0, MIRBuilder, GR)); // Lod id.
1142   if (NumExpectedRetComponents == NumActualRetComponents)
1143     return true;
1144   if (NumExpectedRetComponents == 1) {
1145     // Only 1 component is expected, build OpCompositeExtract instruction.
1146     unsigned ExtractedComposite =
1147         Component == 3 ? NumActualRetComponents - 1 : Component;
1148     assert(ExtractedComposite < NumActualRetComponents &&
1149            "Invalid composite index!");
1150     MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
1151         .addDef(Call->ReturnRegister)
1152         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1153         .addUse(QueryResult)
1154         .addImm(ExtractedComposite);
1155   } else {
1156     // More than 1 component is expected, fill a new vector.
1157     auto MIB = MIRBuilder.buildInstr(SPIRV::OpVectorShuffle)
1158                    .addDef(Call->ReturnRegister)
1159                    .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1160                    .addUse(QueryResult)
1161                    .addUse(QueryResult);
1162     for (unsigned i = 0; i < NumExpectedRetComponents; ++i)
1163       MIB.addImm(i < NumActualRetComponents ? i : 0xffffffff);
1164   }
1165   return true;
1166 }
1167
1168 static bool generateImageMiscQueryInst(const SPIRV::IncomingCall *Call,
1169                                        MachineIRBuilder &MIRBuilder,
1170                                        SPIRVGlobalRegistry *GR) {
1171   assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt &&
1172          "Image samples query result must be of int type!");
1173
1174   // Lookup the instruction opcode in the TableGen records.
1175   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1176   unsigned Opcode =
1177       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1178
1179   Register Image = Call->Arguments[0];
1180   SPIRV::Dim::Dim ImageDimensionality = static_cast<SPIRV::Dim::Dim>(
1181       GR->getSPIRVTypeForVReg(Image)->getOperand(2).getImm());
1182
1183   switch (Opcode) {
1184   case SPIRV::OpImageQuerySamples:
1185     assert(ImageDimensionality == SPIRV::Dim::DIM_2D &&
1186            "Image must be of 2D dimensionality");
1187     break;
1188   case SPIRV::OpImageQueryLevels:
1189     assert((ImageDimensionality == SPIRV::Dim::DIM_1D ||
1190             ImageDimensionality == SPIRV::Dim::DIM_2D ||
1191             ImageDimensionality == SPIRV::Dim::DIM_3D ||
1192             ImageDimensionality == SPIRV::Dim::DIM_Cube) &&
1193            "Image must be of 1D/2D/3D/Cube dimensionality");
1194     break;
1195   }
1196
1197   MIRBuilder.buildInstr(Opcode)
1198       .addDef(Call->ReturnRegister)
1199       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1200       .addUse(Image);
1201   return true;
1202 }
1203
1204 // TODO: Move to TableGen.
1205 static SPIRV::SamplerAddressingMode::SamplerAddressingMode
1206 getSamplerAddressingModeFromBitmask(unsigned Bitmask) {
1207   switch (Bitmask & SPIRV::CLK_ADDRESS_MODE_MASK) {
1208   case SPIRV::CLK_ADDRESS_CLAMP:
1209     return SPIRV::SamplerAddressingMode::Clamp;
1210   case SPIRV::CLK_ADDRESS_CLAMP_TO_EDGE:
1211     return SPIRV::SamplerAddressingMode::ClampToEdge;
1212   case SPIRV::CLK_ADDRESS_REPEAT:
1213     return SPIRV::SamplerAddressingMode::Repeat;
1214   case SPIRV::CLK_ADDRESS_MIRRORED_REPEAT:
1215     return SPIRV::SamplerAddressingMode::RepeatMirrored;
1216   case SPIRV::CLK_ADDRESS_NONE:
1217     return SPIRV::SamplerAddressingMode::None;
1218   default:
1219     llvm_unreachable("Unknown CL address mode");
1220   }
1221 }
1222
1223 static unsigned getSamplerParamFromBitmask(unsigned Bitmask) {
1224   return (Bitmask & SPIRV::CLK_NORMALIZED_COORDS_TRUE) ? 1 : 0;
1225 }
1226
1227 static SPIRV::SamplerFilterMode::SamplerFilterMode
1228 getSamplerFilterModeFromBitmask(unsigned Bitmask) {
1229   if (Bitmask & SPIRV::CLK_FILTER_LINEAR)
1230     return SPIRV::SamplerFilterMode::Linear;
1231   if (Bitmask & SPIRV::CLK_FILTER_NEAREST)
1232     return SPIRV::SamplerFilterMode::Nearest;
1233   return SPIRV::SamplerFilterMode::Nearest;
1234 }
1235
1236 static bool generateReadImageInst(const StringRef DemangledCall,
1237                                   const SPIRV::IncomingCall *Call,
1238                                   MachineIRBuilder &MIRBuilder,
1239                                   SPIRVGlobalRegistry *GR) {
1240   Register Image = Call->Arguments[0];
1241   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1242
1243   if (DemangledCall.contains_insensitive("ocl_sampler")) {
1244     Register Sampler = Call->Arguments[1];
1245
1246     if (!GR->isScalarOfType(Sampler, SPIRV::OpTypeSampler) &&
1247         getDefInstrMaybeConstant(Sampler, MRI)->getOperand(1).isCImm()) {
1248       uint64_t SamplerMask = getIConstVal(Sampler, MRI);
1249       Sampler = GR->buildConstantSampler(
1250           Register(), getSamplerAddressingModeFromBitmask(SamplerMask),
1251           getSamplerParamFromBitmask(SamplerMask),
1252           getSamplerFilterModeFromBitmask(SamplerMask), MIRBuilder,
1253           GR->getSPIRVTypeForVReg(Sampler));
1254     }
1255     SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image);
1256     SPIRVType *SampledImageType =
1257         GR->getOrCreateOpTypeSampledImage(ImageType, MIRBuilder);
1258     Register SampledImage = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1259
1260     MIRBuilder.buildInstr(SPIRV::OpSampledImage)
1261         .addDef(SampledImage)
1262         .addUse(GR->getSPIRVTypeID(SampledImageType))
1263         .addUse(Image)
1264         .addUse(Sampler);
1265
1266     Register Lod = GR->buildConstantFP(APFloat::getZero(APFloat::IEEEsingle()),
1267                                        MIRBuilder);
1268     SPIRVType *TempType = Call->ReturnType;
1269     bool NeedsExtraction = false;
1270     if (TempType->getOpcode() != SPIRV::OpTypeVector) {
1271       TempType =
1272           GR->getOrCreateSPIRVVectorType(Call->ReturnType, 4, MIRBuilder);
1273       NeedsExtraction = true;
1274     }
1275     LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(TempType));
1276     Register TempRegister = MRI->createGenericVirtualRegister(LLType);
1277     GR->assignSPIRVTypeToVReg(TempType, TempRegister, MIRBuilder.getMF());
1278
1279     MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
1280         .addDef(NeedsExtraction ? TempRegister : Call->ReturnRegister)
1281         .addUse(GR->getSPIRVTypeID(TempType))
1282         .addUse(SampledImage)
1283         .addUse(Call->Arguments[2]) // Coordinate.
1284         .addImm(SPIRV::ImageOperand::Lod)
1285         .addUse(Lod);
1286
1287     if (NeedsExtraction)
1288       MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
1289           .addDef(Call->ReturnRegister)
1290           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1291           .addUse(TempRegister)
1292           .addImm(0);
1293   } else if (DemangledCall.contains_insensitive("msaa")) {
1294     MIRBuilder.buildInstr(SPIRV::OpImageRead)
1295         .addDef(Call->ReturnRegister)
1296         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1297         .addUse(Image)
1298         .addUse(Call->Arguments[1]) // Coordinate.
1299         .addImm(SPIRV::ImageOperand::Sample)
1300         .addUse(Call->Arguments[2]);
1301   } else {
1302     MIRBuilder.buildInstr(SPIRV::OpImageRead)
1303         .addDef(Call->ReturnRegister)
1304         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1305         .addUse(Image)
1306         .addUse(Call->Arguments[1]); // Coordinate.
1307   }
1308   return true;
1309 }
1310
1311 static bool generateWriteImageInst(const SPIRV::IncomingCall *Call,
1312                                    MachineIRBuilder &MIRBuilder,
1313                                    SPIRVGlobalRegistry *GR) {
1314   MIRBuilder.buildInstr(SPIRV::OpImageWrite)
1315       .addUse(Call->Arguments[0])  // Image.
1316       .addUse(Call->Arguments[1])  // Coordinate.
1317       .addUse(Call->Arguments[2]); // Texel.
1318   return true;
1319 }
1320
1321 static bool generateSampleImageInst(const StringRef DemangledCall,
1322                                     const SPIRV::IncomingCall *Call,
1323                                     MachineIRBuilder &MIRBuilder,
1324                                     SPIRVGlobalRegistry *GR) {
1325   if (Call->Builtin->Name.contains_insensitive(
1326           "__translate_sampler_initializer")) {
1327     // Build sampler literal.
1328     uint64_t Bitmask = getIConstVal(Call->Arguments[0], MIRBuilder.getMRI());
1329     Register Sampler = GR->buildConstantSampler(
1330         Call->ReturnRegister, getSamplerAddressingModeFromBitmask(Bitmask),
1331         getSamplerParamFromBitmask(Bitmask),
1332         getSamplerFilterModeFromBitmask(Bitmask), MIRBuilder, Call->ReturnType);
1333     return Sampler.isValid();
1334   } else if (Call->Builtin->Name.contains_insensitive("__spirv_SampledImage")) {
1335     // Create OpSampledImage.
1336     Register Image = Call->Arguments[0];
1337     SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image);
1338     SPIRVType *SampledImageType =
1339         GR->getOrCreateOpTypeSampledImage(ImageType, MIRBuilder);
1340     Register SampledImage =
1341         Call->ReturnRegister.isValid()
1342             ? Call->ReturnRegister
1343             : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
1344     MIRBuilder.buildInstr(SPIRV::OpSampledImage)
1345         .addDef(SampledImage)
1346         .addUse(GR->getSPIRVTypeID(SampledImageType))
1347         .addUse(Image)
1348         .addUse(Call->Arguments[1]); // Sampler.
1349     return true;
1350   } else if (Call->Builtin->Name.contains_insensitive(
1351                  "__spirv_ImageSampleExplicitLod")) {
1352     // Sample an image using an explicit level of detail.
1353     std::string ReturnType = DemangledCall.str();
1354     if (DemangledCall.contains("_R")) {
1355       ReturnType = ReturnType.substr(ReturnType.find("_R") + 2);
1356       ReturnType = ReturnType.substr(0, ReturnType.find('('));
1357     }
1358     SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
1359     MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
1360         .addDef(Call->ReturnRegister)
1361         .addUse(GR->getSPIRVTypeID(Type))
1362         .addUse(Call->Arguments[0]) // Image.
1363         .addUse(Call->Arguments[1]) // Coordinate.
1364         .addImm(SPIRV::ImageOperand::Lod)
1365         .addUse(Call->Arguments[3]);
1366     return true;
1367   }
1368   return false;
1369 }
1370
1371 static bool generateSelectInst(const SPIRV::IncomingCall *Call,
1372                                MachineIRBuilder &MIRBuilder) {
1373   MIRBuilder.buildSelect(Call->ReturnRegister, Call->Arguments[0],
1374                          Call->Arguments[1], Call->Arguments[2]);
1375   return true;
1376 }
1377
1378 static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
1379                                      MachineIRBuilder &MIRBuilder,
1380                                      SPIRVGlobalRegistry *GR) {
1381   // Lookup the instruction opcode in the TableGen records.
1382   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1383   unsigned Opcode =
1384       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1385   const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1386
1387   switch (Opcode) {
1388   case SPIRV::OpSpecConstant: {
1389     // Build the SpecID decoration.
1390     unsigned SpecId =
1391         static_cast<unsigned>(getIConstVal(Call->Arguments[0], MRI));
1392     buildOpDecorate(Call->ReturnRegister, MIRBuilder, SPIRV::Decoration::SpecId,
1393                     {SpecId});
1394     // Determine the constant MI.
1395     Register ConstRegister = Call->Arguments[1];
1396     const MachineInstr *Const = getDefInstrMaybeConstant(ConstRegister, MRI);
1397     assert(Const &&
1398            (Const->getOpcode() == TargetOpcode::G_CONSTANT ||
1399             Const->getOpcode() == TargetOpcode::G_FCONSTANT) &&
1400            "Argument should be either an int or floating-point constant");
1401     // Determine the opcode and built the OpSpec MI.
1402     const MachineOperand &ConstOperand = Const->getOperand(1);
1403     if (Call->ReturnType->getOpcode() == SPIRV::OpTypeBool) {
1404       assert(ConstOperand.isCImm() && "Int constant operand is expected");
1405       Opcode = ConstOperand.getCImm()->getValue().getZExtValue()
1406                    ? SPIRV::OpSpecConstantTrue
1407                    : SPIRV::OpSpecConstantFalse;
1408     }
1409     auto MIB = MIRBuilder.buildInstr(Opcode)
1410                    .addDef(Call->ReturnRegister)
1411                    .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1412
1413     if (Call->ReturnType->getOpcode() != SPIRV::OpTypeBool) {
1414       if (Const->getOpcode() == TargetOpcode::G_CONSTANT)
1415         addNumImm(ConstOperand.getCImm()->getValue(), MIB);
1416       else
1417         addNumImm(ConstOperand.getFPImm()->getValueAPF().bitcastToAPInt(), MIB);
1418     }
1419     return true;
1420   }
1421   case SPIRV::OpSpecConstantComposite: {
1422     auto MIB = MIRBuilder.buildInstr(Opcode)
1423                    .addDef(Call->ReturnRegister)
1424                    .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1425     for (unsigned i = 0; i < Call->Arguments.size(); i++)
1426       MIB.addUse(Call->Arguments[i]);
1427     return true;
1428   }
1429   default:
1430     return false;
1431   }
1432 }
1433
1434 static MachineInstr *getBlockStructInstr(Register ParamReg,
1435                                          MachineRegisterInfo *MRI) {
1436   // We expect the following sequence of instructions:
1437   //   %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
1438   //   or       = G_GLOBAL_VALUE @block_literal_global
1439   //   %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
1440   //   %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
1441   MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
1442   assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
1443          MI->getOperand(1).isReg());
1444   Register BitcastReg = MI->getOperand(1).getReg();
1445   MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
1446   assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
1447          BitcastMI->getOperand(2).isReg());
1448   Register ValueReg = BitcastMI->getOperand(2).getReg();
1449   MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
1450   return ValueMI;
1451 }
1452
1453 // Return an integer constant corresponding to the given register and
1454 // defined in spv_track_constant.
1455 // TODO: maybe unify with prelegalizer pass.
1456 static unsigned getConstFromIntrinsic(Register Reg, MachineRegisterInfo *MRI) {
1457   MachineInstr *DefMI = MRI->getUniqueVRegDef(Reg);
1458   assert(isSpvIntrinsic(*DefMI, Intrinsic::spv_track_constant) &&
1459          DefMI->getOperand(2).isReg());
1460   MachineInstr *DefMI2 = MRI->getUniqueVRegDef(DefMI->getOperand(2).getReg());
1461   assert(DefMI2->getOpcode() == TargetOpcode::G_CONSTANT &&
1462          DefMI2->getOperand(1).isCImm());
1463   return DefMI2->getOperand(1).getCImm()->getValue().getZExtValue();
1464 }
1465
1466 // Return type of the instruction result from spv_assign_type intrinsic.
1467 // TODO: maybe unify with prelegalizer pass.
1468 static const Type *getMachineInstrType(MachineInstr *MI) {
1469   MachineInstr *NextMI = MI->getNextNode();
1470   if (isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_name))
1471     NextMI = NextMI->getNextNode();
1472   Register ValueReg = MI->getOperand(0).getReg();
1473   if (!isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_type) ||
1474       NextMI->getOperand(1).getReg() != ValueReg)
1475     return nullptr;
1476   Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0);
1477   assert(Ty && "Type is expected");
1478   return getTypedPtrEltType(Ty);
1479 }
1480
1481 static const Type *getBlockStructType(Register ParamReg,
1482                                       MachineRegisterInfo *MRI) {
1483   // In principle, this information should be passed to us from Clang via
1484   // an elementtype attribute. However, said attribute requires that
1485   // the function call be an intrinsic, which is not. Instead, we rely on being
1486   // able to trace this to the declaration of a variable: OpenCL C specification
1487   // section 6.12.5 should guarantee that we can do this.
1488   MachineInstr *MI = getBlockStructInstr(ParamReg, MRI);
1489   if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE)
1490     return getTypedPtrEltType(MI->getOperand(1).getGlobal()->getType());
1491   assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) &&
1492          "Blocks in OpenCL C must be traceable to allocation site");
1493   return getMachineInstrType(MI);
1494 }
1495
1496 // TODO: maybe move to the global register.
1497 static SPIRVType *
1498 getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder &MIRBuilder,
1499                                    SPIRVGlobalRegistry *GR) {
1500   LLVMContext &Context = MIRBuilder.getMF().getFunction().getContext();
1501   Type *OpaqueType = StructType::getTypeByName(Context, "spirv.DeviceEvent");
1502   if (!OpaqueType)
1503     OpaqueType = StructType::getTypeByName(Context, "opencl.clk_event_t");
1504   if (!OpaqueType)
1505     OpaqueType = StructType::create(Context, "spirv.DeviceEvent");
1506   unsigned SC0 = storageClassToAddressSpace(SPIRV::StorageClass::Function);
1507   unsigned SC1 = storageClassToAddressSpace(SPIRV::StorageClass::Generic);
1508   Type *PtrType = PointerType::get(PointerType::get(OpaqueType, SC0), SC1);
1509   return GR->getOrCreateSPIRVType(PtrType, MIRBuilder);
1510 }
1511
1512 static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
1513                                MachineIRBuilder &MIRBuilder,
1514                                SPIRVGlobalRegistry *GR) {
1515   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1516   const DataLayout &DL = MIRBuilder.getDataLayout();
1517   bool HasEvents = Call->Builtin->Name.find("events") != StringRef::npos;
1518   const SPIRVType *Int32Ty = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
1519
1520   // Make vararg instructions before OpEnqueueKernel.
1521   // Local sizes arguments: Sizes of block invoke arguments. Clang generates
1522   // local size operands as an array, so we need to unpack them.
1523   SmallVector<Register, 16> LocalSizes;
1524   if (Call->Builtin->Name.find("_varargs") != StringRef::npos) {
1525     const unsigned LocalSizeArrayIdx = HasEvents ? 9 : 6;
1526     Register GepReg = Call->Arguments[LocalSizeArrayIdx];
1527     MachineInstr *GepMI = MRI->getUniqueVRegDef(GepReg);
1528     assert(isSpvIntrinsic(*GepMI, Intrinsic::spv_gep) &&
1529            GepMI->getOperand(3).isReg());
1530     Register ArrayReg = GepMI->getOperand(3).getReg();
1531     MachineInstr *ArrayMI = MRI->getUniqueVRegDef(ArrayReg);
1532     const Type *LocalSizeTy = getMachineInstrType(ArrayMI);
1533     assert(LocalSizeTy && "Local size type is expected");
1534     const uint64_t LocalSizeNum =
1535         cast<ArrayType>(LocalSizeTy)->getNumElements();
1536     unsigned SC = storageClassToAddressSpace(SPIRV::StorageClass::Generic);
1537     const LLT LLType = LLT::pointer(SC, GR->getPointerSize());
1538     const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
1539         Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
1540     for (unsigned I = 0; I < LocalSizeNum; ++I) {
1541       Register Reg =
1542           MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
1543       MIRBuilder.getMRI()->setType(Reg, LLType);
1544       GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
1545       auto GEPInst = MIRBuilder.buildIntrinsic(Intrinsic::spv_gep,
1546                                                ArrayRef<Register>{Reg}, true);
1547       GEPInst
1548           .addImm(GepMI->getOperand(2).getImm())          // In bound.
1549           .addUse(ArrayMI->getOperand(0).getReg())        // Alloca.
1550           .addUse(buildConstantIntReg(0, MIRBuilder, GR)) // Indices.
1551           .addUse(buildConstantIntReg(I, MIRBuilder, GR));
1552       LocalSizes.push_back(Reg);
1553     }
1554   }
1555
1556   // SPIRV OpEnqueueKernel instruction has 10+ arguments.
1557   auto MIB = MIRBuilder.buildInstr(SPIRV::OpEnqueueKernel)
1558                  .addDef(Call->ReturnRegister)
1559                  .addUse(GR->getSPIRVTypeID(Int32Ty));
1560
1561   // Copy all arguments before block invoke function pointer.
1562   const unsigned BlockFIdx = HasEvents ? 6 : 3;
1563   for (unsigned i = 0; i < BlockFIdx; i++)
1564     MIB.addUse(Call->Arguments[i]);
1565
1566   // If there are no event arguments in the original call, add dummy ones.
1567   if (!HasEvents) {
1568     MIB.addUse(buildConstantIntReg(0, MIRBuilder, GR)); // Dummy num events.
1569     Register NullPtr = GR->getOrCreateConstNullPtr(
1570         MIRBuilder, getOrCreateSPIRVDeviceEventPointer(MIRBuilder, GR));
1571     MIB.addUse(NullPtr); // Dummy wait events.
1572     MIB.addUse(NullPtr); // Dummy ret event.
1573   }
1574
1575   MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
1576   assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
1577   // Invoke: Pointer to invoke function.
1578   MIB.addGlobalAddress(BlockMI->getOperand(1).getGlobal());
1579
1580   Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
1581   // Param: Pointer to block literal.
1582   MIB.addUse(BlockLiteralReg);
1583
1584   Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
1585   // TODO: these numbers should be obtained from block literal structure.
1586   // Param Size: Size of block literal structure.
1587   MIB.addUse(buildConstantIntReg(DL.getTypeStoreSize(PType), MIRBuilder, GR));
1588   // Param Aligment: Aligment of block literal structure.
1589   MIB.addUse(
1590       buildConstantIntReg(DL.getPrefTypeAlignment(PType), MIRBuilder, GR));
1591
1592   for (unsigned i = 0; i < LocalSizes.size(); i++)
1593     MIB.addUse(LocalSizes[i]);
1594   return true;
1595 }
1596
1597 static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
1598                                 MachineIRBuilder &MIRBuilder,
1599                                 SPIRVGlobalRegistry *GR) {
1600   // Lookup the instruction opcode in the TableGen records.
1601   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1602   unsigned Opcode =
1603       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1604
1605   switch (Opcode) {
1606   case SPIRV::OpRetainEvent:
1607   case SPIRV::OpReleaseEvent:
1608     return MIRBuilder.buildInstr(Opcode).addUse(Call->Arguments[0]);
1609   case SPIRV::OpCreateUserEvent:
1610   case SPIRV::OpGetDefaultQueue:
1611     return MIRBuilder.buildInstr(Opcode)
1612         .addDef(Call->ReturnRegister)
1613         .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1614   case SPIRV::OpIsValidEvent:
1615     return MIRBuilder.buildInstr(Opcode)
1616         .addDef(Call->ReturnRegister)
1617         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1618         .addUse(Call->Arguments[0]);
1619   case SPIRV::OpSetUserEventStatus:
1620     return MIRBuilder.buildInstr(Opcode)
1621         .addUse(Call->Arguments[0])
1622         .addUse(Call->Arguments[1]);
1623   case SPIRV::OpCaptureEventProfilingInfo:
1624     return MIRBuilder.buildInstr(Opcode)
1625         .addUse(Call->Arguments[0])
1626         .addUse(Call->Arguments[1])
1627         .addUse(Call->Arguments[2]);
1628   case SPIRV::OpBuildNDRange: {
1629     MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1630     SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
1631     assert(PtrType->getOpcode() == SPIRV::OpTypePointer &&
1632            PtrType->getOperand(2).isReg());
1633     Register TypeReg = PtrType->getOperand(2).getReg();
1634     SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg);
1635     Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1636     GR->assignSPIRVTypeToVReg(StructType, TmpReg, MIRBuilder.getMF());
1637     // Skip the first arg, it's the destination pointer. OpBuildNDRange takes
1638     // three other arguments, so pass zero constant on absence.
1639     unsigned NumArgs = Call->Arguments.size();
1640     assert(NumArgs >= 2);
1641     Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2];
1642     Register LocalWorkSize =
1643         NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3];
1644     Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1];
1645     if (NumArgs < 4) {
1646       Register Const;
1647       SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize);
1648       if (SpvTy->getOpcode() == SPIRV::OpTypePointer) {
1649         MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize);
1650         assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) &&
1651                DefInstr->getOperand(3).isReg());
1652         Register GWSPtr = DefInstr->getOperand(3).getReg();
1653         // TODO: Maybe simplify generation of the type of the fields.
1654         unsigned Size = Call->Builtin->Name.equals("ndrange_3D") ? 3 : 2;
1655         unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32;
1656         Type *BaseTy = IntegerType::get(
1657             MIRBuilder.getMF().getFunction().getContext(), BitWidth);
1658         Type *FieldTy = ArrayType::get(BaseTy, Size);
1659         SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder);
1660         GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1661         GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize,
1662                                   MIRBuilder.getMF());
1663         MIRBuilder.buildInstr(SPIRV::OpLoad)
1664             .addDef(GlobalWorkSize)
1665             .addUse(GR->getSPIRVTypeID(SpvFieldTy))
1666             .addUse(GWSPtr);
1667         Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy);
1668       } else {
1669         Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
1670       }
1671       if (!LocalWorkSize.isValid())
1672         LocalWorkSize = Const;
1673       if (!GlobalWorkOffset.isValid())
1674         GlobalWorkOffset = Const;
1675     }
1676     MIRBuilder.buildInstr(Opcode)
1677         .addDef(TmpReg)
1678         .addUse(TypeReg)
1679         .addUse(GlobalWorkSize)
1680         .addUse(LocalWorkSize)
1681         .addUse(GlobalWorkOffset);
1682     return MIRBuilder.buildInstr(SPIRV::OpStore)
1683         .addUse(Call->Arguments[0])
1684         .addUse(TmpReg);
1685   }
1686   case SPIRV::OpEnqueueKernel:
1687     return buildEnqueueKernel(Call, MIRBuilder, GR);
1688   default:
1689     return false;
1690   }
1691 }
1692
1693 static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
1694                               MachineIRBuilder &MIRBuilder,
1695                               SPIRVGlobalRegistry *GR) {
1696   // Lookup the instruction opcode in the TableGen records.
1697   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1698   unsigned Opcode =
1699       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1700   auto Scope = buildConstantIntReg(SPIRV::Scope::Workgroup, MIRBuilder, GR);
1701
1702   switch (Opcode) {
1703   case SPIRV::OpGroupAsyncCopy:
1704     return MIRBuilder.buildInstr(Opcode)
1705         .addDef(Call->ReturnRegister)
1706         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1707         .addUse(Scope)
1708         .addUse(Call->Arguments[0])
1709         .addUse(Call->Arguments[1])
1710         .addUse(Call->Arguments[2])
1711         .addUse(buildConstantIntReg(1, MIRBuilder, GR))
1712         .addUse(Call->Arguments[3]);
1713   case SPIRV::OpGroupWaitEvents:
1714     return MIRBuilder.buildInstr(Opcode)
1715         .addUse(Scope)
1716         .addUse(Call->Arguments[0])
1717         .addUse(Call->Arguments[1]);
1718   default:
1719     return false;
1720   }
1721 }
1722
1723 static bool generateConvertInst(const StringRef DemangledCall,
1724                                 const SPIRV::IncomingCall *Call,
1725                                 MachineIRBuilder &MIRBuilder,
1726                                 SPIRVGlobalRegistry *GR) {
1727   // Lookup the conversion builtin in the TableGen records.
1728   const SPIRV::ConvertBuiltin *Builtin =
1729       SPIRV::lookupConvertBuiltin(Call->Builtin->Name, Call->Builtin->Set);
1730
1731   if (Builtin->IsSaturated)
1732     buildOpDecorate(Call->ReturnRegister, MIRBuilder,
1733                     SPIRV::Decoration::SaturatedConversion, {});
1734   if (Builtin->IsRounded)
1735     buildOpDecorate(Call->ReturnRegister, MIRBuilder,
1736                     SPIRV::Decoration::FPRoundingMode, {Builtin->RoundingMode});
1737
1738   unsigned Opcode = SPIRV::OpNop;
1739   if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) {
1740     // Int -> ...
1741     if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
1742       // Int -> Int
1743       if (Builtin->IsSaturated)
1744         Opcode = Builtin->IsDestinationSigned ? SPIRV::OpSatConvertUToS
1745                                               : SPIRV::OpSatConvertSToU;
1746       else
1747         Opcode = Builtin->IsDestinationSigned ? SPIRV::OpUConvert
1748                                               : SPIRV::OpSConvert;
1749     } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
1750                                           SPIRV::OpTypeFloat)) {
1751       // Int -> Float
1752       bool IsSourceSigned =
1753           DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
1754       Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
1755     }
1756   } else if (GR->isScalarOrVectorOfType(Call->Arguments[0],
1757                                         SPIRV::OpTypeFloat)) {
1758     // Float -> ...
1759     if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt))
1760       // Float -> Int
1761       Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
1762                                             : SPIRV::OpConvertFToU;
1763     else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
1764                                         SPIRV::OpTypeFloat))
1765       // Float -> Float
1766       Opcode = SPIRV::OpFConvert;
1767   }
1768
1769   assert(Opcode != SPIRV::OpNop &&
1770          "Conversion between the types not implemented!");
1771
1772   MIRBuilder.buildInstr(Opcode)
1773       .addDef(Call->ReturnRegister)
1774       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1775       .addUse(Call->Arguments[0]);
1776   return true;
1777 }
1778
1779 static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
1780                                         MachineIRBuilder &MIRBuilder,
1781                                         SPIRVGlobalRegistry *GR) {
1782   // Lookup the vector load/store builtin in the TableGen records.
1783   const SPIRV::VectorLoadStoreBuiltin *Builtin =
1784       SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
1785                                           Call->Builtin->Set);
1786   // Build extended instruction.
1787   auto MIB =
1788       MIRBuilder.buildInstr(SPIRV::OpExtInst)
1789           .addDef(Call->ReturnRegister)
1790           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1791           .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
1792           .addImm(Builtin->Number);
1793   for (auto Argument : Call->Arguments)
1794     MIB.addUse(Argument);
1795
1796   // Rounding mode should be passed as a last argument in the MI for builtins
1797   // like "vstorea_halfn_r".
1798   if (Builtin->IsRounded)
1799     MIB.addImm(static_cast<uint32_t>(Builtin->RoundingMode));
1800   return true;
1801 }
1802
1803 static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
1804                                   MachineIRBuilder &MIRBuilder,
1805                                   SPIRVGlobalRegistry *GR) {
1806   // Lookup the instruction opcode in the TableGen records.
1807   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1808   unsigned Opcode =
1809       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1810   bool IsLoad = Opcode == SPIRV::OpLoad;
1811   // Build the instruction.
1812   auto MIB = MIRBuilder.buildInstr(Opcode);
1813   if (IsLoad) {
1814     MIB.addDef(Call->ReturnRegister);
1815     MIB.addUse(GR->getSPIRVTypeID(Call->ReturnType));
1816   }
1817   // Add a pointer to the value to load/store.
1818   MIB.addUse(Call->Arguments[0]);
1819   // Add a value to store.
1820   if (!IsLoad)
1821     MIB.addUse(Call->Arguments[1]);
1822   // Add optional memory attributes and an alignment.
1823   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1824   unsigned NumArgs = Call->Arguments.size();
1825   if ((IsLoad && NumArgs >= 2) || NumArgs >= 3)
1826     MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 1 : 2], MRI));
1827   if ((IsLoad && NumArgs >= 3) || NumArgs >= 4)
1828     MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 2 : 3], MRI));
1829   return true;
1830 }
1831
1832 /// Lowers a builtin funtion call using the provided \p DemangledCall skeleton
1833 /// and external instruction \p Set.
1834 namespace SPIRV {
1835 std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
1836                                  SPIRV::InstructionSet::InstructionSet Set,
1837                                  MachineIRBuilder &MIRBuilder,
1838                                  const Register OrigRet, const Type *OrigRetTy,
1839                                  const SmallVectorImpl<Register> &Args,
1840                                  SPIRVGlobalRegistry *GR) {
1841   LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
1842
1843   // SPIR-V type and return register.
1844   Register ReturnRegister = OrigRet;
1845   SPIRVType *ReturnType = nullptr;
1846   if (OrigRetTy && !OrigRetTy->isVoidTy()) {
1847     ReturnType = GR->assignTypeToVReg(OrigRetTy, OrigRet, MIRBuilder);
1848   } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
1849     ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
1850     MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(32));
1851     ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
1852   }
1853
1854   // Lookup the builtin in the TableGen records.
1855   std::unique_ptr<const IncomingCall> Call =
1856       lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
1857
1858   if (!Call) {
1859     LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
1860     return std::nullopt;
1861   }
1862
1863   // TODO: check if the provided args meet the builtin requirments.
1864   assert(Args.size() >= Call->Builtin->MinNumArgs &&
1865          "Too few arguments to generate the builtin");
1866   if (Call->Builtin->MaxNumArgs && Args.size() > Call->Builtin->MaxNumArgs)
1867     LLVM_DEBUG(dbgs() << "More arguments provided than required!\n");
1868
1869   // Match the builtin with implementation based on the grouping.
1870   switch (Call->Builtin->Group) {
1871   case SPIRV::Extended:
1872     return generateExtInst(Call.get(), MIRBuilder, GR);
1873   case SPIRV::Relational:
1874     return generateRelationalInst(Call.get(), MIRBuilder, GR);
1875   case SPIRV::Group:
1876     return generateGroupInst(Call.get(), MIRBuilder, GR);
1877   case SPIRV::Variable:
1878     return generateBuiltinVar(Call.get(), MIRBuilder, GR);
1879   case SPIRV::Atomic:
1880     return generateAtomicInst(Call.get(), MIRBuilder, GR);
1881   case SPIRV::Barrier:
1882     return generateBarrierInst(Call.get(), MIRBuilder, GR);
1883   case SPIRV::Dot:
1884     return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
1885   case SPIRV::GetQuery:
1886     return generateGetQueryInst(Call.get(), MIRBuilder, GR);
1887   case SPIRV::ImageSizeQuery:
1888     return generateImageSizeQueryInst(Call.get(), MIRBuilder, GR);
1889   case SPIRV::ImageMiscQuery:
1890     return generateImageMiscQueryInst(Call.get(), MIRBuilder, GR);
1891   case SPIRV::ReadImage:
1892     return generateReadImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
1893   case SPIRV::WriteImage:
1894     return generateWriteImageInst(Call.get(), MIRBuilder, GR);
1895   case SPIRV::SampleImage:
1896     return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
1897   case SPIRV::Select:
1898     return generateSelectInst(Call.get(), MIRBuilder);
1899   case SPIRV::SpecConstant:
1900     return generateSpecConstantInst(Call.get(), MIRBuilder, GR);
1901   case SPIRV::Enqueue:
1902     return generateEnqueueInst(Call.get(), MIRBuilder, GR);
1903   case SPIRV::AsyncCopy:
1904     return generateAsyncCopy(Call.get(), MIRBuilder, GR);
1905   case SPIRV::Convert:
1906     return generateConvertInst(DemangledCall, Call.get(), MIRBuilder, GR);
1907   case SPIRV::VectorLoadStore:
1908     return generateVectorLoadStoreInst(Call.get(), MIRBuilder, GR);
1909   case SPIRV::LoadStore:
1910     return generateLoadStoreInst(Call.get(), MIRBuilder, GR);
1911   }
1912   return false;
1913 }
1914
1915 struct DemangledType {
1916   StringRef Name;
1917   uint32_t Opcode;
1918 };
1919
1920 #define GET_DemangledTypes_DECL
1921 #define GET_DemangledTypes_IMPL
1922
1923 struct ImageType {
1924   StringRef Name;
1925   StringRef SampledType;
1926   AccessQualifier::AccessQualifier Qualifier;
1927   Dim::Dim Dimensionality;
1928   bool Arrayed;
1929   bool Depth;
1930   bool Multisampled;
1931   bool Sampled;
1932   ImageFormat::ImageFormat Format;
1933 };
1934
1935 struct PipeType {
1936   StringRef Name;
1937   AccessQualifier::AccessQualifier Qualifier;
1938 };
1939
1940 using namespace AccessQualifier;
1941 using namespace Dim;
1942 using namespace ImageFormat;
1943 #define GET_ImageTypes_DECL
1944 #define GET_ImageTypes_IMPL
1945 #define GET_PipeTypes_DECL
1946 #define GET_PipeTypes_IMPL
1947 #include "SPIRVGenTables.inc"
1948 } // namespace SPIRV
1949
1950 //===----------------------------------------------------------------------===//
1951 // Misc functions for parsing builtin types and looking up implementation
1952 // details in TableGenerated tables.
1953 //===----------------------------------------------------------------------===//
1954
1955 static const SPIRV::DemangledType *findBuiltinType(StringRef Name) {
1956   if (Name.startswith("opencl."))
1957     return SPIRV::lookupBuiltinType(Name);
1958   if (!Name.startswith("spirv."))
1959     return nullptr;
1960   // Some SPIR-V builtin types have a complex list of parameters as part of
1961   // their name (e.g. spirv.Image._void_1_0_0_0_0_0_0). Those parameters often
1962   // are numeric literals which cannot be easily represented by TableGen
1963   // records and should be parsed instead.
1964   unsigned BaseTypeNameLength =
1965       Name.contains('_') ? Name.find('_') - 1 : Name.size();
1966   return SPIRV::lookupBuiltinType(Name.substr(0, BaseTypeNameLength).str());
1967 }
1968
1969 static std::unique_ptr<const SPIRV::ImageType>
1970 lookupOrParseBuiltinImageType(StringRef Name) {
1971   if (Name.startswith("opencl.")) {
1972     // Lookup OpenCL builtin image type lowering details in TableGen records.
1973     const SPIRV::ImageType *Record = SPIRV::lookupImageType(Name);
1974     return std::unique_ptr<SPIRV::ImageType>(new SPIRV::ImageType(*Record));
1975   }
1976   if (!Name.startswith("spirv."))
1977     llvm_unreachable("Unknown builtin image type name/literal");
1978   // Parse the literals of SPIR-V image builtin parameters. The name should
1979   // have the following format:
1980   // spirv.Image._Type_Dim_Depth_Arrayed_MS_Sampled_ImageFormat_AccessQualifier
1981   // e.g. %spirv.Image._void_1_0_0_0_0_0_0
1982   StringRef TypeParametersString = Name.substr(strlen("spirv.Image."));
1983   SmallVector<StringRef> TypeParameters;
1984   SplitString(TypeParametersString, TypeParameters, "_");
1985   assert(TypeParameters.size() == 8 &&
1986          "Wrong number of literals in SPIR-V builtin image type");
1987
1988   StringRef SampledType = TypeParameters[0];
1989   unsigned Dim, Depth, Arrayed, Multisampled, Sampled, Format, AccessQual;
1990   bool AreParameterLiteralsValid =
1991       !(TypeParameters[1].getAsInteger(10, Dim) ||
1992         TypeParameters[2].getAsInteger(10, Depth) ||
1993         TypeParameters[3].getAsInteger(10, Arrayed) ||
1994         TypeParameters[4].getAsInteger(10, Multisampled) ||
1995         TypeParameters[5].getAsInteger(10, Sampled) ||
1996         TypeParameters[6].getAsInteger(10, Format) ||
1997         TypeParameters[7].getAsInteger(10, AccessQual));
1998   assert(AreParameterLiteralsValid &&
1999          "Invalid format of SPIR-V image type parameter literals.");
2000
2001   return std::unique_ptr<SPIRV::ImageType>(new SPIRV::ImageType{
2002       Name, SampledType, SPIRV::AccessQualifier::AccessQualifier(AccessQual),
2003       SPIRV::Dim::Dim(Dim), static_cast<bool>(Arrayed),
2004       static_cast<bool>(Depth), static_cast<bool>(Multisampled),
2005       static_cast<bool>(Sampled), SPIRV::ImageFormat::ImageFormat(Format)});
2006 }
2007
2008 static std::unique_ptr<const SPIRV::PipeType>
2009 lookupOrParseBuiltinPipeType(StringRef Name) {
2010   if (Name.startswith("opencl.")) {
2011     // Lookup OpenCL builtin pipe type lowering details in TableGen records.
2012     const SPIRV::PipeType *Record = SPIRV::lookupPipeType(Name);
2013     return std::unique_ptr<SPIRV::PipeType>(new SPIRV::PipeType(*Record));
2014   }
2015   if (!Name.startswith("spirv."))
2016     llvm_unreachable("Unknown builtin pipe type name/literal");
2017   // Parse the access qualifier literal in the name of the SPIR-V pipe type.
2018   // The name should have the following format:
2019   // spirv.Pipe._AccessQualifier
2020   // e.g. %spirv.Pipe._1
2021   if (Name.endswith("_0"))
2022     return std::unique_ptr<SPIRV::PipeType>(
2023         new SPIRV::PipeType{Name, SPIRV::AccessQualifier::ReadOnly});
2024   if (Name.endswith("_1"))
2025     return std::unique_ptr<SPIRV::PipeType>(
2026         new SPIRV::PipeType{Name, SPIRV::AccessQualifier::WriteOnly});
2027   if (Name.endswith("_2"))
2028     return std::unique_ptr<SPIRV::PipeType>(
2029         new SPIRV::PipeType{Name, SPIRV::AccessQualifier::ReadWrite});
2030   llvm_unreachable("Unknown pipe type access qualifier literal");
2031 }
2032
2033 //===----------------------------------------------------------------------===//
2034 // Implementation functions for builtin types.
2035 //===----------------------------------------------------------------------===//
2036
2037 static SPIRVType *getNonParametrizedType(const StructType *OpaqueType,
2038                                          const SPIRV::DemangledType *TypeRecord,
2039                                          MachineIRBuilder &MIRBuilder,
2040                                          SPIRVGlobalRegistry *GR) {
2041   unsigned Opcode = TypeRecord->Opcode;
2042   // Create or get an existing type from GlobalRegistry.
2043   return GR->getOrCreateOpTypeByOpcode(OpaqueType, MIRBuilder, Opcode);
2044 }
2045
2046 static SPIRVType *getSamplerType(MachineIRBuilder &MIRBuilder,
2047                                  SPIRVGlobalRegistry *GR) {
2048   // Create or get an existing type from GlobalRegistry.
2049   return GR->getOrCreateOpTypeSampler(MIRBuilder);
2050 }
2051
2052 static SPIRVType *getPipeType(const StructType *OpaqueType,
2053                               MachineIRBuilder &MIRBuilder,
2054                               SPIRVGlobalRegistry *GR) {
2055   // Lookup pipe type lowering details in TableGen records or parse the
2056   // name/literal for details.
2057   std::unique_ptr<const SPIRV::PipeType> Record =
2058       lookupOrParseBuiltinPipeType(OpaqueType->getName());
2059   // Create or get an existing type from GlobalRegistry.
2060   return GR->getOrCreateOpTypePipe(MIRBuilder, Record.get()->Qualifier);
2061 }
2062
2063 static SPIRVType *
2064 getImageType(const StructType *OpaqueType,
2065              SPIRV::AccessQualifier::AccessQualifier AccessQual,
2066              MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) {
2067   // Lookup image type lowering details in TableGen records or parse the
2068   // name/literal for details.
2069   std::unique_ptr<const SPIRV::ImageType> Record =
2070       lookupOrParseBuiltinImageType(OpaqueType->getName());
2071
2072   SPIRVType *SampledType =
2073       GR->getOrCreateSPIRVTypeByName(Record.get()->SampledType, MIRBuilder);
2074   return GR->getOrCreateOpTypeImage(
2075       MIRBuilder, SampledType, Record.get()->Dimensionality,
2076       Record.get()->Depth, Record.get()->Arrayed, Record.get()->Multisampled,
2077       Record.get()->Sampled, Record.get()->Format,
2078       AccessQual == SPIRV::AccessQualifier::WriteOnly
2079           ? SPIRV::AccessQualifier::WriteOnly
2080           : Record.get()->Qualifier);
2081 }
2082
2083 static SPIRVType *getSampledImageType(const StructType *OpaqueType,
2084                                       MachineIRBuilder &MIRBuilder,
2085                                       SPIRVGlobalRegistry *GR) {
2086   StringRef TypeParametersString =
2087       OpaqueType->getName().substr(strlen("spirv.SampledImage."));
2088   LLVMContext &Context = MIRBuilder.getMF().getFunction().getContext();
2089   Type *ImageOpaqueType = StructType::getTypeByName(
2090       Context, "spirv.Image." + TypeParametersString.str());
2091   SPIRVType *TargetImageType =
2092       GR->getOrCreateSPIRVType(ImageOpaqueType, MIRBuilder);
2093   return GR->getOrCreateOpTypeSampledImage(TargetImageType, MIRBuilder);
2094 }
2095
2096 namespace SPIRV {
2097 SPIRVType *lowerBuiltinType(const StructType *OpaqueType,
2098                             SPIRV::AccessQualifier::AccessQualifier AccessQual,
2099                             MachineIRBuilder &MIRBuilder,
2100                             SPIRVGlobalRegistry *GR) {
2101   assert(OpaqueType->hasName() &&
2102          "Structs representing builtin types must have a parsable name");
2103   unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();
2104
2105   const StringRef Name = OpaqueType->getName();
2106   LLVM_DEBUG(dbgs() << "Lowering builtin type: " << Name << "\n");
2107
2108   // Lookup the demangled builtin type in the TableGen records.
2109   const SPIRV::DemangledType *TypeRecord = findBuiltinType(Name);
2110   if (!TypeRecord)
2111     report_fatal_error("Missing TableGen record for builtin type: " + Name);
2112
2113   // "Lower" the BuiltinType into TargetType. The following get<...>Type methods
2114   // use the implementation details from TableGen records to either create a new
2115   // OpType<...> machine instruction or get an existing equivalent SPIRVType
2116   // from GlobalRegistry.
2117   SPIRVType *TargetType;
2118   switch (TypeRecord->Opcode) {
2119   case SPIRV::OpTypeImage:
2120     TargetType = getImageType(OpaqueType, AccessQual, MIRBuilder, GR);
2121     break;
2122   case SPIRV::OpTypePipe:
2123     TargetType = getPipeType(OpaqueType, MIRBuilder, GR);
2124     break;
2125   case SPIRV::OpTypeDeviceEvent:
2126     TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
2127     break;
2128   case SPIRV::OpTypeSampler:
2129     TargetType = getSamplerType(MIRBuilder, GR);
2130     break;
2131   case SPIRV::OpTypeSampledImage:
2132     TargetType = getSampledImageType(OpaqueType, MIRBuilder, GR);
2133     break;
2134   default:
2135     TargetType = getNonParametrizedType(OpaqueType, TypeRecord, MIRBuilder, GR);
2136     break;
2137   }
2138
2139   // Emit OpName instruction if a new OpType<...> instruction was added
2140   // (equivalent type was not found in GlobalRegistry).
2141   if (NumStartingVRegs < MIRBuilder.getMRI()->getNumVirtRegs())
2142     buildOpName(GR->getSPIRVTypeID(TargetType), OpaqueType->getName(),
2143                 MIRBuilder);
2144
2145   return TargetType;
2146 }
2147 } // namespace SPIRV
2148 } // namespace llvm