Home | History | Annotate | Line # | Download | only in AMDGPU
      1 //===- GCNRegPressure.cpp -------------------------------------------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 ///
      9 /// \file
     10 /// This file implements the GCNRegPressure class.
     11 ///
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "GCNRegPressure.h"
     15 #include "llvm/CodeGen/RegisterPressure.h"
     16 
     17 using namespace llvm;
     18 
     19 #define DEBUG_TYPE "machine-scheduler"
     20 
     21 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
     22 LLVM_DUMP_METHOD
     23 void llvm::printLivesAt(SlotIndex SI,
     24                         const LiveIntervals &LIS,
     25                         const MachineRegisterInfo &MRI) {
     26   dbgs() << "Live regs at " << SI << ": "
     27          << *LIS.getInstructionFromIndex(SI);
     28   unsigned Num = 0;
     29   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
     30     const unsigned Reg = Register::index2VirtReg(I);
     31     if (!LIS.hasInterval(Reg))
     32       continue;
     33     const auto &LI = LIS.getInterval(Reg);
     34     if (LI.hasSubRanges()) {
     35       bool firstTime = true;
     36       for (const auto &S : LI.subranges()) {
     37         if (!S.liveAt(SI)) continue;
     38         if (firstTime) {
     39           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
     40                  << '\n';
     41           firstTime = false;
     42         }
     43         dbgs() << "  " << S << '\n';
     44         ++Num;
     45       }
     46     } else if (LI.liveAt(SI)) {
     47       dbgs() << "  " << LI << '\n';
     48       ++Num;
     49     }
     50   }
     51   if (!Num) dbgs() << "  <none>\n";
     52 }
     53 #endif
     54 
     55 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
     56                    const GCNRPTracker::LiveRegSet &S2) {
     57   if (S1.size() != S2.size())
     58     return false;
     59 
     60   for (const auto &P : S1) {
     61     auto I = S2.find(P.first);
     62     if (I == S2.end() || I->second != P.second)
     63       return false;
     64   }
     65   return true;
     66 }
     67 
     68 
     69 ///////////////////////////////////////////////////////////////////////////////
     70 // GCNRegPressure
     71 
     72 unsigned GCNRegPressure::getRegKind(Register Reg,
     73                                     const MachineRegisterInfo &MRI) {
     74   assert(Reg.isVirtual());
     75   const auto RC = MRI.getRegClass(Reg);
     76   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
     77   return STI->isSGPRClass(RC) ?
     78     (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
     79     STI->hasAGPRs(RC) ?
     80       (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
     81       (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
     82 }
     83 
     84 void GCNRegPressure::inc(unsigned Reg,
     85                          LaneBitmask PrevMask,
     86                          LaneBitmask NewMask,
     87                          const MachineRegisterInfo &MRI) {
     88   if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
     89       SIRegisterInfo::getNumCoveredRegs(PrevMask))
     90     return;
     91 
     92   int Sign = 1;
     93   if (NewMask < PrevMask) {
     94     std::swap(NewMask, PrevMask);
     95     Sign = -1;
     96   }
     97 
     98   switch (auto Kind = getRegKind(Reg, MRI)) {
     99   case SGPR32:
    100   case VGPR32:
    101   case AGPR32:
    102     Value[Kind] += Sign;
    103     break;
    104 
    105   case SGPR_TUPLE:
    106   case VGPR_TUPLE:
    107   case AGPR_TUPLE:
    108     assert(PrevMask < NewMask);
    109 
    110     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
    111       Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
    112 
    113     if (PrevMask.none()) {
    114       assert(NewMask.any());
    115       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
    116     }
    117     break;
    118 
    119   default: llvm_unreachable("Unknown register kind");
    120   }
    121 }
    122 
    123 bool GCNRegPressure::less(const GCNSubtarget &ST,
    124                           const GCNRegPressure& O,
    125                           unsigned MaxOccupancy) const {
    126   const auto SGPROcc = std::min(MaxOccupancy,
    127                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
    128   const auto VGPROcc =
    129     std::min(MaxOccupancy,
    130              ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
    131   const auto OtherSGPROcc = std::min(MaxOccupancy,
    132                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
    133   const auto OtherVGPROcc =
    134     std::min(MaxOccupancy,
    135              ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts())));
    136 
    137   const auto Occ = std::min(SGPROcc, VGPROcc);
    138   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
    139   if (Occ != OtherOcc)
    140     return Occ > OtherOcc;
    141 
    142   bool SGPRImportant = SGPROcc < VGPROcc;
    143   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
    144 
    145   // if both pressures disagree on what is more important compare vgprs
    146   if (SGPRImportant != OtherSGPRImportant) {
    147     SGPRImportant = false;
    148   }
    149 
    150   // compare large regs pressure
    151   bool SGPRFirst = SGPRImportant;
    152   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
    153     if (SGPRFirst) {
    154       auto SW = getSGPRTuplesWeight();
    155       auto OtherSW = O.getSGPRTuplesWeight();
    156       if (SW != OtherSW)
    157         return SW < OtherSW;
    158     } else {
    159       auto VW = getVGPRTuplesWeight();
    160       auto OtherVW = O.getVGPRTuplesWeight();
    161       if (VW != OtherVW)
    162         return VW < OtherVW;
    163     }
    164   }
    165   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
    166                          (getVGPRNum(ST.hasGFX90AInsts()) <
    167                           O.getVGPRNum(ST.hasGFX90AInsts()));
    168 }
    169 
    170 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
    171 LLVM_DUMP_METHOD
    172 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
    173   OS << "VGPRs: " << Value[VGPR32] << ' ';
    174   OS << "AGPRs: " << Value[AGPR32];
    175   if (ST) OS << "(O"
    176              << ST->getOccupancyWithNumVGPRs(getVGPRNum(ST->hasGFX90AInsts()))
    177              << ')';
    178   OS << ", SGPRs: " << getSGPRNum();
    179   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
    180   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
    181      << ", LSGPR WT: " << getSGPRTuplesWeight();
    182   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
    183   OS << '\n';
    184 }
    185 #endif
    186 
    187 static LaneBitmask getDefRegMask(const MachineOperand &MO,
    188                                  const MachineRegisterInfo &MRI) {
    189   assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
    190 
    191   // We don't rely on read-undef flag because in case of tentative schedule
    192   // tracking it isn't set correctly yet. This works correctly however since
    193   // use mask has been tracked before using LIS.
    194   return MO.getSubReg() == 0 ?
    195     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
    196     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
    197 }
    198 
    199 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
    200                                   const MachineRegisterInfo &MRI,
    201                                   const LiveIntervals &LIS) {
    202   assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual());
    203 
    204   if (auto SubReg = MO.getSubReg())
    205     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
    206 
    207   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
    208   if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
    209     return MaxMask;
    210 
    211   // For a tentative schedule LIS isn't updated yet but livemask should remain
    212   // the same on any schedule. Subreg defs can be reordered but they all must
    213   // dominate uses anyway.
    214   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
    215   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
    216 }
    217 
    218 static SmallVector<RegisterMaskPair, 8>
    219 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
    220                       const MachineRegisterInfo &MRI) {
    221   SmallVector<RegisterMaskPair, 8> Res;
    222   for (const auto &MO : MI.operands()) {
    223     if (!MO.isReg() || !MO.getReg().isVirtual())
    224       continue;
    225     if (!MO.isUse() || !MO.readsReg())
    226       continue;
    227 
    228     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
    229 
    230     auto Reg = MO.getReg();
    231     auto I = llvm::find_if(
    232         Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; });
    233     if (I != Res.end())
    234       I->LaneMask |= UsedMask;
    235     else
    236       Res.push_back(RegisterMaskPair(Reg, UsedMask));
    237   }
    238   return Res;
    239 }
    240 
    241 ///////////////////////////////////////////////////////////////////////////////
    242 // GCNRPTracker
    243 
    244 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
    245                                   SlotIndex SI,
    246                                   const LiveIntervals &LIS,
    247                                   const MachineRegisterInfo &MRI) {
    248   LaneBitmask LiveMask;
    249   const auto &LI = LIS.getInterval(Reg);
    250   if (LI.hasSubRanges()) {
    251     for (const auto &S : LI.subranges())
    252       if (S.liveAt(SI)) {
    253         LiveMask |= S.LaneMask;
    254         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
    255                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
    256       }
    257   } else if (LI.liveAt(SI)) {
    258     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
    259   }
    260   return LiveMask;
    261 }
    262 
    263 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
    264                                            const LiveIntervals &LIS,
    265                                            const MachineRegisterInfo &MRI) {
    266   GCNRPTracker::LiveRegSet LiveRegs;
    267   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
    268     auto Reg = Register::index2VirtReg(I);
    269     if (!LIS.hasInterval(Reg))
    270       continue;
    271     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
    272     if (LiveMask.any())
    273       LiveRegs[Reg] = LiveMask;
    274   }
    275   return LiveRegs;
    276 }
    277 
    278 void GCNRPTracker::reset(const MachineInstr &MI,
    279                          const LiveRegSet *LiveRegsCopy,
    280                          bool After) {
    281   const MachineFunction &MF = *MI.getMF();
    282   MRI = &MF.getRegInfo();
    283   if (LiveRegsCopy) {
    284     if (&LiveRegs != LiveRegsCopy)
    285       LiveRegs = *LiveRegsCopy;
    286   } else {
    287     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
    288                      : getLiveRegsBefore(MI, LIS);
    289   }
    290 
    291   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
    292 }
    293 
    294 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
    295                                const LiveRegSet *LiveRegsCopy) {
    296   GCNRPTracker::reset(MI, LiveRegsCopy, true);
    297 }
    298 
    299 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
    300   assert(MRI && "call reset first");
    301 
    302   LastTrackedMI = &MI;
    303 
    304   if (MI.isDebugInstr())
    305     return;
    306 
    307   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
    308 
    309   // calc pressure at the MI (defs + uses)
    310   auto AtMIPressure = CurPressure;
    311   for (const auto &U : RegUses) {
    312     auto LiveMask = LiveRegs[U.RegUnit];
    313     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
    314   }
    315   // update max pressure
    316   MaxPressure = max(AtMIPressure, MaxPressure);
    317 
    318   for (const auto &MO : MI.operands()) {
    319     if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead())
    320       continue;
    321 
    322     auto Reg = MO.getReg();
    323     auto I = LiveRegs.find(Reg);
    324     if (I == LiveRegs.end())
    325       continue;
    326     auto &LiveMask = I->second;
    327     auto PrevMask = LiveMask;
    328     LiveMask &= ~getDefRegMask(MO, *MRI);
    329     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
    330     if (LiveMask.none())
    331       LiveRegs.erase(I);
    332   }
    333   for (const auto &U : RegUses) {
    334     auto &LiveMask = LiveRegs[U.RegUnit];
    335     auto PrevMask = LiveMask;
    336     LiveMask |= U.LaneMask;
    337     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
    338   }
    339   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
    340 }
    341 
    342 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
    343                                  const LiveRegSet *LiveRegsCopy) {
    344   MRI = &MI.getParent()->getParent()->getRegInfo();
    345   LastTrackedMI = nullptr;
    346   MBBEnd = MI.getParent()->end();
    347   NextMI = &MI;
    348   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
    349   if (NextMI == MBBEnd)
    350     return false;
    351   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
    352   return true;
    353 }
    354 
    355 bool GCNDownwardRPTracker::advanceBeforeNext() {
    356   assert(MRI && "call reset first");
    357 
    358   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
    359   if (NextMI == MBBEnd)
    360     return false;
    361 
    362   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
    363   assert(SI.isValid());
    364 
    365   // Remove dead registers or mask bits.
    366   for (auto &It : LiveRegs) {
    367     const LiveInterval &LI = LIS.getInterval(It.first);
    368     if (LI.hasSubRanges()) {
    369       for (const auto &S : LI.subranges()) {
    370         if (!S.liveAt(SI)) {
    371           auto PrevMask = It.second;
    372           It.second &= ~S.LaneMask;
    373           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
    374         }
    375       }
    376     } else if (!LI.liveAt(SI)) {
    377       auto PrevMask = It.second;
    378       It.second = LaneBitmask::getNone();
    379       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
    380     }
    381     if (It.second.none())
    382       LiveRegs.erase(It.first);
    383   }
    384 
    385   MaxPressure = max(MaxPressure, CurPressure);
    386 
    387   return true;
    388 }
    389 
    390 void GCNDownwardRPTracker::advanceToNext() {
    391   LastTrackedMI = &*NextMI++;
    392   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
    393 
    394   // Add new registers or mask bits.
    395   for (const auto &MO : LastTrackedMI->operands()) {
    396     if (!MO.isReg() || !MO.isDef())
    397       continue;
    398     Register Reg = MO.getReg();
    399     if (!Reg.isVirtual())
    400       continue;
    401     auto &LiveMask = LiveRegs[Reg];
    402     auto PrevMask = LiveMask;
    403     LiveMask |= getDefRegMask(MO, *MRI);
    404     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
    405   }
    406 
    407   MaxPressure = max(MaxPressure, CurPressure);
    408 }
    409 
    410 bool GCNDownwardRPTracker::advance() {
    411   // If we have just called reset live set is actual.
    412   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
    413     return false;
    414   advanceToNext();
    415   return true;
    416 }
    417 
    418 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
    419   while (NextMI != End)
    420     if (!advance()) return false;
    421   return true;
    422 }
    423 
    424 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
    425                                    MachineBasicBlock::const_iterator End,
    426                                    const LiveRegSet *LiveRegsCopy) {
    427   reset(*Begin, LiveRegsCopy);
    428   return advance(End);
    429 }
    430 
    431 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
    432 LLVM_DUMP_METHOD
    433 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
    434                            const GCNRPTracker::LiveRegSet &TrackedLR,
    435                            const TargetRegisterInfo *TRI) {
    436   for (auto const &P : TrackedLR) {
    437     auto I = LISLR.find(P.first);
    438     if (I == LISLR.end()) {
    439       dbgs() << "  " << printReg(P.first, TRI)
    440              << ":L" << PrintLaneMask(P.second)
    441              << " isn't found in LIS reported set\n";
    442     }
    443     else if (I->second != P.second) {
    444       dbgs() << "  " << printReg(P.first, TRI)
    445         << " masks doesn't match: LIS reported "
    446         << PrintLaneMask(I->second)
    447         << ", tracked "
    448         << PrintLaneMask(P.second)
    449         << '\n';
    450     }
    451   }
    452   for (auto const &P : LISLR) {
    453     auto I = TrackedLR.find(P.first);
    454     if (I == TrackedLR.end()) {
    455       dbgs() << "  " << printReg(P.first, TRI)
    456              << ":L" << PrintLaneMask(P.second)
    457              << " isn't found in tracked set\n";
    458     }
    459   }
    460 }
    461 
    462 bool GCNUpwardRPTracker::isValid() const {
    463   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
    464   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
    465   const auto &TrackedLR = LiveRegs;
    466 
    467   if (!isEqual(LISLR, TrackedLR)) {
    468     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
    469               " LIS reported livesets mismatch:\n";
    470     printLivesAt(SI, LIS, *MRI);
    471     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
    472     return false;
    473   }
    474 
    475   auto LISPressure = getRegPressure(*MRI, LISLR);
    476   if (LISPressure != CurPressure) {
    477     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
    478     CurPressure.print(dbgs());
    479     dbgs() << "LIS rpt: ";
    480     LISPressure.print(dbgs());
    481     return false;
    482   }
    483   return true;
    484 }
    485 
    486 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
    487                                  const MachineRegisterInfo &MRI) {
    488   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
    489   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
    490     unsigned Reg = Register::index2VirtReg(I);
    491     auto It = LiveRegs.find(Reg);
    492     if (It != LiveRegs.end() && It->second.any())
    493       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
    494          << PrintLaneMask(It->second);
    495   }
    496   OS << '\n';
    497 }
    498 #endif
    499