Home | History | Annotate | Line # | Download | only in Utils
      1 //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file defines a utility class to perform loop versioning.  The versioned
     10 // loop speculates that otherwise may-aliasing memory accesses don't overlap and
     11 // emits checks to prove this.
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #include "llvm/Transforms/Utils/LoopVersioning.h"
     16 #include "llvm/ADT/ArrayRef.h"
     17 #include "llvm/Analysis/LoopAccessAnalysis.h"
     18 #include "llvm/Analysis/LoopInfo.h"
     19 #include "llvm/Analysis/MemorySSA.h"
     20 #include "llvm/Analysis/ScalarEvolution.h"
     21 #include "llvm/Analysis/TargetLibraryInfo.h"
     22 #include "llvm/IR/Dominators.h"
     23 #include "llvm/IR/MDBuilder.h"
     24 #include "llvm/IR/PassManager.h"
     25 #include "llvm/InitializePasses.h"
     26 #include "llvm/Support/CommandLine.h"
     27 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     28 #include "llvm/Transforms/Utils/Cloning.h"
     29 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
     30 
     31 using namespace llvm;
     32 
     33 static cl::opt<bool>
     34     AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
     35                     cl::Hidden,
     36                     cl::desc("Add no-alias annotation for instructions that "
     37                              "are disambiguated by memchecks"));
     38 
     39 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
     40                                ArrayRef<RuntimePointerCheck> Checks, Loop *L,
     41                                LoopInfo *LI, DominatorTree *DT,
     42                                ScalarEvolution *SE)
     43     : VersionedLoop(L), NonVersionedLoop(nullptr),
     44       AliasChecks(Checks.begin(), Checks.end()),
     45       Preds(LAI.getPSE().getUnionPredicate()), LAI(LAI), LI(LI), DT(DT),
     46       SE(SE) {
     47   assert(L->getUniqueExitBlock() && "No single exit block");
     48 }
     49 
     50 void LoopVersioning::versionLoop(
     51     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
     52   assert(VersionedLoop->isLoopSimplifyForm() &&
     53          "Loop is not in loop-simplify form");
     54 
     55   Instruction *FirstCheckInst;
     56   Instruction *MemRuntimeCheck;
     57   Value *SCEVRuntimeCheck;
     58   Value *RuntimeCheck = nullptr;
     59 
     60   // Add the memcheck in the original preheader (this is empty initially).
     61   BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
     62   const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
     63 
     64   SCEVExpander Exp2(*RtPtrChecking.getSE(),
     65                     VersionedLoop->getHeader()->getModule()->getDataLayout(),
     66                     "induction");
     67   std::tie(FirstCheckInst, MemRuntimeCheck) = addRuntimeChecks(
     68       RuntimeCheckBB->getTerminator(), VersionedLoop, AliasChecks, Exp2);
     69 
     70   SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
     71                    "scev.check");
     72   SCEVRuntimeCheck =
     73       Exp.expandCodeForPredicate(&Preds, RuntimeCheckBB->getTerminator());
     74   auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
     75 
     76   // Discard the SCEV runtime check if it is always true.
     77   if (CI && CI->isZero())
     78     SCEVRuntimeCheck = nullptr;
     79 
     80   if (MemRuntimeCheck && SCEVRuntimeCheck) {
     81     RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
     82                                           SCEVRuntimeCheck, "lver.safe");
     83     if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
     84       I->insertBefore(RuntimeCheckBB->getTerminator());
     85   } else
     86     RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
     87 
     88   assert(RuntimeCheck && "called even though we don't need "
     89                          "any runtime checks");
     90 
     91   // Rename the block to make the IR more readable.
     92   RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
     93                           ".lver.check");
     94 
     95   // Create empty preheader for the loop (and after cloning for the
     96   // non-versioned loop).
     97   BasicBlock *PH =
     98       SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
     99                  nullptr, VersionedLoop->getHeader()->getName() + ".ph");
    100 
    101   // Clone the loop including the preheader.
    102   //
    103   // FIXME: This does not currently preserve SimplifyLoop because the exit
    104   // block is a join between the two loops.
    105   SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
    106   NonVersionedLoop =
    107       cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
    108                              ".lver.orig", LI, DT, NonVersionedLoopBlocks);
    109   remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
    110 
    111   // Insert the conditional branch based on the result of the memchecks.
    112   Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
    113   BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
    114                      VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
    115   OrigTerm->eraseFromParent();
    116 
    117   // The loops merge in the original exit block.  This is now dominated by the
    118   // memchecking block.
    119   DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
    120 
    121   // Adds the necessary PHI nodes for the versioned loops based on the
    122   // loop-defined values used outside of the loop.
    123   addPHINodes(DefsUsedOutside);
    124   formDedicatedExitBlocks(NonVersionedLoop, DT, LI, nullptr, true);
    125   formDedicatedExitBlocks(VersionedLoop, DT, LI, nullptr, true);
    126   assert(NonVersionedLoop->isLoopSimplifyForm() &&
    127          VersionedLoop->isLoopSimplifyForm() &&
    128          "The versioned loops should be in simplify form.");
    129 }
    130 
    131 void LoopVersioning::addPHINodes(
    132     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
    133   BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
    134   assert(PHIBlock && "No single successor to loop exit block");
    135   PHINode *PN;
    136 
    137   // First add a single-operand PHI for each DefsUsedOutside if one does not
    138   // exists yet.
    139   for (auto *Inst : DefsUsedOutside) {
    140     // See if we have a single-operand PHI with the value defined by the
    141     // original loop.
    142     for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
    143       if (PN->getIncomingValue(0) == Inst)
    144         break;
    145     }
    146     // If not create it.
    147     if (!PN) {
    148       PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
    149                            &PHIBlock->front());
    150       SmallVector<User*, 8> UsersToUpdate;
    151       for (User *U : Inst->users())
    152         if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
    153           UsersToUpdate.push_back(U);
    154       for (User *U : UsersToUpdate)
    155         U->replaceUsesOfWith(Inst, PN);
    156       PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
    157     }
    158   }
    159 
    160   // Then for each PHI add the operand for the edge from the cloned loop.
    161   for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
    162     assert(PN->getNumOperands() == 1 &&
    163            "Exit block should only have on predecessor");
    164 
    165     // If the definition was cloned used that otherwise use the same value.
    166     Value *ClonedValue = PN->getIncomingValue(0);
    167     auto Mapped = VMap.find(ClonedValue);
    168     if (Mapped != VMap.end())
    169       ClonedValue = Mapped->second;
    170 
    171     PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
    172   }
    173 }
    174 
    175 void LoopVersioning::prepareNoAliasMetadata() {
    176   // We need to turn the no-alias relation between pointer checking groups into
    177   // no-aliasing annotations between instructions.
    178   //
    179   // We accomplish this by mapping each pointer checking group (a set of
    180   // pointers memchecked together) to an alias scope and then also mapping each
    181   // group to the list of scopes it can't alias.
    182 
    183   const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
    184   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
    185 
    186   // First allocate an aliasing scope for each pointer checking group.
    187   //
    188   // While traversing through the checking groups in the loop, also create a
    189   // reverse map from pointers to the pointer checking group they were assigned
    190   // to.
    191   MDBuilder MDB(Context);
    192   MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
    193 
    194   for (const auto &Group : RtPtrChecking->CheckingGroups) {
    195     GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
    196 
    197     for (unsigned PtrIdx : Group.Members)
    198       PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
    199   }
    200 
    201   // Go through the checks and for each pointer group, collect the scopes for
    202   // each non-aliasing pointer group.
    203   DenseMap<const RuntimeCheckingPtrGroup *, SmallVector<Metadata *, 4>>
    204       GroupToNonAliasingScopes;
    205 
    206   for (const auto &Check : AliasChecks)
    207     GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
    208 
    209   // Finally, transform the above to actually map to scope list which is what
    210   // the metadata uses.
    211 
    212   for (auto Pair : GroupToNonAliasingScopes)
    213     GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
    214 }
    215 
    216 void LoopVersioning::annotateLoopWithNoAlias() {
    217   if (!AnnotateNoAlias)
    218     return;
    219 
    220   // First prepare the maps.
    221   prepareNoAliasMetadata();
    222 
    223   // Add the scope and no-alias metadata to the instructions.
    224   for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
    225     annotateInstWithNoAlias(I);
    226   }
    227 }
    228 
    229 void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
    230                                              const Instruction *OrigInst) {
    231   if (!AnnotateNoAlias)
    232     return;
    233 
    234   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
    235   const Value *Ptr = isa<LoadInst>(OrigInst)
    236                          ? cast<LoadInst>(OrigInst)->getPointerOperand()
    237                          : cast<StoreInst>(OrigInst)->getPointerOperand();
    238 
    239   // Find the group for the pointer and then add the scope metadata.
    240   auto Group = PtrToGroup.find(Ptr);
    241   if (Group != PtrToGroup.end()) {
    242     VersionedInst->setMetadata(
    243         LLVMContext::MD_alias_scope,
    244         MDNode::concatenate(
    245             VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
    246             MDNode::get(Context, GroupToScope[Group->second])));
    247 
    248     // Add the no-alias metadata.
    249     auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
    250     if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
    251       VersionedInst->setMetadata(
    252           LLVMContext::MD_noalias,
    253           MDNode::concatenate(
    254               VersionedInst->getMetadata(LLVMContext::MD_noalias),
    255               NonAliasingScopeList->second));
    256   }
    257 }
    258 
    259 namespace {
    260 bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA,
    261              DominatorTree *DT, ScalarEvolution *SE) {
    262   // Build up a worklist of inner-loops to version. This is necessary as the
    263   // act of versioning a loop creates new loops and can invalidate iterators
    264   // across the loops.
    265   SmallVector<Loop *, 8> Worklist;
    266 
    267   for (Loop *TopLevelLoop : *LI)
    268     for (Loop *L : depth_first(TopLevelLoop))
    269       // We only handle inner-most loops.
    270       if (L->isInnermost())
    271         Worklist.push_back(L);
    272 
    273   // Now walk the identified inner loops.
    274   bool Changed = false;
    275   for (Loop *L : Worklist) {
    276     if (!L->isLoopSimplifyForm() || !L->isRotatedForm() ||
    277         !L->getExitingBlock())
    278       continue;
    279     const LoopAccessInfo &LAI = GetLAA(*L);
    280     if (!LAI.hasConvergentOp() &&
    281         (LAI.getNumRuntimePointerChecks() ||
    282          !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) {
    283       LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L,
    284                           LI, DT, SE);
    285       LVer.versionLoop();
    286       LVer.annotateLoopWithNoAlias();
    287       Changed = true;
    288     }
    289   }
    290 
    291   return Changed;
    292 }
    293 
    294 /// Also expose this is a pass.  Currently this is only used for
    295 /// unit-testing.  It adds all memchecks necessary to remove all may-aliasing
    296 /// array accesses from the loop.
    297 class LoopVersioningLegacyPass : public FunctionPass {
    298 public:
    299   LoopVersioningLegacyPass() : FunctionPass(ID) {
    300     initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry());
    301   }
    302 
    303   bool runOnFunction(Function &F) override {
    304     auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
    305     auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & {
    306       return getAnalysis<LoopAccessLegacyAnalysis>().getInfo(&L);
    307     };
    308 
    309     auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
    310     auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
    311 
    312     return runImpl(LI, GetLAA, DT, SE);
    313   }
    314 
    315   void getAnalysisUsage(AnalysisUsage &AU) const override {
    316     AU.addRequired<LoopInfoWrapperPass>();
    317     AU.addPreserved<LoopInfoWrapperPass>();
    318     AU.addRequired<LoopAccessLegacyAnalysis>();
    319     AU.addRequired<DominatorTreeWrapperPass>();
    320     AU.addPreserved<DominatorTreeWrapperPass>();
    321     AU.addRequired<ScalarEvolutionWrapperPass>();
    322   }
    323 
    324   static char ID;
    325 };
    326 }
    327 
    328 #define LVER_OPTION "loop-versioning"
    329 #define DEBUG_TYPE LVER_OPTION
    330 
    331 char LoopVersioningLegacyPass::ID;
    332 static const char LVer_name[] = "Loop Versioning";
    333 
    334 INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
    335                       false)
    336 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
    337 INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
    338 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
    339 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
    340 INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false,
    341                     false)
    342 
    343 namespace llvm {
    344 FunctionPass *createLoopVersioningLegacyPass() {
    345   return new LoopVersioningLegacyPass();
    346 }
    347 
    348 PreservedAnalyses LoopVersioningPass::run(Function &F,
    349                                           FunctionAnalysisManager &AM) {
    350   auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
    351   auto &LI = AM.getResult<LoopAnalysis>(F);
    352   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
    353   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
    354   auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
    355   auto &AA = AM.getResult<AAManager>(F);
    356   auto &AC = AM.getResult<AssumptionAnalysis>(F);
    357   MemorySSA *MSSA = EnableMSSALoopDependency
    358                         ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA()
    359                         : nullptr;
    360 
    361   auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager();
    362   auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & {
    363     LoopStandardAnalysisResults AR = {AA,  AC,  DT,      LI,  SE,
    364                                       TLI, TTI, nullptr, MSSA};
    365     return LAM.getResult<LoopAccessAnalysis>(L, AR);
    366   };
    367 
    368   if (runImpl(&LI, GetLAA, &DT, &SE))
    369     return PreservedAnalyses::none();
    370   return PreservedAnalyses::all();
    371 }
    372 } // namespace llvm
    373