]> CyberLeo.Net >> Repos - FreeBSD/releng/9.0.git/blob - contrib/llvm/lib/Target/PTX/PTXISelLowering.cpp
Copy stable/9 to releng/9.0 as part of the FreeBSD 9.0-RELEASE release
[FreeBSD/releng/9.0.git] / contrib / llvm / lib / Target / PTX / PTXISelLowering.cpp
1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the PTXTargetLowering class.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PTX.h"
15 #include "PTXISelLowering.h"
16 #include "PTXMachineFunctionInfo.h"
17 #include "PTXRegisterInfo.h"
18 #include "PTXSubtarget.h"
19 #include "llvm/Function.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/CodeGen/CallingConvLower.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/SelectionDAG.h"
25 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28
29 using namespace llvm;
30
31 //===----------------------------------------------------------------------===//
32 // TargetLowering Implementation
33 //===----------------------------------------------------------------------===//
34
35 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
36   : TargetLowering(TM, new TargetLoweringObjectFileELF()) {
37   // Set up the register classes.
38   addRegisterClass(MVT::i1,  PTX::RegPredRegisterClass);
39   addRegisterClass(MVT::i16, PTX::RegI16RegisterClass);
40   addRegisterClass(MVT::i32, PTX::RegI32RegisterClass);
41   addRegisterClass(MVT::i64, PTX::RegI64RegisterClass);
42   addRegisterClass(MVT::f32, PTX::RegF32RegisterClass);
43   addRegisterClass(MVT::f64, PTX::RegF64RegisterClass);
44
45   setBooleanContents(ZeroOrOneBooleanContent);
46   setBooleanVectorContents(ZeroOrOneBooleanContent); // FIXME: Is this correct?
47   setMinFunctionAlignment(2);
48
49   ////////////////////////////////////
50   /////////// Expansion //////////////
51   ////////////////////////////////////
52
53   // (any/zero/sign) extload => load + (any/zero/sign) extend
54
55   setLoadExtAction(ISD::EXTLOAD, MVT::i16, Expand);
56   setLoadExtAction(ISD::ZEXTLOAD, MVT::i16, Expand);
57   setLoadExtAction(ISD::SEXTLOAD, MVT::i16, Expand);
58
59   // f32 extload => load + fextend
60
61   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
62
63   // f64 truncstore => trunc + store
64
65   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
66
67   // sign_extend_inreg => sign_extend
68
69   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
70
71   // br_cc => brcond
72
73   setOperationAction(ISD::BR_CC, MVT::Other, Expand);
74
75   // select_cc => setcc
76
77   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
78   setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
79   setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
80
81   ////////////////////////////////////
82   //////////// Legal /////////////////
83   ////////////////////////////////////
84
85   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
86   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
87
88   ////////////////////////////////////
89   //////////// Custom ////////////////
90   ////////////////////////////////////
91
92   // customise setcc to use bitwise logic if possible
93
94   setOperationAction(ISD::SETCC, MVT::i1, Custom);
95
96   // customize translation of memory addresses
97
98   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
99   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
100
101   // Compute derived properties from the register classes
102   computeRegisterProperties();
103 }
104
105 EVT PTXTargetLowering::getSetCCResultType(EVT VT) const {
106   return MVT::i1;
107 }
108
109 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
110   switch (Op.getOpcode()) {
111     default:
112       llvm_unreachable("Unimplemented operand");
113     case ISD::SETCC:
114       return LowerSETCC(Op, DAG);
115     case ISD::GlobalAddress:
116       return LowerGlobalAddress(Op, DAG);
117   }
118 }
119
120 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
121   switch (Opcode) {
122     default:
123       llvm_unreachable("Unknown opcode");
124     case PTXISD::COPY_ADDRESS:
125       return "PTXISD::COPY_ADDRESS";
126     case PTXISD::LOAD_PARAM:
127       return "PTXISD::LOAD_PARAM";
128     case PTXISD::STORE_PARAM:
129       return "PTXISD::STORE_PARAM";
130     case PTXISD::READ_PARAM:
131       return "PTXISD::READ_PARAM";
132     case PTXISD::WRITE_PARAM:
133       return "PTXISD::WRITE_PARAM";
134     case PTXISD::EXIT:
135       return "PTXISD::EXIT";
136     case PTXISD::RET:
137       return "PTXISD::RET";
138     case PTXISD::CALL:
139       return "PTXISD::CALL";
140   }
141 }
142
143 //===----------------------------------------------------------------------===//
144 //                      Custom Lower Operation
145 //===----------------------------------------------------------------------===//
146
147 SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
148   assert(Op.getValueType() == MVT::i1 && "SetCC type must be 1-bit integer");
149   SDValue Op0 = Op.getOperand(0);
150   SDValue Op1 = Op.getOperand(1);
151   SDValue Op2 = Op.getOperand(2);
152   DebugLoc dl = Op.getDebugLoc();
153   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
154
155   // Look for X == 0, X == 1, X != 0, or X != 1
156   // We can simplify these to bitwise logic
157
158   if (Op1.getOpcode() == ISD::Constant &&
159       (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 ||
160        cast<ConstantSDNode>(Op1)->isNullValue()) &&
161       (CC == ISD::SETEQ || CC == ISD::SETNE)) {
162
163     return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1);
164   }
165
166   return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2);
167 }
168
169 SDValue PTXTargetLowering::
170 LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
171   EVT PtrVT = getPointerTy();
172   DebugLoc dl = Op.getDebugLoc();
173   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
174
175   assert(PtrVT.isSimple() && "Pointer must be to primitive type.");
176
177   SDValue targetGlobal = DAG.getTargetGlobalAddress(GV, dl, PtrVT);
178   SDValue movInstr = DAG.getNode(PTXISD::COPY_ADDRESS,
179                                  dl,
180                                  PtrVT.getSimpleVT(),
181                                  targetGlobal);
182
183   return movInstr;
184 }
185
186 //===----------------------------------------------------------------------===//
187 //                      Calling Convention Implementation
188 //===----------------------------------------------------------------------===//
189
190 SDValue PTXTargetLowering::
191   LowerFormalArguments(SDValue Chain,
192                        CallingConv::ID CallConv,
193                        bool isVarArg,
194                        const SmallVectorImpl<ISD::InputArg> &Ins,
195                        DebugLoc dl,
196                        SelectionDAG &DAG,
197                        SmallVectorImpl<SDValue> &InVals) const {
198   if (isVarArg) llvm_unreachable("PTX does not support varargs");
199
200   MachineFunction &MF = DAG.getMachineFunction();
201   const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
202   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
203   PTXParamManager &PM = MFI->getParamManager();
204
205   switch (CallConv) {
206     default:
207       llvm_unreachable("Unsupported calling convention");
208       break;
209     case CallingConv::PTX_Kernel:
210       MFI->setKernel(true);
211       break;
212     case CallingConv::PTX_Device:
213       MFI->setKernel(false);
214       break;
215   }
216
217   // We do one of two things here:
218   // IsKernel || SM >= 2.0  ->  Use param space for arguments
219   // SM < 2.0               ->  Use registers for arguments
220   if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) {
221     // We just need to emit the proper LOAD_PARAM ISDs
222     for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
223       assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
224              "Kernels cannot take pred operands");
225
226       unsigned ParamSize = Ins[i].VT.getStoreSizeInBits();
227       unsigned Param = PM.addArgumentParam(ParamSize);
228       const std::string &ParamName = PM.getParamName(Param);
229       SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
230                                                        MVT::Other);
231       SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
232                                      ParamValue);
233       InVals.push_back(ArgValue);
234     }
235   }
236   else {
237     for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
238       EVT                  RegVT = Ins[i].VT;
239       TargetRegisterClass* TRC   = getRegClassFor(RegVT);
240
241       // Use a unique index in the instruction to prevent instruction folding.
242       // Yes, this is a hack.
243       SDValue Index = DAG.getTargetConstant(i, MVT::i32);
244       unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
245       SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain,
246                                      Index);
247
248       InVals.push_back(ArgValue);
249
250       MFI->addArgReg(Reg);
251     }
252   }
253
254   return Chain;
255 }
256
257 SDValue PTXTargetLowering::
258   LowerReturn(SDValue Chain,
259               CallingConv::ID CallConv,
260               bool isVarArg,
261               const SmallVectorImpl<ISD::OutputArg> &Outs,
262               const SmallVectorImpl<SDValue> &OutVals,
263               DebugLoc dl,
264               SelectionDAG &DAG) const {
265   if (isVarArg) llvm_unreachable("PTX does not support varargs");
266
267   switch (CallConv) {
268     default:
269       llvm_unreachable("Unsupported calling convention.");
270     case CallingConv::PTX_Kernel:
271       assert(Outs.size() == 0 && "Kernel must return void.");
272       return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
273     case CallingConv::PTX_Device:
274       assert(Outs.size() <= 1 && "Can at most return one value.");
275       break;
276   }
277
278   MachineFunction& MF = DAG.getMachineFunction();
279   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
280   PTXParamManager &PM = MFI->getParamManager();
281
282   SDValue Flag;
283   const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
284
285   if (ST.useParamSpaceForDeviceArgs()) {
286     assert(Outs.size() < 2 && "Device functions can return at most one value");
287
288     if (Outs.size() == 1) {
289       unsigned ParamSize = OutVals[0].getValueType().getSizeInBits();
290       unsigned Param = PM.addReturnParam(ParamSize);
291       const std::string &ParamName = PM.getParamName(Param);
292       SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
293                                                        MVT::Other);
294       Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
295                           ParamValue, OutVals[0]);
296     }
297   } else {
298     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
299       EVT                  RegVT = Outs[i].VT;
300       TargetRegisterClass* TRC = 0;
301
302       // Determine which register class we need
303       if (RegVT == MVT::i1) {
304         TRC = PTX::RegPredRegisterClass;
305       }
306       else if (RegVT == MVT::i16) {
307         TRC = PTX::RegI16RegisterClass;
308       }
309       else if (RegVT == MVT::i32) {
310         TRC = PTX::RegI32RegisterClass;
311       }
312       else if (RegVT == MVT::i64) {
313         TRC = PTX::RegI64RegisterClass;
314       }
315       else if (RegVT == MVT::f32) {
316         TRC = PTX::RegF32RegisterClass;
317       }
318       else if (RegVT == MVT::f64) {
319         TRC = PTX::RegF64RegisterClass;
320       }
321       else {
322         llvm_unreachable("Unknown parameter type");
323       }
324
325       unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
326
327       SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
328       SDValue OutReg = DAG.getRegister(Reg, RegVT);
329
330       Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg);
331
332       MFI->addRetReg(Reg);
333     }
334   }
335
336   if (Flag.getNode() == 0) {
337     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
338   }
339   else {
340     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
341   }
342 }
343
344 SDValue
345 PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
346                              CallingConv::ID CallConv, bool isVarArg,
347                              bool &isTailCall,
348                              const SmallVectorImpl<ISD::OutputArg> &Outs,
349                              const SmallVectorImpl<SDValue> &OutVals,
350                              const SmallVectorImpl<ISD::InputArg> &Ins,
351                              DebugLoc dl, SelectionDAG &DAG,
352                              SmallVectorImpl<SDValue> &InVals) const {
353
354   MachineFunction& MF = DAG.getMachineFunction();
355   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
356   PTXParamManager &PM = MFI->getParamManager();
357
358   assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
359          "Calls are not handled for the target device");
360
361   std::vector<SDValue> Ops;
362   // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
363   Ops.resize(Outs.size() + Ins.size() + 4);
364
365   Ops[0] = Chain;
366
367   // Identify the callee function
368   const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
369   assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
370          "PTX function calls must be to PTX device functions");
371   Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
372   Ops[Ins.size()+2] = Callee;
373
374   // Generate STORE_PARAM nodes for each function argument.  In PTX, function
375   // arguments are explicitly stored into .param variables and passed as
376   // arguments. There is no register/stack-based calling convention in PTX.
377   Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32);
378   for (unsigned i = 0; i != OutVals.size(); ++i) {
379     unsigned Size = OutVals[i].getValueType().getSizeInBits();
380     unsigned Param = PM.addLocalParam(Size);
381     const std::string &ParamName = PM.getParamName(Param);
382     SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
383                                                      MVT::Other);
384     Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
385                         ParamValue, OutVals[i]);
386     Ops[i+Ins.size()+4] = ParamValue;
387   }
388
389   std::vector<SDValue> InParams;
390
391   // Generate list of .param variables to hold the return value(s).
392   Ops[1] = DAG.getTargetConstant(Ins.size(), MVT::i32);
393   for (unsigned i = 0; i < Ins.size(); ++i) {
394     unsigned Size = Ins[i].VT.getStoreSizeInBits();
395     unsigned Param = PM.addLocalParam(Size);
396     const std::string &ParamName = PM.getParamName(Param);
397     SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
398                                                      MVT::Other);
399     Ops[i+2] = ParamValue;
400     InParams.push_back(ParamValue);
401   }
402
403   Ops[0] = Chain;
404
405   // Create the CALL node.
406   Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size());
407
408   // Create the LOAD_PARAM nodes that retrieve the function return value(s).
409   for (unsigned i = 0; i < Ins.size(); ++i) {
410     SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
411                                InParams[i]);
412     InVals.push_back(Load);
413   }
414
415   return Chain;
416 }
417
418 unsigned PTXTargetLowering::getNumRegisters(LLVMContext &Context, EVT VT) {
419   // All arguments consist of one "register," regardless of the type.
420   return 1;
421 }
422