Home | History | Annotate | Line # | Download | only in IPO
      1 //===- WholeProgramDevirt.h - Whole-program devirt pass ---------*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file defines parts of the whole-program devirtualization pass
     10 // implementation that may be usefully unit tested.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #ifndef LLVM_TRANSFORMS_IPO_WHOLEPROGRAMDEVIRT_H
     15 #define LLVM_TRANSFORMS_IPO_WHOLEPROGRAMDEVIRT_H
     16 
     17 #include "llvm/IR/Module.h"
     18 #include "llvm/IR/PassManager.h"
     19 #include "llvm/Transforms/IPO/FunctionImport.h"
     20 #include <cassert>
     21 #include <cstdint>
     22 #include <set>
     23 #include <utility>
     24 #include <vector>
     25 
     26 namespace llvm {
     27 
     28 template <typename T> class ArrayRef;
     29 template <typename T> class MutableArrayRef;
     30 class Function;
     31 class GlobalVariable;
     32 class ModuleSummaryIndex;
     33 struct ValueInfo;
     34 
     35 namespace wholeprogramdevirt {
     36 
     37 // A bit vector that keeps track of which bits are used. We use this to
     38 // pack constant values compactly before and after each virtual table.
     39 struct AccumBitVector {
     40   std::vector<uint8_t> Bytes;
     41 
     42   // Bits in BytesUsed[I] are 1 if matching bit in Bytes[I] is used, 0 if not.
     43   std::vector<uint8_t> BytesUsed;
     44 
     45   std::pair<uint8_t *, uint8_t *> getPtrToData(uint64_t Pos, uint8_t Size) {
     46     if (Bytes.size() < Pos + Size) {
     47       Bytes.resize(Pos + Size);
     48       BytesUsed.resize(Pos + Size);
     49     }
     50     return std::make_pair(Bytes.data() + Pos, BytesUsed.data() + Pos);
     51   }
     52 
     53   // Set little-endian value Val with size Size at bit position Pos,
     54   // and mark bytes as used.
     55   void setLE(uint64_t Pos, uint64_t Val, uint8_t Size) {
     56     assert(Pos % 8 == 0);
     57     auto DataUsed = getPtrToData(Pos / 8, Size);
     58     for (unsigned I = 0; I != Size; ++I) {
     59       DataUsed.first[I] = Val >> (I * 8);
     60       assert(!DataUsed.second[I]);
     61       DataUsed.second[I] = 0xff;
     62     }
     63   }
     64 
     65   // Set big-endian value Val with size Size at bit position Pos,
     66   // and mark bytes as used.
     67   void setBE(uint64_t Pos, uint64_t Val, uint8_t Size) {
     68     assert(Pos % 8 == 0);
     69     auto DataUsed = getPtrToData(Pos / 8, Size);
     70     for (unsigned I = 0; I != Size; ++I) {
     71       DataUsed.first[Size - I - 1] = Val >> (I * 8);
     72       assert(!DataUsed.second[Size - I - 1]);
     73       DataUsed.second[Size - I - 1] = 0xff;
     74     }
     75   }
     76 
     77   // Set bit at bit position Pos to b and mark bit as used.
     78   void setBit(uint64_t Pos, bool b) {
     79     auto DataUsed = getPtrToData(Pos / 8, 1);
     80     if (b)
     81       *DataUsed.first |= 1 << (Pos % 8);
     82     assert(!(*DataUsed.second & (1 << Pos % 8)));
     83     *DataUsed.second |= 1 << (Pos % 8);
     84   }
     85 };
     86 
     87 // The bits that will be stored before and after a particular vtable.
     88 struct VTableBits {
     89   // The vtable global.
     90   GlobalVariable *GV;
     91 
     92   // Cache of the vtable's size in bytes.
     93   uint64_t ObjectSize = 0;
     94 
     95   // The bit vector that will be laid out before the vtable. Note that these
     96   // bytes are stored in reverse order until the globals are rebuilt. This means
     97   // that any values in the array must be stored using the opposite endianness
     98   // from the target.
     99   AccumBitVector Before;
    100 
    101   // The bit vector that will be laid out after the vtable.
    102   AccumBitVector After;
    103 };
    104 
    105 // Information about a member of a particular type identifier.
    106 struct TypeMemberInfo {
    107   // The VTableBits for the vtable.
    108   VTableBits *Bits;
    109 
    110   // The offset in bytes from the start of the vtable (i.e. the address point).
    111   uint64_t Offset;
    112 
    113   bool operator<(const TypeMemberInfo &other) const {
    114     return Bits < other.Bits || (Bits == other.Bits && Offset < other.Offset);
    115   }
    116 };
    117 
    118 // A virtual call target, i.e. an entry in a particular vtable.
    119 struct VirtualCallTarget {
    120   VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM);
    121 
    122   // For testing only.
    123   VirtualCallTarget(const TypeMemberInfo *TM, bool IsBigEndian)
    124       : Fn(nullptr), TM(TM), IsBigEndian(IsBigEndian), WasDevirt(false) {}
    125 
    126   // The function stored in the vtable.
    127   Function *Fn;
    128 
    129   // A pointer to the type identifier member through which the pointer to Fn is
    130   // accessed.
    131   const TypeMemberInfo *TM;
    132 
    133   // When doing virtual constant propagation, this stores the return value for
    134   // the function when passed the currently considered argument list.
    135   uint64_t RetVal;
    136 
    137   // Whether the target is big endian.
    138   bool IsBigEndian;
    139 
    140   // Whether at least one call site to the target was devirtualized.
    141   bool WasDevirt;
    142 
    143   // The minimum byte offset before the address point. This covers the bytes in
    144   // the vtable object before the address point (e.g. RTTI, access-to-top,
    145   // vtables for other base classes) and is equal to the offset from the start
    146   // of the vtable object to the address point.
    147   uint64_t minBeforeBytes() const { return TM->Offset; }
    148 
    149   // The minimum byte offset after the address point. This covers the bytes in
    150   // the vtable object after the address point (e.g. the vtable for the current
    151   // class and any later base classes) and is equal to the size of the vtable
    152   // object minus the offset from the start of the vtable object to the address
    153   // point.
    154   uint64_t minAfterBytes() const { return TM->Bits->ObjectSize - TM->Offset; }
    155 
    156   // The number of bytes allocated (for the vtable plus the byte array) before
    157   // the address point.
    158   uint64_t allocatedBeforeBytes() const {
    159     return minBeforeBytes() + TM->Bits->Before.Bytes.size();
    160   }
    161 
    162   // The number of bytes allocated (for the vtable plus the byte array) after
    163   // the address point.
    164   uint64_t allocatedAfterBytes() const {
    165     return minAfterBytes() + TM->Bits->After.Bytes.size();
    166   }
    167 
    168   // Set the bit at position Pos before the address point to RetVal.
    169   void setBeforeBit(uint64_t Pos) {
    170     assert(Pos >= 8 * minBeforeBytes());
    171     TM->Bits->Before.setBit(Pos - 8 * minBeforeBytes(), RetVal);
    172   }
    173 
    174   // Set the bit at position Pos after the address point to RetVal.
    175   void setAfterBit(uint64_t Pos) {
    176     assert(Pos >= 8 * minAfterBytes());
    177     TM->Bits->After.setBit(Pos - 8 * minAfterBytes(), RetVal);
    178   }
    179 
    180   // Set the bytes at position Pos before the address point to RetVal.
    181   // Because the bytes in Before are stored in reverse order, we use the
    182   // opposite endianness to the target.
    183   void setBeforeBytes(uint64_t Pos, uint8_t Size) {
    184     assert(Pos >= 8 * minBeforeBytes());
    185     if (IsBigEndian)
    186       TM->Bits->Before.setLE(Pos - 8 * minBeforeBytes(), RetVal, Size);
    187     else
    188       TM->Bits->Before.setBE(Pos - 8 * minBeforeBytes(), RetVal, Size);
    189   }
    190 
    191   // Set the bytes at position Pos after the address point to RetVal.
    192   void setAfterBytes(uint64_t Pos, uint8_t Size) {
    193     assert(Pos >= 8 * minAfterBytes());
    194     if (IsBigEndian)
    195       TM->Bits->After.setBE(Pos - 8 * minAfterBytes(), RetVal, Size);
    196     else
    197       TM->Bits->After.setLE(Pos - 8 * minAfterBytes(), RetVal, Size);
    198   }
    199 };
    200 
    201 // Find the minimum offset that we may store a value of size Size bits at. If
    202 // IsAfter is set, look for an offset before the object, otherwise look for an
    203 // offset after the object.
    204 uint64_t findLowestOffset(ArrayRef<VirtualCallTarget> Targets, bool IsAfter,
    205                           uint64_t Size);
    206 
    207 // Set the stored value in each of Targets to VirtualCallTarget::RetVal at the
    208 // given allocation offset before the vtable address. Stores the computed
    209 // byte/bit offset to OffsetByte/OffsetBit.
    210 void setBeforeReturnValues(MutableArrayRef<VirtualCallTarget> Targets,
    211                            uint64_t AllocBefore, unsigned BitWidth,
    212                            int64_t &OffsetByte, uint64_t &OffsetBit);
    213 
    214 // Set the stored value in each of Targets to VirtualCallTarget::RetVal at the
    215 // given allocation offset after the vtable address. Stores the computed
    216 // byte/bit offset to OffsetByte/OffsetBit.
    217 void setAfterReturnValues(MutableArrayRef<VirtualCallTarget> Targets,
    218                           uint64_t AllocAfter, unsigned BitWidth,
    219                           int64_t &OffsetByte, uint64_t &OffsetBit);
    220 
    221 } // end namespace wholeprogramdevirt
    222 
    223 struct WholeProgramDevirtPass : public PassInfoMixin<WholeProgramDevirtPass> {
    224   ModuleSummaryIndex *ExportSummary;
    225   const ModuleSummaryIndex *ImportSummary;
    226   bool UseCommandLine = false;
    227   WholeProgramDevirtPass()
    228       : ExportSummary(nullptr), ImportSummary(nullptr), UseCommandLine(true) {}
    229   WholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary,
    230                          const ModuleSummaryIndex *ImportSummary)
    231       : ExportSummary(ExportSummary), ImportSummary(ImportSummary) {
    232     assert(!(ExportSummary && ImportSummary));
    233   }
    234   PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
    235 };
    236 
    237 struct VTableSlotSummary {
    238   StringRef TypeID;
    239   uint64_t ByteOffset;
    240 };
    241 
    242 void updateVCallVisibilityInModule(
    243     Module &M, bool WholeProgramVisibilityEnabledInLTO,
    244     const DenseSet<GlobalValue::GUID> &DynamicExportSymbols);
    245 void updateVCallVisibilityInIndex(
    246     ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO,
    247     const DenseSet<GlobalValue::GUID> &DynamicExportSymbols);
    248 
    249 /// Perform index-based whole program devirtualization on the \p Summary
    250 /// index. Any devirtualized targets used by a type test in another module
    251 /// are added to the \p ExportedGUIDs set. For any local devirtualized targets
    252 /// only used within the defining module, the information necessary for
    253 /// locating the corresponding WPD resolution is recorded for the ValueInfo
    254 /// in case it is exported by cross module importing (in which case the
    255 /// devirtualized target name will need adjustment).
    256 void runWholeProgramDevirtOnIndex(
    257     ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs,
    258     std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap);
    259 
    260 /// Call after cross-module importing to update the recorded single impl
    261 /// devirt target names for any locals that were exported.
    262 void updateIndexWPDForExports(
    263     ModuleSummaryIndex &Summary,
    264     function_ref<bool(StringRef, ValueInfo)> isExported,
    265     std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap);
    266 
    267 } // end namespace llvm
    268 
    269 #endif // LLVM_TRANSFORMS_IPO_WHOLEPROGRAMDEVIRT_H
    270