Home | History | Annotate | Line # | Download | only in ADT
      1 //===- llvm/ADT/EquivalenceClasses.h - Generic Equiv. Classes ---*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // Generic implementation of equivalence classes through the use Tarjan's
     10 // efficient union-find algorithm.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #ifndef LLVM_ADT_EQUIVALENCECLASSES_H
     15 #define LLVM_ADT_EQUIVALENCECLASSES_H
     16 
     17 #include <cassert>
     18 #include <cstddef>
     19 #include <cstdint>
     20 #include <iterator>
     21 #include <set>
     22 
     23 namespace llvm {
     24 
     25 /// EquivalenceClasses - This represents a collection of equivalence classes and
     26 /// supports three efficient operations: insert an element into a class of its
     27 /// own, union two classes, and find the class for a given element.  In
     28 /// addition to these modification methods, it is possible to iterate over all
     29 /// of the equivalence classes and all of the elements in a class.
     30 ///
     31 /// This implementation is an efficient implementation that only stores one copy
     32 /// of the element being indexed per entry in the set, and allows any arbitrary
     33 /// type to be indexed (as long as it can be ordered with operator<).
     34 ///
     35 /// Here is a simple example using integers:
     36 ///
     37 /// \code
     38 ///  EquivalenceClasses<int> EC;
     39 ///  EC.unionSets(1, 2);                // insert 1, 2 into the same set
     40 ///  EC.insert(4); EC.insert(5);        // insert 4, 5 into own sets
     41 ///  EC.unionSets(5, 1);                // merge the set for 1 with 5's set.
     42 ///
     43 ///  for (EquivalenceClasses<int>::iterator I = EC.begin(), E = EC.end();
     44 ///       I != E; ++I) {           // Iterate over all of the equivalence sets.
     45 ///    if (!I->isLeader()) continue;   // Ignore non-leader sets.
     46 ///    for (EquivalenceClasses<int>::member_iterator MI = EC.member_begin(I);
     47 ///         MI != EC.member_end(); ++MI)   // Loop over members in this set.
     48 ///      cerr << *MI << " ";  // Print member.
     49 ///    cerr << "\n";   // Finish set.
     50 ///  }
     51 /// \endcode
     52 ///
     53 /// This example prints:
     54 ///   4
     55 ///   5 1 2
     56 ///
     57 template <class ElemTy>
     58 class EquivalenceClasses {
     59   /// ECValue - The EquivalenceClasses data structure is just a set of these.
     60   /// Each of these represents a relation for a value.  First it stores the
     61   /// value itself, which provides the ordering that the set queries.  Next, it
     62   /// provides a "next pointer", which is used to enumerate all of the elements
     63   /// in the unioned set.  Finally, it defines either a "end of list pointer" or
     64   /// "leader pointer" depending on whether the value itself is a leader.  A
     65   /// "leader pointer" points to the node that is the leader for this element,
     66   /// if the node is not a leader.  A "end of list pointer" points to the last
     67   /// node in the list of members of this list.  Whether or not a node is a
     68   /// leader is determined by a bit stolen from one of the pointers.
     69   class ECValue {
     70     friend class EquivalenceClasses;
     71 
     72     mutable const ECValue *Leader, *Next;
     73     ElemTy Data;
     74 
     75     // ECValue ctor - Start out with EndOfList pointing to this node, Next is
     76     // Null, isLeader = true.
     77     ECValue(const ElemTy &Elt)
     78       : Leader(this), Next((ECValue*)(intptr_t)1), Data(Elt) {}
     79 
     80     const ECValue *getLeader() const {
     81       if (isLeader()) return this;
     82       if (Leader->isLeader()) return Leader;
     83       // Path compression.
     84       return Leader = Leader->getLeader();
     85     }
     86 
     87     const ECValue *getEndOfList() const {
     88       assert(isLeader() && "Cannot get the end of a list for a non-leader!");
     89       return Leader;
     90     }
     91 
     92     void setNext(const ECValue *NewNext) const {
     93       assert(getNext() == nullptr && "Already has a next pointer!");
     94       Next = (const ECValue*)((intptr_t)NewNext | (intptr_t)isLeader());
     95     }
     96 
     97   public:
     98     ECValue(const ECValue &RHS) : Leader(this), Next((ECValue*)(intptr_t)1),
     99                                   Data(RHS.Data) {
    100       // Only support copying of singleton nodes.
    101       assert(RHS.isLeader() && RHS.getNext() == nullptr && "Not a singleton!");
    102     }
    103 
    104     bool operator<(const ECValue &UFN) const { return Data < UFN.Data; }
    105 
    106     bool isLeader() const { return (intptr_t)Next & 1; }
    107     const ElemTy &getData() const { return Data; }
    108 
    109     const ECValue *getNext() const {
    110       return (ECValue*)((intptr_t)Next & ~(intptr_t)1);
    111     }
    112 
    113     template<typename T>
    114     bool operator<(const T &Val) const { return Data < Val; }
    115   };
    116 
    117   /// TheMapping - This implicitly provides a mapping from ElemTy values to the
    118   /// ECValues, it just keeps the key as part of the value.
    119   std::set<ECValue> TheMapping;
    120 
    121 public:
    122   EquivalenceClasses() = default;
    123   EquivalenceClasses(const EquivalenceClasses &RHS) {
    124     operator=(RHS);
    125   }
    126 
    127   const EquivalenceClasses &operator=(const EquivalenceClasses &RHS) {
    128     TheMapping.clear();
    129     for (iterator I = RHS.begin(), E = RHS.end(); I != E; ++I)
    130       if (I->isLeader()) {
    131         member_iterator MI = RHS.member_begin(I);
    132         member_iterator LeaderIt = member_begin(insert(*MI));
    133         for (++MI; MI != member_end(); ++MI)
    134           unionSets(LeaderIt, member_begin(insert(*MI)));
    135       }
    136     return *this;
    137   }
    138 
    139   //===--------------------------------------------------------------------===//
    140   // Inspection methods
    141   //
    142 
    143   /// iterator* - Provides a way to iterate over all values in the set.
    144   using iterator = typename std::set<ECValue>::const_iterator;
    145 
    146   iterator begin() const { return TheMapping.begin(); }
    147   iterator end() const { return TheMapping.end(); }
    148 
    149   bool empty() const { return TheMapping.empty(); }
    150 
    151   /// member_* Iterate over the members of an equivalence class.
    152   class member_iterator;
    153   member_iterator member_begin(iterator I) const {
    154     // Only leaders provide anything to iterate over.
    155     return member_iterator(I->isLeader() ? &*I : nullptr);
    156   }
    157   member_iterator member_end() const {
    158     return member_iterator(nullptr);
    159   }
    160 
    161   /// findValue - Return an iterator to the specified value.  If it does not
    162   /// exist, end() is returned.
    163   iterator findValue(const ElemTy &V) const {
    164     return TheMapping.find(V);
    165   }
    166 
    167   /// getLeaderValue - Return the leader for the specified value that is in the
    168   /// set.  It is an error to call this method for a value that is not yet in
    169   /// the set.  For that, call getOrInsertLeaderValue(V).
    170   const ElemTy &getLeaderValue(const ElemTy &V) const {
    171     member_iterator MI = findLeader(V);
    172     assert(MI != member_end() && "Value is not in the set!");
    173     return *MI;
    174   }
    175 
    176   /// getOrInsertLeaderValue - Return the leader for the specified value that is
    177   /// in the set.  If the member is not in the set, it is inserted, then
    178   /// returned.
    179   const ElemTy &getOrInsertLeaderValue(const ElemTy &V) {
    180     member_iterator MI = findLeader(insert(V));
    181     assert(MI != member_end() && "Value is not in the set!");
    182     return *MI;
    183   }
    184 
    185   /// getNumClasses - Return the number of equivalence classes in this set.
    186   /// Note that this is a linear time operation.
    187   unsigned getNumClasses() const {
    188     unsigned NC = 0;
    189     for (iterator I = begin(), E = end(); I != E; ++I)
    190       if (I->isLeader()) ++NC;
    191     return NC;
    192   }
    193 
    194   //===--------------------------------------------------------------------===//
    195   // Mutation methods
    196 
    197   /// insert - Insert a new value into the union/find set, ignoring the request
    198   /// if the value already exists.
    199   iterator insert(const ElemTy &Data) {
    200     return TheMapping.insert(ECValue(Data)).first;
    201   }
    202 
    203   /// findLeader - Given a value in the set, return a member iterator for the
    204   /// equivalence class it is in.  This does the path-compression part that
    205   /// makes union-find "union findy".  This returns an end iterator if the value
    206   /// is not in the equivalence class.
    207   member_iterator findLeader(iterator I) const {
    208     if (I == TheMapping.end()) return member_end();
    209     return member_iterator(I->getLeader());
    210   }
    211   member_iterator findLeader(const ElemTy &V) const {
    212     return findLeader(TheMapping.find(V));
    213   }
    214 
    215   /// union - Merge the two equivalence sets for the specified values, inserting
    216   /// them if they do not already exist in the equivalence set.
    217   member_iterator unionSets(const ElemTy &V1, const ElemTy &V2) {
    218     iterator V1I = insert(V1), V2I = insert(V2);
    219     return unionSets(findLeader(V1I), findLeader(V2I));
    220   }
    221   member_iterator unionSets(member_iterator L1, member_iterator L2) {
    222     assert(L1 != member_end() && L2 != member_end() && "Illegal inputs!");
    223     if (L1 == L2) return L1;   // Unifying the same two sets, noop.
    224 
    225     // Otherwise, this is a real union operation.  Set the end of the L1 list to
    226     // point to the L2 leader node.
    227     const ECValue &L1LV = *L1.Node, &L2LV = *L2.Node;
    228     L1LV.getEndOfList()->setNext(&L2LV);
    229 
    230     // Update L1LV's end of list pointer.
    231     L1LV.Leader = L2LV.getEndOfList();
    232 
    233     // Clear L2's leader flag:
    234     L2LV.Next = L2LV.getNext();
    235 
    236     // L2's leader is now L1.
    237     L2LV.Leader = &L1LV;
    238     return L1;
    239   }
    240 
    241   // isEquivalent - Return true if V1 is equivalent to V2. This can happen if
    242   // V1 is equal to V2 or if they belong to one equivalence class.
    243   bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const {
    244     // Fast path: any element is equivalent to itself.
    245     if (V1 == V2)
    246       return true;
    247     auto It = findLeader(V1);
    248     return It != member_end() && It == findLeader(V2);
    249   }
    250 
    251   class member_iterator {
    252     friend class EquivalenceClasses;
    253 
    254     const ECValue *Node;
    255 
    256   public:
    257     using iterator_category = std::forward_iterator_tag;
    258     using value_type = const ElemTy;
    259     using size_type = std::size_t;
    260     using difference_type = std::ptrdiff_t;
    261     using pointer = value_type *;
    262     using reference = value_type &;
    263 
    264     explicit member_iterator() = default;
    265     explicit member_iterator(const ECValue *N) : Node(N) {}
    266 
    267     reference operator*() const {
    268       assert(Node != nullptr && "Dereferencing end()!");
    269       return Node->getData();
    270     }
    271     pointer operator->() const { return &operator*(); }
    272 
    273     member_iterator &operator++() {
    274       assert(Node != nullptr && "++'d off the end of the list!");
    275       Node = Node->getNext();
    276       return *this;
    277     }
    278 
    279     member_iterator operator++(int) {    // postincrement operators.
    280       member_iterator tmp = *this;
    281       ++*this;
    282       return tmp;
    283     }
    284 
    285     bool operator==(const member_iterator &RHS) const {
    286       return Node == RHS.Node;
    287     }
    288     bool operator!=(const member_iterator &RHS) const {
    289       return Node != RHS.Node;
    290     }
    291   };
    292 };
    293 
    294 } // end namespace llvm
    295 
    296 #endif // LLVM_ADT_EQUIVALENCECLASSES_H
    297