Home | History | Annotate | Line # | Download | only in CodeGen
      1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This pass tries to expand memcmp() calls into optimally-sized loads and
     10 // compares for the target.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "llvm/ADT/Statistic.h"
     15 #include "llvm/Analysis/ConstantFolding.h"
     16 #include "llvm/Analysis/DomTreeUpdater.h"
     17 #include "llvm/Analysis/LazyBlockFrequencyInfo.h"
     18 #include "llvm/Analysis/ProfileSummaryInfo.h"
     19 #include "llvm/Analysis/TargetLibraryInfo.h"
     20 #include "llvm/Analysis/TargetTransformInfo.h"
     21 #include "llvm/Analysis/ValueTracking.h"
     22 #include "llvm/CodeGen/TargetLowering.h"
     23 #include "llvm/CodeGen/TargetPassConfig.h"
     24 #include "llvm/CodeGen/TargetSubtargetInfo.h"
     25 #include "llvm/IR/Dominators.h"
     26 #include "llvm/IR/IRBuilder.h"
     27 #include "llvm/InitializePasses.h"
     28 #include "llvm/Target/TargetMachine.h"
     29 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     30 #include "llvm/Transforms/Utils/Local.h"
     31 #include "llvm/Transforms/Utils/SizeOpts.h"
     32 
     33 using namespace llvm;
     34 
     35 #define DEBUG_TYPE "expandmemcmp"
     36 
     37 STATISTIC(NumMemCmpCalls, "Number of memcmp calls");
     38 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size");
     39 STATISTIC(NumMemCmpGreaterThanMax,
     40           "Number of memcmp calls with size greater than max size");
     41 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls");
     42 
     43 static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock(
     44     "memcmp-num-loads-per-block", cl::Hidden, cl::init(1),
     45     cl::desc("The number of loads per basic block for inline expansion of "
     46              "memcmp that is only being compared against zero."));
     47 
     48 static cl::opt<unsigned> MaxLoadsPerMemcmp(
     49     "max-loads-per-memcmp", cl::Hidden,
     50     cl::desc("Set maximum number of loads used in expanded memcmp"));
     51 
     52 static cl::opt<unsigned> MaxLoadsPerMemcmpOptSize(
     53     "max-loads-per-memcmp-opt-size", cl::Hidden,
     54     cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz"));
     55 
     56 namespace {
     57 
     58 
     59 // This class provides helper functions to expand a memcmp library call into an
     60 // inline expansion.
     61 class MemCmpExpansion {
     62   struct ResultBlock {
     63     BasicBlock *BB = nullptr;
     64     PHINode *PhiSrc1 = nullptr;
     65     PHINode *PhiSrc2 = nullptr;
     66 
     67     ResultBlock() = default;
     68   };
     69 
     70   CallInst *const CI;
     71   ResultBlock ResBlock;
     72   const uint64_t Size;
     73   unsigned MaxLoadSize;
     74   uint64_t NumLoadsNonOneByte;
     75   const uint64_t NumLoadsPerBlockForZeroCmp;
     76   std::vector<BasicBlock *> LoadCmpBlocks;
     77   BasicBlock *EndBlock;
     78   PHINode *PhiRes;
     79   const bool IsUsedForZeroCmp;
     80   const DataLayout &DL;
     81   DomTreeUpdater *DTU;
     82   IRBuilder<> Builder;
     83   // Represents the decomposition in blocks of the expansion. For example,
     84   // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
     85   // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {1, 32}.
     86   struct LoadEntry {
     87     LoadEntry(unsigned LoadSize, uint64_t Offset)
     88         : LoadSize(LoadSize), Offset(Offset) {
     89     }
     90 
     91     // The size of the load for this block, in bytes.
     92     unsigned LoadSize;
     93     // The offset of this load from the base pointer, in bytes.
     94     uint64_t Offset;
     95   };
     96   using LoadEntryVector = SmallVector<LoadEntry, 8>;
     97   LoadEntryVector LoadSequence;
     98 
     99   void createLoadCmpBlocks();
    100   void createResultBlock();
    101   void setupResultBlockPHINodes();
    102   void setupEndBlockPHINodes();
    103   Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
    104   void emitLoadCompareBlock(unsigned BlockIndex);
    105   void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
    106                                          unsigned &LoadIndex);
    107   void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes);
    108   void emitMemCmpResultBlock();
    109   Value *getMemCmpExpansionZeroCase();
    110   Value *getMemCmpEqZeroOneBlock();
    111   Value *getMemCmpOneBlock();
    112   struct LoadPair {
    113     Value *Lhs = nullptr;
    114     Value *Rhs = nullptr;
    115   };
    116   LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType,
    117                        unsigned OffsetBytes);
    118 
    119   static LoadEntryVector
    120   computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
    121                             unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte);
    122   static LoadEntryVector
    123   computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize,
    124                                  unsigned MaxNumLoads,
    125                                  unsigned &NumLoadsNonOneByte);
    126 
    127 public:
    128   MemCmpExpansion(CallInst *CI, uint64_t Size,
    129                   const TargetTransformInfo::MemCmpExpansionOptions &Options,
    130                   const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout,
    131                   DomTreeUpdater *DTU);
    132 
    133   unsigned getNumBlocks();
    134   uint64_t getNumLoads() const { return LoadSequence.size(); }
    135 
    136   Value *getMemCmpExpansion();
    137 };
    138 
    139 MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence(
    140     uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
    141     const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) {
    142   NumLoadsNonOneByte = 0;
    143   LoadEntryVector LoadSequence;
    144   uint64_t Offset = 0;
    145   while (Size && !LoadSizes.empty()) {
    146     const unsigned LoadSize = LoadSizes.front();
    147     const uint64_t NumLoadsForThisSize = Size / LoadSize;
    148     if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
    149       // Do not expand if the total number of loads is larger than what the
    150       // target allows. Note that it's important that we exit before completing
    151       // the expansion to avoid using a ton of memory to store the expansion for
    152       // large sizes.
    153       return {};
    154     }
    155     if (NumLoadsForThisSize > 0) {
    156       for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
    157         LoadSequence.push_back({LoadSize, Offset});
    158         Offset += LoadSize;
    159       }
    160       if (LoadSize > 1)
    161         ++NumLoadsNonOneByte;
    162       Size = Size % LoadSize;
    163     }
    164     LoadSizes = LoadSizes.drop_front();
    165   }
    166   return LoadSequence;
    167 }
    168 
    169 MemCmpExpansion::LoadEntryVector
    170 MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
    171                                                 const unsigned MaxLoadSize,
    172                                                 const unsigned MaxNumLoads,
    173                                                 unsigned &NumLoadsNonOneByte) {
    174   // These are already handled by the greedy approach.
    175   if (Size < 2 || MaxLoadSize < 2)
    176     return {};
    177 
    178   // We try to do as many non-overlapping loads as possible starting from the
    179   // beginning.
    180   const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize;
    181   assert(NumNonOverlappingLoads && "there must be at least one load");
    182   // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with
    183   // an overlapping load.
    184   Size = Size - NumNonOverlappingLoads * MaxLoadSize;
    185   // Bail if we do not need an overloapping store, this is already handled by
    186   // the greedy approach.
    187   if (Size == 0)
    188     return {};
    189   // Bail if the number of loads (non-overlapping + potential overlapping one)
    190   // is larger than the max allowed.
    191   if ((NumNonOverlappingLoads + 1) > MaxNumLoads)
    192     return {};
    193 
    194   // Add non-overlapping loads.
    195   LoadEntryVector LoadSequence;
    196   uint64_t Offset = 0;
    197   for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) {
    198     LoadSequence.push_back({MaxLoadSize, Offset});
    199     Offset += MaxLoadSize;
    200   }
    201 
    202   // Add the last overlapping load.
    203   assert(Size > 0 && Size < MaxLoadSize && "broken invariant");
    204   LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)});
    205   NumLoadsNonOneByte = 1;
    206   return LoadSequence;
    207 }
    208 
    209 // Initialize the basic block structure required for expansion of memcmp call
    210 // with given maximum load size and memcmp size parameter.
    211 // This structure includes:
    212 // 1. A list of load compare blocks - LoadCmpBlocks.
    213 // 2. An EndBlock, split from original instruction point, which is the block to
    214 // return from.
    215 // 3. ResultBlock, block to branch to for early exit when a
    216 // LoadCmpBlock finds a difference.
    217 MemCmpExpansion::MemCmpExpansion(
    218     CallInst *const CI, uint64_t Size,
    219     const TargetTransformInfo::MemCmpExpansionOptions &Options,
    220     const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout,
    221     DomTreeUpdater *DTU)
    222     : CI(CI), Size(Size), MaxLoadSize(0), NumLoadsNonOneByte(0),
    223       NumLoadsPerBlockForZeroCmp(Options.NumLoadsPerBlock),
    224       IsUsedForZeroCmp(IsUsedForZeroCmp), DL(TheDataLayout), DTU(DTU),
    225       Builder(CI) {
    226   assert(Size > 0 && "zero blocks");
    227   // Scale the max size down if the target can load more bytes than we need.
    228   llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes);
    229   while (!LoadSizes.empty() && LoadSizes.front() > Size) {
    230     LoadSizes = LoadSizes.drop_front();
    231   }
    232   assert(!LoadSizes.empty() && "cannot load Size bytes");
    233   MaxLoadSize = LoadSizes.front();
    234   // Compute the decomposition.
    235   unsigned GreedyNumLoadsNonOneByte = 0;
    236   LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, Options.MaxNumLoads,
    237                                            GreedyNumLoadsNonOneByte);
    238   NumLoadsNonOneByte = GreedyNumLoadsNonOneByte;
    239   assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
    240   // If we allow overlapping loads and the load sequence is not already optimal,
    241   // use overlapping loads.
    242   if (Options.AllowOverlappingLoads &&
    243       (LoadSequence.empty() || LoadSequence.size() > 2)) {
    244     unsigned OverlappingNumLoadsNonOneByte = 0;
    245     auto OverlappingLoads = computeOverlappingLoadSequence(
    246         Size, MaxLoadSize, Options.MaxNumLoads, OverlappingNumLoadsNonOneByte);
    247     if (!OverlappingLoads.empty() &&
    248         (LoadSequence.empty() ||
    249          OverlappingLoads.size() < LoadSequence.size())) {
    250       LoadSequence = OverlappingLoads;
    251       NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte;
    252     }
    253   }
    254   assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
    255 }
    256 
    257 unsigned MemCmpExpansion::getNumBlocks() {
    258   if (IsUsedForZeroCmp)
    259     return getNumLoads() / NumLoadsPerBlockForZeroCmp +
    260            (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0);
    261   return getNumLoads();
    262 }
    263 
    264 void MemCmpExpansion::createLoadCmpBlocks() {
    265   for (unsigned i = 0; i < getNumBlocks(); i++) {
    266     BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
    267                                         EndBlock->getParent(), EndBlock);
    268     LoadCmpBlocks.push_back(BB);
    269   }
    270 }
    271 
    272 void MemCmpExpansion::createResultBlock() {
    273   ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block",
    274                                    EndBlock->getParent(), EndBlock);
    275 }
    276 
    277 MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
    278                                                        bool NeedsBSwap,
    279                                                        Type *CmpSizeType,
    280                                                        unsigned OffsetBytes) {
    281   // Get the memory source at offset `OffsetBytes`.
    282   Value *LhsSource = CI->getArgOperand(0);
    283   Value *RhsSource = CI->getArgOperand(1);
    284   Align LhsAlign = LhsSource->getPointerAlignment(DL);
    285   Align RhsAlign = RhsSource->getPointerAlignment(DL);
    286   if (OffsetBytes > 0) {
    287     auto *ByteType = Type::getInt8Ty(CI->getContext());
    288     LhsSource = Builder.CreateConstGEP1_64(
    289         ByteType, Builder.CreateBitCast(LhsSource, ByteType->getPointerTo()),
    290         OffsetBytes);
    291     RhsSource = Builder.CreateConstGEP1_64(
    292         ByteType, Builder.CreateBitCast(RhsSource, ByteType->getPointerTo()),
    293         OffsetBytes);
    294     LhsAlign = commonAlignment(LhsAlign, OffsetBytes);
    295     RhsAlign = commonAlignment(RhsAlign, OffsetBytes);
    296   }
    297   LhsSource = Builder.CreateBitCast(LhsSource, LoadSizeType->getPointerTo());
    298   RhsSource = Builder.CreateBitCast(RhsSource, LoadSizeType->getPointerTo());
    299 
    300   // Create a constant or a load from the source.
    301   Value *Lhs = nullptr;
    302   if (auto *C = dyn_cast<Constant>(LhsSource))
    303     Lhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL);
    304   if (!Lhs)
    305     Lhs = Builder.CreateAlignedLoad(LoadSizeType, LhsSource, LhsAlign);
    306 
    307   Value *Rhs = nullptr;
    308   if (auto *C = dyn_cast<Constant>(RhsSource))
    309     Rhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL);
    310   if (!Rhs)
    311     Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign);
    312 
    313   // Swap bytes if required.
    314   if (NeedsBSwap) {
    315     Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
    316                                                 Intrinsic::bswap, LoadSizeType);
    317     Lhs = Builder.CreateCall(Bswap, Lhs);
    318     Rhs = Builder.CreateCall(Bswap, Rhs);
    319   }
    320 
    321   // Zero extend if required.
    322   if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) {
    323     Lhs = Builder.CreateZExt(Lhs, CmpSizeType);
    324     Rhs = Builder.CreateZExt(Rhs, CmpSizeType);
    325   }
    326   return {Lhs, Rhs};
    327 }
    328 
    329 // This function creates the IR instructions for loading and comparing 1 byte.
    330 // It loads 1 byte from each source of the memcmp parameters with the given
    331 // GEPIndex. It then subtracts the two loaded values and adds this result to the
    332 // final phi node for selecting the memcmp result.
    333 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
    334                                                unsigned OffsetBytes) {
    335   BasicBlock *BB = LoadCmpBlocks[BlockIndex];
    336   Builder.SetInsertPoint(BB);
    337   const LoadPair Loads =
    338       getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
    339                   Type::getInt32Ty(CI->getContext()), OffsetBytes);
    340   Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);
    341 
    342   PhiRes->addIncoming(Diff, BB);
    343 
    344   if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
    345     // Early exit branch if difference found to EndBlock. Otherwise, continue to
    346     // next LoadCmpBlock,
    347     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
    348                                     ConstantInt::get(Diff->getType(), 0));
    349     BranchInst *CmpBr =
    350         BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
    351     if (DTU)
    352       DTU->applyUpdates(
    353           {{DominatorTree::Insert, BB, EndBlock},
    354            {DominatorTree::Insert, BB, LoadCmpBlocks[BlockIndex + 1]}});
    355     Builder.Insert(CmpBr);
    356   } else {
    357     // The last block has an unconditional branch to EndBlock.
    358     BranchInst *CmpBr = BranchInst::Create(EndBlock);
    359     if (DTU)
    360       DTU->applyUpdates({{DominatorTree::Insert, BB, EndBlock}});
    361     Builder.Insert(CmpBr);
    362   }
    363 }
    364 
    365 /// Generate an equality comparison for one or more pairs of loaded values.
    366 /// This is used in the case where the memcmp() call is compared equal or not
    367 /// equal to zero.
    368 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
    369                                             unsigned &LoadIndex) {
    370   assert(LoadIndex < getNumLoads() &&
    371          "getCompareLoadPairs() called with no remaining loads");
    372   std::vector<Value *> XorList, OrList;
    373   Value *Diff = nullptr;
    374 
    375   const unsigned NumLoads =
    376       std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp);
    377 
    378   // For a single-block expansion, start inserting before the memcmp call.
    379   if (LoadCmpBlocks.empty())
    380     Builder.SetInsertPoint(CI);
    381   else
    382     Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
    383 
    384   Value *Cmp = nullptr;
    385   // If we have multiple loads per block, we need to generate a composite
    386   // comparison using xor+or. The type for the combinations is the largest load
    387   // type.
    388   IntegerType *const MaxLoadType =
    389       NumLoads == 1 ? nullptr
    390                     : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
    391   for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
    392     const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
    393     const LoadPair Loads = getLoadPair(
    394         IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
    395         /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset);
    396 
    397     if (NumLoads != 1) {
    398       // If we have multiple loads per block, we need to generate a composite
    399       // comparison using xor+or.
    400       Diff = Builder.CreateXor(Loads.Lhs, Loads.Rhs);
    401       Diff = Builder.CreateZExt(Diff, MaxLoadType);
    402       XorList.push_back(Diff);
    403     } else {
    404       // If there's only one load per block, we just compare the loaded values.
    405       Cmp = Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs);
    406     }
    407   }
    408 
    409   auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
    410     std::vector<Value *> OutList;
    411     for (unsigned i = 0; i < InList.size() - 1; i = i + 2) {
    412       Value *Or = Builder.CreateOr(InList[i], InList[i + 1]);
    413       OutList.push_back(Or);
    414     }
    415     if (InList.size() % 2 != 0)
    416       OutList.push_back(InList.back());
    417     return OutList;
    418   };
    419 
    420   if (!Cmp) {
    421     // Pairwise OR the XOR results.
    422     OrList = pairWiseOr(XorList);
    423 
    424     // Pairwise OR the OR results until one result left.
    425     while (OrList.size() != 1) {
    426       OrList = pairWiseOr(OrList);
    427     }
    428 
    429     assert(Diff && "Failed to find comparison diff");
    430     Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0));
    431   }
    432 
    433   return Cmp;
    434 }
    435 
    436 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
    437                                                         unsigned &LoadIndex) {
    438   Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
    439 
    440   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
    441                            ? EndBlock
    442                            : LoadCmpBlocks[BlockIndex + 1];
    443   // Early exit branch if difference found to ResultBlock. Otherwise,
    444   // continue to next LoadCmpBlock or EndBlock.
    445   BasicBlock *BB = Builder.GetInsertBlock();
    446   BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
    447   Builder.Insert(CmpBr);
    448   if (DTU)
    449     DTU->applyUpdates({{DominatorTree::Insert, BB, ResBlock.BB},
    450                        {DominatorTree::Insert, BB, NextBB}});
    451 
    452   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
    453   // since early exit to ResultBlock was not taken (no difference was found in
    454   // any of the bytes).
    455   if (BlockIndex == LoadCmpBlocks.size() - 1) {
    456     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
    457     PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
    458   }
    459 }
    460 
    461 // This function creates the IR intructions for loading and comparing using the
    462 // given LoadSize. It loads the number of bytes specified by LoadSize from each
    463 // source of the memcmp parameters. It then does a subtract to see if there was
    464 // a difference in the loaded values. If a difference is found, it branches
    465 // with an early exit to the ResultBlock for calculating which source was
    466 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or
    467 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
    468 // a special case through emitLoadCompareByteBlock. The special handling can
    469 // simply subtract the loaded values and add it to the result phi node.
    470 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
    471   // There is one load per block in this case, BlockIndex == LoadIndex.
    472   const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
    473 
    474   if (CurLoadEntry.LoadSize == 1) {
    475     MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset);
    476     return;
    477   }
    478 
    479   Type *LoadSizeType =
    480       IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
    481   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
    482   assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
    483 
    484   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
    485 
    486   const LoadPair Loads =
    487       getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType,
    488                   CurLoadEntry.Offset);
    489 
    490   // Add the loaded values to the phi nodes for calculating memcmp result only
    491   // if result is not used in a zero equality.
    492   if (!IsUsedForZeroCmp) {
    493     ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]);
    494     ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]);
    495   }
    496 
    497   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs);
    498   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
    499                            ? EndBlock
    500                            : LoadCmpBlocks[BlockIndex + 1];
    501   // Early exit branch if difference found to ResultBlock. Otherwise, continue
    502   // to next LoadCmpBlock or EndBlock.
    503   BasicBlock *BB = Builder.GetInsertBlock();
    504   BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
    505   Builder.Insert(CmpBr);
    506   if (DTU)
    507     DTU->applyUpdates({{DominatorTree::Insert, BB, NextBB},
    508                        {DominatorTree::Insert, BB, ResBlock.BB}});
    509 
    510   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
    511   // since early exit to ResultBlock was not taken (no difference was found in
    512   // any of the bytes).
    513   if (BlockIndex == LoadCmpBlocks.size() - 1) {
    514     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
    515     PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
    516   }
    517 }
    518 
    519 // This function populates the ResultBlock with a sequence to calculate the
    520 // memcmp result. It compares the two loaded source values and returns -1 if
    521 // src1 < src2 and 1 if src1 > src2.
    522 void MemCmpExpansion::emitMemCmpResultBlock() {
    523   // Special case: if memcmp result is used in a zero equality, result does not
    524   // need to be calculated and can simply return 1.
    525   if (IsUsedForZeroCmp) {
    526     BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
    527     Builder.SetInsertPoint(ResBlock.BB, InsertPt);
    528     Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1);
    529     PhiRes->addIncoming(Res, ResBlock.BB);
    530     BranchInst *NewBr = BranchInst::Create(EndBlock);
    531     Builder.Insert(NewBr);
    532     if (DTU)
    533       DTU->applyUpdates({{DominatorTree::Insert, ResBlock.BB, EndBlock}});
    534     return;
    535   }
    536   BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
    537   Builder.SetInsertPoint(ResBlock.BB, InsertPt);
    538 
    539   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1,
    540                                   ResBlock.PhiSrc2);
    541 
    542   Value *Res =
    543       Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
    544                            ConstantInt::get(Builder.getInt32Ty(), 1));
    545 
    546   PhiRes->addIncoming(Res, ResBlock.BB);
    547   BranchInst *NewBr = BranchInst::Create(EndBlock);
    548   Builder.Insert(NewBr);
    549   if (DTU)
    550     DTU->applyUpdates({{DominatorTree::Insert, ResBlock.BB, EndBlock}});
    551 }
    552 
    553 void MemCmpExpansion::setupResultBlockPHINodes() {
    554   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
    555   Builder.SetInsertPoint(ResBlock.BB);
    556   // Note: this assumes one load per block.
    557   ResBlock.PhiSrc1 =
    558       Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
    559   ResBlock.PhiSrc2 =
    560       Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
    561 }
    562 
    563 void MemCmpExpansion::setupEndBlockPHINodes() {
    564   Builder.SetInsertPoint(&EndBlock->front());
    565   PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
    566 }
    567 
    568 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
    569   unsigned LoadIndex = 0;
    570   // This loop populates each of the LoadCmpBlocks with the IR sequence to
    571   // handle multiple loads per block.
    572   for (unsigned I = 0; I < getNumBlocks(); ++I) {
    573     emitLoadCompareBlockMultipleLoads(I, LoadIndex);
    574   }
    575 
    576   emitMemCmpResultBlock();
    577   return PhiRes;
    578 }
    579 
    580 /// A memcmp expansion that compares equality with 0 and only has one block of
    581 /// load and compare can bypass the compare, branch, and phi IR that is required
    582 /// in the general case.
    583 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
    584   unsigned LoadIndex = 0;
    585   Value *Cmp = getCompareLoadPairs(0, LoadIndex);
    586   assert(LoadIndex == getNumLoads() && "some entries were not consumed");
    587   return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
    588 }
    589 
    590 /// A memcmp expansion that only has one block of load and compare can bypass
    591 /// the compare, branch, and phi IR that is required in the general case.
    592 Value *MemCmpExpansion::getMemCmpOneBlock() {
    593   Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
    594   bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
    595 
    596   // The i8 and i16 cases don't need compares. We zext the loaded values and
    597   // subtract them to get the suitable negative, zero, or positive i32 result.
    598   if (Size < 4) {
    599     const LoadPair Loads =
    600         getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(),
    601                     /*Offset*/ 0);
    602     return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
    603   }
    604 
    605   const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
    606                                      /*Offset*/ 0);
    607   // The result of memcmp is negative, zero, or positive, so produce that by
    608   // subtracting 2 extended compare bits: sub (ugt, ult).
    609   // If a target prefers to use selects to get -1/0/1, they should be able
    610   // to transform this later. The inverse transform (going from selects to math)
    611   // may not be possible in the DAG because the selects got converted into
    612   // branches before we got there.
    613   Value *CmpUGT = Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs);
    614   Value *CmpULT = Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs);
    615   Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
    616   Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
    617   return Builder.CreateSub(ZextUGT, ZextULT);
    618 }
    619 
    620 // This function expands the memcmp call into an inline expansion and returns
    621 // the memcmp result.
    622 Value *MemCmpExpansion::getMemCmpExpansion() {
    623   // Create the basic block framework for a multi-block expansion.
    624   if (getNumBlocks() != 1) {
    625     BasicBlock *StartBlock = CI->getParent();
    626     EndBlock = SplitBlock(StartBlock, CI, DTU, /*LI=*/nullptr,
    627                           /*MSSAU=*/nullptr, "endblock");
    628     setupEndBlockPHINodes();
    629     createResultBlock();
    630 
    631     // If return value of memcmp is not used in a zero equality, we need to
    632     // calculate which source was larger. The calculation requires the
    633     // two loaded source values of each load compare block.
    634     // These will be saved in the phi nodes created by setupResultBlockPHINodes.
    635     if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
    636 
    637     // Create the number of required load compare basic blocks.
    638     createLoadCmpBlocks();
    639 
    640     // Update the terminator added by SplitBlock to branch to the first
    641     // LoadCmpBlock.
    642     StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
    643     if (DTU)
    644       DTU->applyUpdates({{DominatorTree::Insert, StartBlock, LoadCmpBlocks[0]},
    645                          {DominatorTree::Delete, StartBlock, EndBlock}});
    646   }
    647 
    648   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    649 
    650   if (IsUsedForZeroCmp)
    651     return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
    652                                : getMemCmpExpansionZeroCase();
    653 
    654   if (getNumBlocks() == 1)
    655     return getMemCmpOneBlock();
    656 
    657   for (unsigned I = 0; I < getNumBlocks(); ++I) {
    658     emitLoadCompareBlock(I);
    659   }
    660 
    661   emitMemCmpResultBlock();
    662   return PhiRes;
    663 }
    664 
    665 // This function checks to see if an expansion of memcmp can be generated.
    666 // It checks for constant compare size that is less than the max inline size.
    667 // If an expansion cannot occur, returns false to leave as a library call.
    668 // Otherwise, the library call is replaced with a new IR instruction sequence.
    669 /// We want to transform:
    670 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
    671 /// To:
    672 /// loadbb:
    673 ///  %0 = bitcast i32* %buffer2 to i8*
    674 ///  %1 = bitcast i32* %buffer1 to i8*
    675 ///  %2 = bitcast i8* %1 to i64*
    676 ///  %3 = bitcast i8* %0 to i64*
    677 ///  %4 = load i64, i64* %2
    678 ///  %5 = load i64, i64* %3
    679 ///  %6 = call i64 @llvm.bswap.i64(i64 %4)
    680 ///  %7 = call i64 @llvm.bswap.i64(i64 %5)
    681 ///  %8 = sub i64 %6, %7
    682 ///  %9 = icmp ne i64 %8, 0
    683 ///  br i1 %9, label %res_block, label %loadbb1
    684 /// res_block:                                        ; preds = %loadbb2,
    685 /// %loadbb1, %loadbb
    686 ///  %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
    687 ///  %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
    688 ///  %10 = icmp ult i64 %phi.src1, %phi.src2
    689 ///  %11 = select i1 %10, i32 -1, i32 1
    690 ///  br label %endblock
    691 /// loadbb1:                                          ; preds = %loadbb
    692 ///  %12 = bitcast i32* %buffer2 to i8*
    693 ///  %13 = bitcast i32* %buffer1 to i8*
    694 ///  %14 = bitcast i8* %13 to i32*
    695 ///  %15 = bitcast i8* %12 to i32*
    696 ///  %16 = getelementptr i32, i32* %14, i32 2
    697 ///  %17 = getelementptr i32, i32* %15, i32 2
    698 ///  %18 = load i32, i32* %16
    699 ///  %19 = load i32, i32* %17
    700 ///  %20 = call i32 @llvm.bswap.i32(i32 %18)
    701 ///  %21 = call i32 @llvm.bswap.i32(i32 %19)
    702 ///  %22 = zext i32 %20 to i64
    703 ///  %23 = zext i32 %21 to i64
    704 ///  %24 = sub i64 %22, %23
    705 ///  %25 = icmp ne i64 %24, 0
    706 ///  br i1 %25, label %res_block, label %loadbb2
    707 /// loadbb2:                                          ; preds = %loadbb1
    708 ///  %26 = bitcast i32* %buffer2 to i8*
    709 ///  %27 = bitcast i32* %buffer1 to i8*
    710 ///  %28 = bitcast i8* %27 to i16*
    711 ///  %29 = bitcast i8* %26 to i16*
    712 ///  %30 = getelementptr i16, i16* %28, i16 6
    713 ///  %31 = getelementptr i16, i16* %29, i16 6
    714 ///  %32 = load i16, i16* %30
    715 ///  %33 = load i16, i16* %31
    716 ///  %34 = call i16 @llvm.bswap.i16(i16 %32)
    717 ///  %35 = call i16 @llvm.bswap.i16(i16 %33)
    718 ///  %36 = zext i16 %34 to i64
    719 ///  %37 = zext i16 %35 to i64
    720 ///  %38 = sub i64 %36, %37
    721 ///  %39 = icmp ne i64 %38, 0
    722 ///  br i1 %39, label %res_block, label %loadbb3
    723 /// loadbb3:                                          ; preds = %loadbb2
    724 ///  %40 = bitcast i32* %buffer2 to i8*
    725 ///  %41 = bitcast i32* %buffer1 to i8*
    726 ///  %42 = getelementptr i8, i8* %41, i8 14
    727 ///  %43 = getelementptr i8, i8* %40, i8 14
    728 ///  %44 = load i8, i8* %42
    729 ///  %45 = load i8, i8* %43
    730 ///  %46 = zext i8 %44 to i32
    731 ///  %47 = zext i8 %45 to i32
    732 ///  %48 = sub i32 %46, %47
    733 ///  br label %endblock
    734 /// endblock:                                         ; preds = %res_block,
    735 /// %loadbb3
    736 ///  %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
    737 ///  ret i32 %phi.res
    738 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
    739                          const TargetLowering *TLI, const DataLayout *DL,
    740                          ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI,
    741                          DomTreeUpdater *DTU) {
    742   NumMemCmpCalls++;
    743 
    744   // Early exit from expansion if -Oz.
    745   if (CI->getFunction()->hasMinSize())
    746     return false;
    747 
    748   // Early exit from expansion if size is not a constant.
    749   ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2));
    750   if (!SizeCast) {
    751     NumMemCmpNotConstant++;
    752     return false;
    753   }
    754   const uint64_t SizeVal = SizeCast->getZExtValue();
    755 
    756   if (SizeVal == 0) {
    757     return false;
    758   }
    759   // TTI call to check if target would like to expand memcmp. Also, get the
    760   // available load sizes.
    761   const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
    762   bool OptForSize = CI->getFunction()->hasOptSize() ||
    763                     llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI);
    764   auto Options = TTI->enableMemCmpExpansion(OptForSize,
    765                                             IsUsedForZeroCmp);
    766   if (!Options) return false;
    767 
    768   if (MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences())
    769     Options.NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock;
    770 
    771   if (OptForSize &&
    772       MaxLoadsPerMemcmpOptSize.getNumOccurrences())
    773     Options.MaxNumLoads = MaxLoadsPerMemcmpOptSize;
    774 
    775   if (!OptForSize && MaxLoadsPerMemcmp.getNumOccurrences())
    776     Options.MaxNumLoads = MaxLoadsPerMemcmp;
    777 
    778   MemCmpExpansion Expansion(CI, SizeVal, Options, IsUsedForZeroCmp, *DL, DTU);
    779 
    780   // Don't expand if this will require more loads than desired by the target.
    781   if (Expansion.getNumLoads() == 0) {
    782     NumMemCmpGreaterThanMax++;
    783     return false;
    784   }
    785 
    786   NumMemCmpInlined++;
    787 
    788   Value *Res = Expansion.getMemCmpExpansion();
    789 
    790   // Replace call with result of expansion and erase call.
    791   CI->replaceAllUsesWith(Res);
    792   CI->eraseFromParent();
    793 
    794   return true;
    795 }
    796 
    797 class ExpandMemCmpPass : public FunctionPass {
    798 public:
    799   static char ID;
    800 
    801   ExpandMemCmpPass() : FunctionPass(ID) {
    802     initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry());
    803   }
    804 
    805   bool runOnFunction(Function &F) override {
    806     if (skipFunction(F)) return false;
    807 
    808     auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
    809     if (!TPC) {
    810       return false;
    811     }
    812     const TargetLowering* TL =
    813         TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering();
    814 
    815     const TargetLibraryInfo *TLI =
    816         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
    817     const TargetTransformInfo *TTI =
    818         &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
    819     auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
    820     auto *BFI = (PSI && PSI->hasProfileSummary()) ?
    821            &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() :
    822            nullptr;
    823     DominatorTree *DT = nullptr;
    824     if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
    825       DT = &DTWP->getDomTree();
    826     auto PA = runImpl(F, TLI, TTI, TL, PSI, BFI, DT);
    827     return !PA.areAllPreserved();
    828   }
    829 
    830 private:
    831   void getAnalysisUsage(AnalysisUsage &AU) const override {
    832     AU.addRequired<TargetLibraryInfoWrapperPass>();
    833     AU.addRequired<TargetTransformInfoWrapperPass>();
    834     AU.addRequired<ProfileSummaryInfoWrapperPass>();
    835     AU.addPreserved<DominatorTreeWrapperPass>();
    836     LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU);
    837     FunctionPass::getAnalysisUsage(AU);
    838   }
    839 
    840   PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
    841                             const TargetTransformInfo *TTI,
    842                             const TargetLowering *TL, ProfileSummaryInfo *PSI,
    843                             BlockFrequencyInfo *BFI, DominatorTree *DT);
    844   // Returns true if a change was made.
    845   bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
    846                   const TargetTransformInfo *TTI, const TargetLowering *TL,
    847                   const DataLayout &DL, ProfileSummaryInfo *PSI,
    848                   BlockFrequencyInfo *BFI, DomTreeUpdater *DTU);
    849 };
    850 
    851 bool ExpandMemCmpPass::runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
    852                                   const TargetTransformInfo *TTI,
    853                                   const TargetLowering *TL,
    854                                   const DataLayout &DL, ProfileSummaryInfo *PSI,
    855                                   BlockFrequencyInfo *BFI,
    856                                   DomTreeUpdater *DTU) {
    857   for (Instruction& I : BB) {
    858     CallInst *CI = dyn_cast<CallInst>(&I);
    859     if (!CI) {
    860       continue;
    861     }
    862     LibFunc Func;
    863     if (TLI->getLibFunc(*CI, Func) &&
    864         (Func == LibFunc_memcmp || Func == LibFunc_bcmp) &&
    865         expandMemCmp(CI, TTI, TL, &DL, PSI, BFI, DTU)) {
    866       return true;
    867     }
    868   }
    869   return false;
    870 }
    871 
    872 PreservedAnalyses
    873 ExpandMemCmpPass::runImpl(Function &F, const TargetLibraryInfo *TLI,
    874                           const TargetTransformInfo *TTI,
    875                           const TargetLowering *TL, ProfileSummaryInfo *PSI,
    876                           BlockFrequencyInfo *BFI, DominatorTree *DT) {
    877   Optional<DomTreeUpdater> DTU;
    878   if (DT)
    879     DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
    880 
    881   const DataLayout& DL = F.getParent()->getDataLayout();
    882   bool MadeChanges = false;
    883   for (auto BBIt = F.begin(); BBIt != F.end();) {
    884     if (runOnBlock(*BBIt, TLI, TTI, TL, DL, PSI, BFI,
    885                    DTU.hasValue() ? DTU.getPointer() : nullptr)) {
    886       MadeChanges = true;
    887       // If changes were made, restart the function from the beginning, since
    888       // the structure of the function was changed.
    889       BBIt = F.begin();
    890     } else {
    891       ++BBIt;
    892     }
    893   }
    894   if (MadeChanges)
    895     for (BasicBlock &BB : F)
    896       SimplifyInstructionsInBlock(&BB);
    897   if (!MadeChanges)
    898     return PreservedAnalyses::all();
    899   PreservedAnalyses PA;
    900   PA.preserve<DominatorTreeAnalysis>();
    901   return PA;
    902 }
    903 
    904 } // namespace
    905 
    906 char ExpandMemCmpPass::ID = 0;
    907 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp",
    908                       "Expand memcmp() to load/stores", false, false)
    909 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
    910 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
    911 INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass)
    912 INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
    913 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
    914 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp",
    915                     "Expand memcmp() to load/stores", false, false)
    916 
    917 FunctionPass *llvm::createExpandMemCmpPass() {
    918   return new ExpandMemCmpPass();
    919 }
    920