]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - contrib/llvm-project/llvm/lib/Target/X86/X86PreAMXConfig.cpp
Merge llvm-project main llvmorg-15-init-17485-ga3e38b4a206b
[FreeBSD/FreeBSD.git] / contrib / llvm-project / llvm / lib / Target / X86 / X86PreAMXConfig.cpp
1 //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// Insert tilecfg for each area of key AMX intrinsic.
10 /// All the key AMX intrinsic's tile operand must come from tileload. And the
11 /// def tile of key AMX intrinsic must be tilestored.
12 /// take tdpbssd for example:
13 /// --------------------------------------------------------------------------
14 /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...)                key
15 /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...)                 |
16 /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...)                amx
17 /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3)         |
18 /// call void @llvm.x86.tilestored64.internal(... td)                     area
19 /// --------------------------------------------------------------------------
20 /// This pass will insert tilecfg before every key-amx-area, some like:
21 /// --------------------------------------------------------------------------
22 /// %cfgmem = alloca <16 x i32>, align 4                        * allocate mem
23 /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem       * zero init
24 /// ...
25 /// ... pre-config shape of %t1                                 *
26 /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
27 /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
28 /// ...                                                         *
29 /// ... pre-config shape of %t2                                 * shapes
30 /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     *
31 /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
32 /// ...
33 /// call void @llvm.x86.ldtilecfg(i8* %cfgmem)                  * tile config
34 //
35 //===----------------------------------------------------------------------===//
36 //
37 #include "X86.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/Analysis/TargetTransformInfo.h"
40 #include "llvm/CodeGen/Passes.h"
41 #include "llvm/CodeGen/TargetPassConfig.h"
42 #include "llvm/CodeGen/ValueTypes.h"
43 #include "llvm/IR/DataLayout.h"
44 #include "llvm/IR/Function.h"
45 #include "llvm/IR/IRBuilder.h"
46 #include "llvm/IR/Instructions.h"
47 #include "llvm/IR/IntrinsicInst.h"
48 #include "llvm/IR/IntrinsicsX86.h"
49 #include "llvm/IR/PatternMatch.h"
50 #include "llvm/InitializePasses.h"
51 #include "llvm/Pass.h"
52 #include "llvm/Support/raw_ostream.h"
53 #include "llvm/Target/TargetMachine.h"
54
55 using namespace llvm;
56 using namespace PatternMatch;
57
58 #define DEBUG_TYPE "pre-amx-config"
59
60 static bool isAMXIntrinsic(IntrinsicInst *II) {
61   for (Value *Operand : II->operands())
62     if (Operand->getType()->isX86_AMXTy())
63       return true;
64   return II->getType()->isX86_AMXTy();
65 }
66
67 static bool isTileLoad(IntrinsicInst *II) {
68   return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal ||
69          II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal;
70 }
71
72 static bool isTileStore(IntrinsicInst *II) {
73   return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
74 }
75
76 #ifndef NDEBUG
77 static bool onlyTileDef(IntrinsicInst *II) {
78   for (Value *Operand : II->operands())
79     if (Operand->getType()->isX86_AMXTy())
80       return false;
81   return II->getType()->isX86_AMXTy();
82 }
83
84 static bool brokenVolatile(Instruction *I) {
85   // Todo: it is weak to identify a normal call here.
86   if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
87     return true;
88   return false;
89 }
90 #endif
91
92 namespace {
93 class X86PreAMXConfig {
94   using PosAndShapesMap = MapVector<Instruction *, SmallVector<Value *, 8>>;
95
96   Function &F;
97
98 public:
99   X86PreAMXConfig(Function &Func) : F(Func) {}
100   bool preTileConfig();
101   void addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
102   bool findConfigShapes(PosAndShapesMap &PosAndShapes);
103   bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
104   void preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
105                        SmallVector<Value *, 8> &Shapes);
106   BasicBlock::iterator
107   getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
108                            SmallVector<Value *, 8> &Shapes);
109   bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
110                           IntrinsicInst *KeyAMX);
111 };
112
113 // Orderly write the shapes in tilecfg's mem. This maybe not right.
114 // Because the first shape may not corresponding to the first tmm register,
115 // so we need to handle at at X86FastTileConfig::materializeTileCfg()
116 // after register allocation.
117 // For example:
118 // --------------------------------------------------------------------------
119 // zeroinitialize tilecfg's mem (of ldtilecfg)
120 // --------------------------------------------------------------------------
121 // ... pre-config shape of %t1                                 *
122 // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48   *
123 // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
124 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
125 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
126 // ...                                                         *
127 // ... pre-config shape of %t2                                 *
128 // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49   *
129 // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
130 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     * shapes
131 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
132 // ...                                                         *
133 // ... pre-config shape of %t3                                 * of
134 // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50   *
135 // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
136 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1     *
137 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2   *
138 // ...                                                         * tiles
139 // ... pre-config shape of %td                                 *
140 // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51   *
141 // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
142 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1     *
143 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2   *
144 // --------------------------------------------------------------------------
145 // call void @llvm.x86.ldtilecfg(i8* %mem)                     * tile config
146 // --------------------------------------------------------------------------
147 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
148 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
149 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
150 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
151 // call void @llvm.x86.tilestored64.internal(... td)                     area
152 // --------------------------------------------------------------------------
153 void X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
154                                       SmallVector<Value *, 8> &Shapes) {
155   LLVMContext &Ctx = Builder.getContext();
156   Type *I8Ty = Type::getInt8Ty(Ctx);
157   Type *I16Ty = Type::getInt16Ty(Ctx);
158
159   // TODO: Currently we defaultly set Palette = 1, it may be assigned to
160   // other value in the future.
161   Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
162   Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
163   Value *PalettePos = Builder.CreateGEP(I8Ty, I8Ptr, PaletteOffset);
164   Builder.CreateStore(PaletteValue, PalettePos);
165
166   for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
167     Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
168     Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
169     const std::string ShapeName = "amx.tmm." + itostr(I);
170     Value *RowPos = Builder.CreateGEP(I8Ty, I8Ptr, RowOffset,
171                                       ShapeName + ".shape.row");
172     Value *ColPos = Builder.CreateGEP(I8Ty, I8Ptr, ColOffset);
173     ColPos = Builder.CreateBitCast(ColPos, PointerType::get(I16Ty, 0),
174                                    ShapeName + ".shape.col");
175     Value *Row = Shapes[I * 2];
176     Value *Col = Shapes[I * 2 + 1];
177     Row = Builder.CreateTrunc(Row, I8Ty);
178     Builder.CreateStore(Row, RowPos);
179     Builder.CreateStore(Col, ColPos);
180   }
181 }
182
183 void X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
184                                     SmallVector<Value *, 8> &Shapes) {
185   Module *M = F.getParent();
186   IRBuilder<> Builder(ModelStart);
187   const DataLayout &DL = M->getDataLayout();
188   unsigned AddrSpace = DL.getAllocaAddrSpace();
189   LLVMContext &Ctx = Builder.getContext();
190   Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
191   Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
192
193   AllocaInst *Addr =
194       new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
195   Addr->setAlignment(Alignment);
196   Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
197
198   Builder.CreateAlignedStore(Constant::getNullValue(V512Ty), Addr, Alignment);
199
200   preWriteTileCfg(I8Ptr, Builder, Shapes);
201
202   Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, None, {I8Ptr});
203 }
204
205 // Todo: We may need to handle "more than one store" case in the future.
206 bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
207                                          IntrinsicInst *Store,
208                                          IntrinsicInst *KeyAMX) {
209   Value *ST = Store->getOperand(4);
210
211   // Only has tileload and tilestore.
212   if (!KeyAMX)
213     return (Loads.size() == 1) && Loads.contains(ST);
214
215   // All Loads should be operands of KeyAMX.
216   // All tile operands of KeyAMX should come from Loads.
217   for (Value *Op : KeyAMX->operands()) {
218     if (Op->getType()->isX86_AMXTy())
219       if (!Loads.erase(Op))
220         return false;
221   }
222
223   // The def of KeyAMX should be stored into mem.
224   // Todo: is it key amx can be no def?
225   return Loads.empty() && (ST == cast<Value>(KeyAMX));
226 }
227
228 bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
229                                       SmallVector<Value *, 8> &Shapes) {
230   for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
231     Value *Op = KeyAMX->getOperand(I);
232     if (!Op->getType()->isX86_AMXTy())
233       continue;
234     IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
235     assert((TileDef && isTileLoad(TileDef)) &&
236            "All KeyAMX's tile definiation should comes from TileLoad!");
237     Shapes.push_back(TileDef->getOperand(0));
238     Shapes.push_back(TileDef->getOperand(1));
239   }
240   if (!isTileStore(KeyAMX)) {
241     Shapes.push_back(KeyAMX->getOperand(0));
242     Shapes.push_back(KeyAMX->getOperand(1));
243   }
244   return Shapes.size() != 0;
245 }
246
247 // Collect the shapes and skip the area of current key amx intrinsic.
248 //
249 // For example:
250 // ...
251 // --------------------------------------------------------------------------
252 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)  record (m,k)
253 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)  record (m,k)
254 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)  record (m,k)
255 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
256 // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
257 // --------------------------------------------------------------------------
258 BasicBlock::iterator
259 X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
260                                           SmallVector<Value *, 8> &Shapes) {
261   IntrinsicInst *KeyAMX = nullptr;
262   BasicBlock *BB = Iter->getParent();
263   BasicBlock::iterator PosEnd = BB->end();
264   SmallSet<Value *, 4> Loads;
265
266   // See TileStore as "Config Position End" and check volatile model.
267   for (auto I = Iter, E = BB->end(); I != E; ++I) {
268     assert(!brokenVolatile(&*I) && "Not reach tile store!");
269     IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
270     if (!II || !isAMXIntrinsic(II))
271       continue;
272
273     if (isTileLoad(II)) {
274       Loads.insert(II);
275     } else if (isTileStore(II)) {
276       if (!checkVolatileModel(Loads, II, KeyAMX))
277         report_fatal_error("Not Volatile AMX Model!");
278       PosEnd = I;
279       break;
280     } else {
281       assert(!KeyAMX && "Too many key amx intrinsic!");
282       KeyAMX = II;
283     }
284   }
285   assert(PosEnd != BB->end() && "Not find TileStore!");
286
287   // See KeyAMX as TileStore if only TileLoad and TileStore.
288   if (!KeyAMX)
289     KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
290
291   // Get Shapes in order.
292   assert(Shapes.empty() && "Shapes should be clean.");
293   getKeyAMXShapes(KeyAMX, Shapes);
294
295   return PosEnd;
296 }
297
298 // Record a key amx area's shapes with its position.
299 // Use the first tileload as its position.
300 // For example:
301 // ...
302 // --------------------------------------------------------------------------
303 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)   <--  pos
304 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)        /
305 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)     shapes:
306 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)    (m,k)(k,n)
307 // call void @llvm.x86.tilestored64.internal(m, n,... td)          (m,n)(m,n)
308 // --------------------------------------------------------------------------
309 bool X86PreAMXConfig::findConfigShapes(PosAndShapesMap &PosAndShapes) {
310   bool Find = false;
311   for (BasicBlock &BB : F) {
312     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
313       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
314       if (!II)
315         continue;
316       if (!isAMXIntrinsic(II))
317         continue;
318       assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
319
320       I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
321       Find = true;
322     }
323   }
324   return Find;
325 }
326
327 // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
328 // e.g. (key amx = tdpbssd)
329 // --------------------------------------------------------------------------
330 // %cfgmem = alloca <16 x i32>, align 4                        * allocate mem
331 // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem       * zero init
332 // ...
333 // ... pre-config shape of %t1                                 *
334 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
335 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
336 // ...                                                         *
337 // ... pre-config shape of %t2                                 *
338 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     * shapes
339 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
340 // ...                                                         *
341 // ... pre-config shape of %t3                                 * of
342 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1     *
343 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2   *
344 // ...                                                         * tiles
345 // ... pre-config shape of %td                                 *
346 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1     *
347 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2   *
348 //
349 // call void @llvm.x86.ldtilecfg(i8* %cfgmem)                  * pre-config
350 // --------------------------------------------------------------------------
351 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
352 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
353 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
354 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
355 // call void @llvm.x86.tilestored64.internal(... td)                     area
356 // --------------------------------------------------------------------------
357 bool X86PreAMXConfig::preTileConfig() {
358   PosAndShapesMap PosAndShapes;
359   bool NeedCfg = findConfigShapes(PosAndShapes);
360   if (!NeedCfg)
361     return false;
362   for (auto &IPAndShapes : PosAndShapes)
363     addTileConfig(IPAndShapes.first, IPAndShapes.second);
364
365   return true;
366 }
367 } // anonymous namespace
368
369 namespace {
370
371 class X86PreAMXConfigPass : public FunctionPass {
372 public:
373   static char ID;
374
375   X86PreAMXConfigPass() : FunctionPass(ID) {
376     initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
377   }
378
379   bool runOnFunction(Function &F) override {
380     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
381     bool C = false;
382
383     // Prepare for fast register allocation at O0.
384     if (TM->getOptLevel() == CodeGenOpt::None) {
385
386       // We pre-config each key AMX intrinsic at O0.
387       // In theory, one tile config can cover several AMX intrinsics, but
388       // it is very diffcult to classify the tile shapes at O0. So here we
389       // let thing be easy, pre-config every key AMX intrinsic.
390       X86PreAMXConfig PCFG(F);
391       C = PCFG.preTileConfig();
392     }
393
394     return C;
395   }
396
397   void getAnalysisUsage(AnalysisUsage &AU) const override {
398     AU.setPreservesCFG();
399     AU.addRequired<TargetPassConfig>();
400   }
401 };
402
403 } // anonymous namespace
404
405 static const char PassName[] = "Pre AMX Tile Config";
406 char X86PreAMXConfigPass::ID = 0;
407 INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
408 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
409 INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
410
411 FunctionPass *llvm::createX86PreAMXConfigPass() {
412   return new X86PreAMXConfigPass();
413 }