Home | History | Annotate | Line # | Download | only in CodeGen
      1 //===- RDFRegisters.cpp ---------------------------------------------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 
      9 #include "llvm/ADT/BitVector.h"
     10 #include "llvm/CodeGen/MachineFunction.h"
     11 #include "llvm/CodeGen/MachineInstr.h"
     12 #include "llvm/CodeGen/MachineOperand.h"
     13 #include "llvm/CodeGen/RDFRegisters.h"
     14 #include "llvm/CodeGen/TargetRegisterInfo.h"
     15 #include "llvm/MC/LaneBitmask.h"
     16 #include "llvm/MC/MCRegisterInfo.h"
     17 #include "llvm/Support/ErrorHandling.h"
     18 #include "llvm/Support/raw_ostream.h"
     19 #include <cassert>
     20 #include <cstdint>
     21 #include <set>
     22 #include <utility>
     23 
     24 using namespace llvm;
     25 using namespace rdf;
     26 
     27 PhysicalRegisterInfo::PhysicalRegisterInfo(const TargetRegisterInfo &tri,
     28       const MachineFunction &mf)
     29     : TRI(tri) {
     30   RegInfos.resize(TRI.getNumRegs());
     31 
     32   BitVector BadRC(TRI.getNumRegs());
     33   for (const TargetRegisterClass *RC : TRI.regclasses()) {
     34     for (MCPhysReg R : *RC) {
     35       RegInfo &RI = RegInfos[R];
     36       if (RI.RegClass != nullptr && !BadRC[R]) {
     37         if (RC->LaneMask != RI.RegClass->LaneMask) {
     38           BadRC.set(R);
     39           RI.RegClass = nullptr;
     40         }
     41       } else
     42         RI.RegClass = RC;
     43     }
     44   }
     45 
     46   UnitInfos.resize(TRI.getNumRegUnits());
     47 
     48   for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) {
     49     if (UnitInfos[U].Reg != 0)
     50       continue;
     51     MCRegUnitRootIterator R(U, &TRI);
     52     assert(R.isValid());
     53     RegisterId F = *R;
     54     ++R;
     55     if (R.isValid()) {
     56       UnitInfos[U].Mask = LaneBitmask::getAll();
     57       UnitInfos[U].Reg = F;
     58     } else {
     59       for (MCRegUnitMaskIterator I(F, &TRI); I.isValid(); ++I) {
     60         std::pair<uint32_t,LaneBitmask> P = *I;
     61         UnitInfo &UI = UnitInfos[P.first];
     62         UI.Reg = F;
     63         if (P.second.any()) {
     64           UI.Mask = P.second;
     65         } else {
     66           if (const TargetRegisterClass *RC = RegInfos[F].RegClass)
     67             UI.Mask = RC->LaneMask;
     68           else
     69             UI.Mask = LaneBitmask::getAll();
     70         }
     71       }
     72     }
     73   }
     74 
     75   for (const uint32_t *RM : TRI.getRegMasks())
     76     RegMasks.insert(RM);
     77   for (const MachineBasicBlock &B : mf)
     78     for (const MachineInstr &In : B)
     79       for (const MachineOperand &Op : In.operands())
     80         if (Op.isRegMask())
     81           RegMasks.insert(Op.getRegMask());
     82 
     83   MaskInfos.resize(RegMasks.size()+1);
     84   for (uint32_t M = 1, NM = RegMasks.size(); M <= NM; ++M) {
     85     BitVector PU(TRI.getNumRegUnits());
     86     const uint32_t *MB = RegMasks.get(M);
     87     for (unsigned I = 1, E = TRI.getNumRegs(); I != E; ++I) {
     88       if (!(MB[I / 32] & (1u << (I % 32))))
     89         continue;
     90       for (MCRegUnitIterator U(MCRegister::from(I), &TRI); U.isValid(); ++U)
     91         PU.set(*U);
     92     }
     93     MaskInfos[M].Units = PU.flip();
     94   }
     95 
     96   AliasInfos.resize(TRI.getNumRegUnits());
     97   for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) {
     98     BitVector AS(TRI.getNumRegs());
     99     for (MCRegUnitRootIterator R(U, &TRI); R.isValid(); ++R)
    100       for (MCSuperRegIterator S(*R, &TRI, true); S.isValid(); ++S)
    101         AS.set(*S);
    102     AliasInfos[U].Regs = AS;
    103   }
    104 }
    105 
    106 std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const {
    107   // Do not include RR in the alias set.
    108   std::set<RegisterId> AS;
    109   assert(isRegMaskId(Reg) || Register::isPhysicalRegister(Reg));
    110   if (isRegMaskId(Reg)) {
    111     // XXX SLOW
    112     const uint32_t *MB = getRegMaskBits(Reg);
    113     for (unsigned i = 1, e = TRI.getNumRegs(); i != e; ++i) {
    114       if (MB[i/32] & (1u << (i%32)))
    115         continue;
    116       AS.insert(i);
    117     }
    118     for (const uint32_t *RM : RegMasks) {
    119       RegisterId MI = getRegMaskId(RM);
    120       if (MI != Reg && aliasMM(RegisterRef(Reg), RegisterRef(MI)))
    121         AS.insert(MI);
    122     }
    123     return AS;
    124   }
    125 
    126   for (MCRegAliasIterator AI(Reg, &TRI, false); AI.isValid(); ++AI)
    127     AS.insert(*AI);
    128   for (const uint32_t *RM : RegMasks) {
    129     RegisterId MI = getRegMaskId(RM);
    130     if (aliasRM(RegisterRef(Reg), RegisterRef(MI)))
    131       AS.insert(MI);
    132   }
    133   return AS;
    134 }
    135 
    136 bool PhysicalRegisterInfo::aliasRR(RegisterRef RA, RegisterRef RB) const {
    137   assert(Register::isPhysicalRegister(RA.Reg));
    138   assert(Register::isPhysicalRegister(RB.Reg));
    139 
    140   MCRegUnitMaskIterator UMA(RA.Reg, &TRI);
    141   MCRegUnitMaskIterator UMB(RB.Reg, &TRI);
    142   // Reg units are returned in the numerical order.
    143   while (UMA.isValid() && UMB.isValid()) {
    144     // Skip units that are masked off in RA.
    145     std::pair<RegisterId,LaneBitmask> PA = *UMA;
    146     if (PA.second.any() && (PA.second & RA.Mask).none()) {
    147       ++UMA;
    148       continue;
    149     }
    150     // Skip units that are masked off in RB.
    151     std::pair<RegisterId,LaneBitmask> PB = *UMB;
    152     if (PB.second.any() && (PB.second & RB.Mask).none()) {
    153       ++UMB;
    154       continue;
    155     }
    156 
    157     if (PA.first == PB.first)
    158       return true;
    159     if (PA.first < PB.first)
    160       ++UMA;
    161     else if (PB.first < PA.first)
    162       ++UMB;
    163   }
    164   return false;
    165 }
    166 
    167 bool PhysicalRegisterInfo::aliasRM(RegisterRef RR, RegisterRef RM) const {
    168   assert(Register::isPhysicalRegister(RR.Reg) && isRegMaskId(RM.Reg));
    169   const uint32_t *MB = getRegMaskBits(RM.Reg);
    170   bool Preserved = MB[RR.Reg/32] & (1u << (RR.Reg%32));
    171   // If the lane mask information is "full", e.g. when the given lane mask
    172   // is a superset of the lane mask from the register class, check the regmask
    173   // bit directly.
    174   if (RR.Mask == LaneBitmask::getAll())
    175     return !Preserved;
    176   const TargetRegisterClass *RC = RegInfos[RR.Reg].RegClass;
    177   if (RC != nullptr && (RR.Mask & RC->LaneMask) == RC->LaneMask)
    178     return !Preserved;
    179 
    180   // Otherwise, check all subregisters whose lane mask overlaps the given
    181   // mask. For each such register, if it is preserved by the regmask, then
    182   // clear the corresponding bits in the given mask. If at the end, all
    183   // bits have been cleared, the register does not alias the regmask (i.e.
    184   // is it preserved by it).
    185   LaneBitmask M = RR.Mask;
    186   for (MCSubRegIndexIterator SI(RR.Reg, &TRI); SI.isValid(); ++SI) {
    187     LaneBitmask SM = TRI.getSubRegIndexLaneMask(SI.getSubRegIndex());
    188     if ((SM & RR.Mask).none())
    189       continue;
    190     unsigned SR = SI.getSubReg();
    191     if (!(MB[SR/32] & (1u << (SR%32))))
    192       continue;
    193     // The subregister SR is preserved.
    194     M &= ~SM;
    195     if (M.none())
    196       return false;
    197   }
    198 
    199   return true;
    200 }
    201 
    202 bool PhysicalRegisterInfo::aliasMM(RegisterRef RM, RegisterRef RN) const {
    203   assert(isRegMaskId(RM.Reg) && isRegMaskId(RN.Reg));
    204   unsigned NumRegs = TRI.getNumRegs();
    205   const uint32_t *BM = getRegMaskBits(RM.Reg);
    206   const uint32_t *BN = getRegMaskBits(RN.Reg);
    207 
    208   for (unsigned w = 0, nw = NumRegs/32; w != nw; ++w) {
    209     // Intersect the negations of both words. Disregard reg=0,
    210     // i.e. 0th bit in the 0th word.
    211     uint32_t C = ~BM[w] & ~BN[w];
    212     if (w == 0)
    213       C &= ~1;
    214     if (C)
    215       return true;
    216   }
    217 
    218   // Check the remaining registers in the last word.
    219   unsigned TailRegs = NumRegs % 32;
    220   if (TailRegs == 0)
    221     return false;
    222   unsigned TW = NumRegs / 32;
    223   uint32_t TailMask = (1u << TailRegs) - 1;
    224   if (~BM[TW] & ~BN[TW] & TailMask)
    225     return true;
    226 
    227   return false;
    228 }
    229 
    230 RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, unsigned R) const {
    231   if (RR.Reg == R)
    232     return RR;
    233   if (unsigned Idx = TRI.getSubRegIndex(R, RR.Reg))
    234     return RegisterRef(R, TRI.composeSubRegIndexLaneMask(Idx, RR.Mask));
    235   if (unsigned Idx = TRI.getSubRegIndex(RR.Reg, R)) {
    236     const RegInfo &RI = RegInfos[R];
    237     LaneBitmask RCM = RI.RegClass ? RI.RegClass->LaneMask
    238                                   : LaneBitmask::getAll();
    239     LaneBitmask M = TRI.reverseComposeSubRegIndexLaneMask(Idx, RR.Mask);
    240     return RegisterRef(R, M & RCM);
    241   }
    242   llvm_unreachable("Invalid arguments: unrelated registers?");
    243 }
    244 
    245 bool RegisterAggr::hasAliasOf(RegisterRef RR) const {
    246   if (PhysicalRegisterInfo::isRegMaskId(RR.Reg))
    247     return Units.anyCommon(PRI.getMaskUnits(RR.Reg));
    248 
    249   for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
    250     std::pair<uint32_t,LaneBitmask> P = *U;
    251     if (P.second.none() || (P.second & RR.Mask).any())
    252       if (Units.test(P.first))
    253         return true;
    254   }
    255   return false;
    256 }
    257 
    258 bool RegisterAggr::hasCoverOf(RegisterRef RR) const {
    259   if (PhysicalRegisterInfo::isRegMaskId(RR.Reg)) {
    260     BitVector T(PRI.getMaskUnits(RR.Reg));
    261     return T.reset(Units).none();
    262   }
    263 
    264   for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
    265     std::pair<uint32_t,LaneBitmask> P = *U;
    266     if (P.second.none() || (P.second & RR.Mask).any())
    267       if (!Units.test(P.first))
    268         return false;
    269   }
    270   return true;
    271 }
    272 
    273 RegisterAggr &RegisterAggr::insert(RegisterRef RR) {
    274   if (PhysicalRegisterInfo::isRegMaskId(RR.Reg)) {
    275     Units |= PRI.getMaskUnits(RR.Reg);
    276     return *this;
    277   }
    278 
    279   for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
    280     std::pair<uint32_t,LaneBitmask> P = *U;
    281     if (P.second.none() || (P.second & RR.Mask).any())
    282       Units.set(P.first);
    283   }
    284   return *this;
    285 }
    286 
    287 RegisterAggr &RegisterAggr::insert(const RegisterAggr &RG) {
    288   Units |= RG.Units;
    289   return *this;
    290 }
    291 
    292 RegisterAggr &RegisterAggr::intersect(RegisterRef RR) {
    293   return intersect(RegisterAggr(PRI).insert(RR));
    294 }
    295 
    296 RegisterAggr &RegisterAggr::intersect(const RegisterAggr &RG) {
    297   Units &= RG.Units;
    298   return *this;
    299 }
    300 
    301 RegisterAggr &RegisterAggr::clear(RegisterRef RR) {
    302   return clear(RegisterAggr(PRI).insert(RR));
    303 }
    304 
    305 RegisterAggr &RegisterAggr::clear(const RegisterAggr &RG) {
    306   Units.reset(RG.Units);
    307   return *this;
    308 }
    309 
    310 RegisterRef RegisterAggr::intersectWith(RegisterRef RR) const {
    311   RegisterAggr T(PRI);
    312   T.insert(RR).intersect(*this);
    313   if (T.empty())
    314     return RegisterRef();
    315   RegisterRef NR = T.makeRegRef();
    316   assert(NR);
    317   return NR;
    318 }
    319 
    320 RegisterRef RegisterAggr::clearIn(RegisterRef RR) const {
    321   return RegisterAggr(PRI).insert(RR).clear(*this).makeRegRef();
    322 }
    323 
    324 RegisterRef RegisterAggr::makeRegRef() const {
    325   int U = Units.find_first();
    326   if (U < 0)
    327     return RegisterRef();
    328 
    329   // Find the set of all registers that are aliased to all the units
    330   // in this aggregate.
    331 
    332   // Get all the registers aliased to the first unit in the bit vector.
    333   BitVector Regs = PRI.getUnitAliases(U);
    334   U = Units.find_next(U);
    335 
    336   // For each other unit, intersect it with the set of all registers
    337   // aliased that unit.
    338   while (U >= 0) {
    339     Regs &= PRI.getUnitAliases(U);
    340     U = Units.find_next(U);
    341   }
    342 
    343   // If there is at least one register remaining, pick the first one,
    344   // and consolidate the masks of all of its units contained in this
    345   // aggregate.
    346 
    347   int F = Regs.find_first();
    348   if (F <= 0)
    349     return RegisterRef();
    350 
    351   LaneBitmask M;
    352   for (MCRegUnitMaskIterator I(F, &PRI.getTRI()); I.isValid(); ++I) {
    353     std::pair<uint32_t,LaneBitmask> P = *I;
    354     if (Units.test(P.first))
    355       M |= P.second.none() ? LaneBitmask::getAll() : P.second;
    356   }
    357   return RegisterRef(F, M);
    358 }
    359 
    360 void RegisterAggr::print(raw_ostream &OS) const {
    361   OS << '{';
    362   for (int U = Units.find_first(); U >= 0; U = Units.find_next(U))
    363     OS << ' ' << printRegUnit(U, &PRI.getTRI());
    364   OS << " }";
    365 }
    366 
    367 RegisterAggr::rr_iterator::rr_iterator(const RegisterAggr &RG,
    368       bool End)
    369     : Owner(&RG) {
    370   for (int U = RG.Units.find_first(); U >= 0; U = RG.Units.find_next(U)) {
    371     RegisterRef R = RG.PRI.getRefForUnit(U);
    372     Masks[R.Reg] |= R.Mask;
    373   }
    374   Pos = End ? Masks.end() : Masks.begin();
    375   Index = End ? Masks.size() : 0;
    376 }
    377 
    378 raw_ostream &rdf::operator<<(raw_ostream &OS, const RegisterAggr &A) {
    379   A.print(OS);
    380   return OS;
    381 }
    382