Home | History | Annotate | Line # | Download | only in Support
      1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file contains a class for representing known zeros and ones used by
     10 // computeKnownBits.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "llvm/Support/KnownBits.h"
     15 #include "llvm/Support/Debug.h"
     16 #include "llvm/Support/raw_ostream.h"
     17 #include <cassert>
     18 
     19 using namespace llvm;
     20 
     21 static KnownBits computeForAddCarry(
     22     const KnownBits &LHS, const KnownBits &RHS,
     23     bool CarryZero, bool CarryOne) {
     24   assert(!(CarryZero && CarryOne) &&
     25          "Carry can't be zero and one at the same time");
     26 
     27   APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
     28   APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
     29 
     30   // Compute known bits of the carry.
     31   APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
     32   APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
     33 
     34   // Compute set of known bits (where all three relevant bits are known).
     35   APInt LHSKnownUnion = LHS.Zero | LHS.One;
     36   APInt RHSKnownUnion = RHS.Zero | RHS.One;
     37   APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
     38   APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
     39 
     40   assert((PossibleSumZero & Known) == (PossibleSumOne & Known) &&
     41          "known bits of sum differ");
     42 
     43   // Compute known bits of the result.
     44   KnownBits KnownOut;
     45   KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
     46   KnownOut.One = std::move(PossibleSumOne) & Known;
     47   return KnownOut;
     48 }
     49 
     50 KnownBits KnownBits::computeForAddCarry(
     51     const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
     52   assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
     53   return ::computeForAddCarry(
     54       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
     55 }
     56 
     57 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
     58                                       const KnownBits &LHS, KnownBits RHS) {
     59   KnownBits KnownOut;
     60   if (Add) {
     61     // Sum = LHS + RHS + 0
     62     KnownOut = ::computeForAddCarry(
     63         LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
     64   } else {
     65     // Sum = LHS + ~RHS + 1
     66     std::swap(RHS.Zero, RHS.One);
     67     KnownOut = ::computeForAddCarry(
     68         LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
     69   }
     70 
     71   // Are we still trying to solve for the sign bit?
     72   if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
     73     if (NSW) {
     74       // Adding two non-negative numbers, or subtracting a negative number from
     75       // a non-negative one, can't wrap into negative.
     76       if (LHS.isNonNegative() && RHS.isNonNegative())
     77         KnownOut.makeNonNegative();
     78       // Adding two negative numbers, or subtracting a non-negative number from
     79       // a negative one, can't wrap into non-negative.
     80       else if (LHS.isNegative() && RHS.isNegative())
     81         KnownOut.makeNegative();
     82     }
     83   }
     84 
     85   return KnownOut;
     86 }
     87 
     88 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
     89   unsigned BitWidth = getBitWidth();
     90   assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
     91          "Illegal sext-in-register");
     92 
     93   if (SrcBitWidth == BitWidth)
     94     return *this;
     95 
     96   unsigned ExtBits = BitWidth - SrcBitWidth;
     97   KnownBits Result;
     98   Result.One = One << ExtBits;
     99   Result.Zero = Zero << ExtBits;
    100   Result.One.ashrInPlace(ExtBits);
    101   Result.Zero.ashrInPlace(ExtBits);
    102   return Result;
    103 }
    104 
    105 KnownBits KnownBits::makeGE(const APInt &Val) const {
    106   // Count the number of leading bit positions where our underlying value is
    107   // known to be less than or equal to Val.
    108   unsigned N = (Zero | Val).countLeadingOnes();
    109 
    110   // For each of those bit positions, if Val has a 1 in that bit then our
    111   // underlying value must also have a 1.
    112   APInt MaskedVal(Val);
    113   MaskedVal.clearLowBits(getBitWidth() - N);
    114   return KnownBits(Zero, One | MaskedVal);
    115 }
    116 
    117 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
    118   // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
    119   // RHS. Ideally our caller would already have spotted these cases and
    120   // optimized away the umax operation, but we handle them here for
    121   // completeness.
    122   if (LHS.getMinValue().uge(RHS.getMaxValue()))
    123     return LHS;
    124   if (RHS.getMinValue().uge(LHS.getMaxValue()))
    125     return RHS;
    126 
    127   // If the result of the umax is LHS then it must be greater than or equal to
    128   // the minimum possible value of RHS. Likewise for RHS. Any known bits that
    129   // are common to these two values are also known in the result.
    130   KnownBits L = LHS.makeGE(RHS.getMinValue());
    131   KnownBits R = RHS.makeGE(LHS.getMinValue());
    132   return KnownBits::commonBits(L, R);
    133 }
    134 
    135 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
    136   // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
    137   auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
    138   return Flip(umax(Flip(LHS), Flip(RHS)));
    139 }
    140 
    141 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
    142   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
    143   auto Flip = [](const KnownBits &Val) {
    144     unsigned SignBitPosition = Val.getBitWidth() - 1;
    145     APInt Zero = Val.Zero;
    146     APInt One = Val.One;
    147     Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
    148     One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
    149     return KnownBits(Zero, One);
    150   };
    151   return Flip(umax(Flip(LHS), Flip(RHS)));
    152 }
    153 
    154 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
    155   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
    156   auto Flip = [](const KnownBits &Val) {
    157     unsigned SignBitPosition = Val.getBitWidth() - 1;
    158     APInt Zero = Val.One;
    159     APInt One = Val.Zero;
    160     Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
    161     One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
    162     return KnownBits(Zero, One);
    163   };
    164   return Flip(umax(Flip(LHS), Flip(RHS)));
    165 }
    166 
    167 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
    168   unsigned BitWidth = LHS.getBitWidth();
    169   KnownBits Known(BitWidth);
    170 
    171   // If the shift amount is a valid constant then transform LHS directly.
    172   if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
    173     unsigned Shift = RHS.getConstant().getZExtValue();
    174     Known = LHS;
    175     Known.Zero <<= Shift;
    176     Known.One <<= Shift;
    177     // Low bits are known zero.
    178     Known.Zero.setLowBits(Shift);
    179     return Known;
    180   }
    181 
    182   // No matter the shift amount, the trailing zeros will stay zero.
    183   unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
    184 
    185   // Minimum shift amount low bits are known zero.
    186   APInt MinShiftAmount = RHS.getMinValue();
    187   if (MinShiftAmount.ult(BitWidth)) {
    188     MinTrailingZeros += MinShiftAmount.getZExtValue();
    189     MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
    190   }
    191 
    192   // If the maximum shift is in range, then find the common bits from all
    193   // possible shifts.
    194   APInt MaxShiftAmount = RHS.getMaxValue();
    195   if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
    196     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
    197     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
    198     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
    199     Known.Zero.setAllBits();
    200     Known.One.setAllBits();
    201     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
    202                   MaxShiftAmt = MaxShiftAmount.getZExtValue();
    203          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
    204       // Skip if the shift amount is impossible.
    205       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
    206           (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
    207         continue;
    208       KnownBits SpecificShift;
    209       SpecificShift.Zero = LHS.Zero << ShiftAmt;
    210       SpecificShift.One = LHS.One << ShiftAmt;
    211       Known = KnownBits::commonBits(Known, SpecificShift);
    212       if (Known.isUnknown())
    213         break;
    214     }
    215   }
    216 
    217   Known.Zero.setLowBits(MinTrailingZeros);
    218   return Known;
    219 }
    220 
    221 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
    222   unsigned BitWidth = LHS.getBitWidth();
    223   KnownBits Known(BitWidth);
    224 
    225   if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
    226     unsigned Shift = RHS.getConstant().getZExtValue();
    227     Known = LHS;
    228     Known.Zero.lshrInPlace(Shift);
    229     Known.One.lshrInPlace(Shift);
    230     // High bits are known zero.
    231     Known.Zero.setHighBits(Shift);
    232     return Known;
    233   }
    234 
    235   // No matter the shift amount, the leading zeros will stay zero.
    236   unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
    237 
    238   // Minimum shift amount high bits are known zero.
    239   APInt MinShiftAmount = RHS.getMinValue();
    240   if (MinShiftAmount.ult(BitWidth)) {
    241     MinLeadingZeros += MinShiftAmount.getZExtValue();
    242     MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
    243   }
    244 
    245   // If the maximum shift is in range, then find the common bits from all
    246   // possible shifts.
    247   APInt MaxShiftAmount = RHS.getMaxValue();
    248   if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
    249     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
    250     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
    251     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
    252     Known.Zero.setAllBits();
    253     Known.One.setAllBits();
    254     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
    255                   MaxShiftAmt = MaxShiftAmount.getZExtValue();
    256          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
    257       // Skip if the shift amount is impossible.
    258       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
    259           (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
    260         continue;
    261       KnownBits SpecificShift = LHS;
    262       SpecificShift.Zero.lshrInPlace(ShiftAmt);
    263       SpecificShift.One.lshrInPlace(ShiftAmt);
    264       Known = KnownBits::commonBits(Known, SpecificShift);
    265       if (Known.isUnknown())
    266         break;
    267     }
    268   }
    269 
    270   Known.Zero.setHighBits(MinLeadingZeros);
    271   return Known;
    272 }
    273 
    274 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
    275   unsigned BitWidth = LHS.getBitWidth();
    276   KnownBits Known(BitWidth);
    277 
    278   if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
    279     unsigned Shift = RHS.getConstant().getZExtValue();
    280     Known = LHS;
    281     Known.Zero.ashrInPlace(Shift);
    282     Known.One.ashrInPlace(Shift);
    283     return Known;
    284   }
    285 
    286   // No matter the shift amount, the leading sign bits will stay.
    287   unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
    288   unsigned MinLeadingOnes = LHS.countMinLeadingOnes();
    289 
    290   // Minimum shift amount high bits are known sign bits.
    291   APInt MinShiftAmount = RHS.getMinValue();
    292   if (MinShiftAmount.ult(BitWidth)) {
    293     if (MinLeadingZeros) {
    294       MinLeadingZeros += MinShiftAmount.getZExtValue();
    295       MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
    296     }
    297     if (MinLeadingOnes) {
    298       MinLeadingOnes += MinShiftAmount.getZExtValue();
    299       MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
    300     }
    301   }
    302 
    303   // If the maximum shift is in range, then find the common bits from all
    304   // possible shifts.
    305   APInt MaxShiftAmount = RHS.getMaxValue();
    306   if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
    307     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
    308     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
    309     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
    310     Known.Zero.setAllBits();
    311     Known.One.setAllBits();
    312     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
    313                   MaxShiftAmt = MaxShiftAmount.getZExtValue();
    314          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
    315       // Skip if the shift amount is impossible.
    316       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
    317           (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
    318         continue;
    319       KnownBits SpecificShift = LHS;
    320       SpecificShift.Zero.ashrInPlace(ShiftAmt);
    321       SpecificShift.One.ashrInPlace(ShiftAmt);
    322       Known = KnownBits::commonBits(Known, SpecificShift);
    323       if (Known.isUnknown())
    324         break;
    325     }
    326   }
    327 
    328   Known.Zero.setHighBits(MinLeadingZeros);
    329   Known.One.setHighBits(MinLeadingOnes);
    330   return Known;
    331 }
    332 
    333 Optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
    334   if (LHS.isConstant() && RHS.isConstant())
    335     return Optional<bool>(LHS.getConstant() == RHS.getConstant());
    336   if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero))
    337     return Optional<bool>(false);
    338   return None;
    339 }
    340 
    341 Optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
    342   if (Optional<bool> KnownEQ = eq(LHS, RHS))
    343     return Optional<bool>(!KnownEQ.getValue());
    344   return None;
    345 }
    346 
    347 Optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
    348   // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
    349   if (LHS.getMaxValue().ule(RHS.getMinValue()))
    350     return Optional<bool>(false);
    351   // LHS >u RHS -> true if umin(LHS) > umax(RHS)
    352   if (LHS.getMinValue().ugt(RHS.getMaxValue()))
    353     return Optional<bool>(true);
    354   return None;
    355 }
    356 
    357 Optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
    358   if (Optional<bool> IsUGT = ugt(RHS, LHS))
    359     return Optional<bool>(!IsUGT.getValue());
    360   return None;
    361 }
    362 
    363 Optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
    364   return ugt(RHS, LHS);
    365 }
    366 
    367 Optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
    368   return uge(RHS, LHS);
    369 }
    370 
    371 Optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
    372   // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
    373   if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue()))
    374     return Optional<bool>(false);
    375   // LHS >s RHS -> true if smin(LHS) > smax(RHS)
    376   if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue()))
    377     return Optional<bool>(true);
    378   return None;
    379 }
    380 
    381 Optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
    382   if (Optional<bool> KnownSGT = sgt(RHS, LHS))
    383     return Optional<bool>(!KnownSGT.getValue());
    384   return None;
    385 }
    386 
    387 Optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
    388   return sgt(RHS, LHS);
    389 }
    390 
    391 Optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
    392   return sge(RHS, LHS);
    393 }
    394 
    395 KnownBits KnownBits::abs(bool IntMinIsPoison) const {
    396   // If the source's MSB is zero then we know the rest of the bits already.
    397   if (isNonNegative())
    398     return *this;
    399 
    400   // Absolute value preserves trailing zero count.
    401   KnownBits KnownAbs(getBitWidth());
    402   KnownAbs.Zero.setLowBits(countMinTrailingZeros());
    403 
    404   // We only know that the absolute values's MSB will be zero if INT_MIN is
    405   // poison, or there is a set bit that isn't the sign bit (otherwise it could
    406   // be INT_MIN).
    407   if (IntMinIsPoison || (!One.isNullValue() && !One.isMinSignedValue()))
    408     KnownAbs.Zero.setSignBit();
    409 
    410   // FIXME: Handle known negative input?
    411   // FIXME: Calculate the negated Known bits and combine them?
    412   return KnownAbs;
    413 }
    414 
    415 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS) {
    416   unsigned BitWidth = LHS.getBitWidth();
    417   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
    418          !RHS.hasConflict() && "Operand mismatch");
    419 
    420   // Compute a conservative estimate for high known-0 bits.
    421   unsigned LeadZ =
    422       std::max(LHS.countMinLeadingZeros() + RHS.countMinLeadingZeros(),
    423                BitWidth) -
    424       BitWidth;
    425   LeadZ = std::min(LeadZ, BitWidth);
    426 
    427   // The result of the bottom bits of an integer multiply can be
    428   // inferred by looking at the bottom bits of both operands and
    429   // multiplying them together.
    430   // We can infer at least the minimum number of known trailing bits
    431   // of both operands. Depending on number of trailing zeros, we can
    432   // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
    433   // a and b are divisible by m and n respectively.
    434   // We then calculate how many of those bits are inferrable and set
    435   // the output. For example, the i8 mul:
    436   //  a = XXXX1100 (12)
    437   //  b = XXXX1110 (14)
    438   // We know the bottom 3 bits are zero since the first can be divided by
    439   // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
    440   // Applying the multiplication to the trimmed arguments gets:
    441   //    XX11 (3)
    442   //    X111 (7)
    443   // -------
    444   //    XX11
    445   //   XX11
    446   //  XX11
    447   // XX11
    448   // -------
    449   // XXXXX01
    450   // Which allows us to infer the 2 LSBs. Since we're multiplying the result
    451   // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
    452   // The proof for this can be described as:
    453   // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
    454   //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
    455   //                    umin(countTrailingZeros(C2), C6) +
    456   //                    umin(C5 - umin(countTrailingZeros(C1), C5),
    457   //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
    458   // %aa = shl i8 %a, C5
    459   // %bb = shl i8 %b, C6
    460   // %aaa = or i8 %aa, C1
    461   // %bbb = or i8 %bb, C2
    462   // %mul = mul i8 %aaa, %bbb
    463   // %mask = and i8 %mul, C7
    464   //   =>
    465   // %mask = i8 ((C1*C2)&C7)
    466   // Where C5, C6 describe the known bits of %a, %b
    467   // C1, C2 describe the known bottom bits of %a, %b.
    468   // C7 describes the mask of the known bits of the result.
    469   const APInt &Bottom0 = LHS.One;
    470   const APInt &Bottom1 = RHS.One;
    471 
    472   // How many times we'd be able to divide each argument by 2 (shr by 1).
    473   // This gives us the number of trailing zeros on the multiplication result.
    474   unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countTrailingOnes();
    475   unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countTrailingOnes();
    476   unsigned TrailZero0 = LHS.countMinTrailingZeros();
    477   unsigned TrailZero1 = RHS.countMinTrailingZeros();
    478   unsigned TrailZ = TrailZero0 + TrailZero1;
    479 
    480   // Figure out the fewest known-bits operand.
    481   unsigned SmallestOperand =
    482       std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1);
    483   unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
    484 
    485   APInt BottomKnown =
    486       Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1);
    487 
    488   KnownBits Res(BitWidth);
    489   Res.Zero.setHighBits(LeadZ);
    490   Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
    491   Res.One = BottomKnown.getLoBits(ResultBitsKnown);
    492   return Res;
    493 }
    494 
    495 KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
    496   unsigned BitWidth = LHS.getBitWidth();
    497   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
    498          !RHS.hasConflict() && "Operand mismatch");
    499   KnownBits WideLHS = LHS.sext(2 * BitWidth);
    500   KnownBits WideRHS = RHS.sext(2 * BitWidth);
    501   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
    502 }
    503 
    504 KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
    505   unsigned BitWidth = LHS.getBitWidth();
    506   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
    507          !RHS.hasConflict() && "Operand mismatch");
    508   KnownBits WideLHS = LHS.zext(2 * BitWidth);
    509   KnownBits WideRHS = RHS.zext(2 * BitWidth);
    510   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
    511 }
    512 
    513 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) {
    514   unsigned BitWidth = LHS.getBitWidth();
    515   assert(!LHS.hasConflict() && !RHS.hasConflict());
    516   KnownBits Known(BitWidth);
    517 
    518   // For the purposes of computing leading zeros we can conservatively
    519   // treat a udiv as a logical right shift by the power of 2 known to
    520   // be less than the denominator.
    521   unsigned LeadZ = LHS.countMinLeadingZeros();
    522   unsigned RHSMaxLeadingZeros = RHS.countMaxLeadingZeros();
    523 
    524   if (RHSMaxLeadingZeros != BitWidth)
    525     LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSMaxLeadingZeros - 1);
    526 
    527   Known.Zero.setHighBits(LeadZ);
    528   return Known;
    529 }
    530 
    531 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
    532   unsigned BitWidth = LHS.getBitWidth();
    533   assert(!LHS.hasConflict() && !RHS.hasConflict());
    534   KnownBits Known(BitWidth);
    535 
    536   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
    537     // The upper bits are all zero, the lower ones are unchanged.
    538     APInt LowBits = RHS.getConstant() - 1;
    539     Known.Zero = LHS.Zero | ~LowBits;
    540     Known.One = LHS.One & LowBits;
    541     return Known;
    542   }
    543 
    544   // Since the result is less than or equal to either operand, any leading
    545   // zero bits in either operand must also exist in the result.
    546   uint32_t Leaders =
    547       std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros());
    548   Known.Zero.setHighBits(Leaders);
    549   return Known;
    550 }
    551 
    552 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
    553   unsigned BitWidth = LHS.getBitWidth();
    554   assert(!LHS.hasConflict() && !RHS.hasConflict());
    555   KnownBits Known(BitWidth);
    556 
    557   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
    558     // The low bits of the first operand are unchanged by the srem.
    559     APInt LowBits = RHS.getConstant() - 1;
    560     Known.Zero = LHS.Zero & LowBits;
    561     Known.One = LHS.One & LowBits;
    562 
    563     // If the first operand is non-negative or has all low bits zero, then
    564     // the upper bits are all zero.
    565     if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero))
    566       Known.Zero |= ~LowBits;
    567 
    568     // If the first operand is negative and not all low bits are zero, then
    569     // the upper bits are all one.
    570     if (LHS.isNegative() && LowBits.intersects(LHS.One))
    571       Known.One |= ~LowBits;
    572     return Known;
    573   }
    574 
    575   // The sign bit is the LHS's sign bit, except when the result of the
    576   // remainder is zero. The magnitude of the result should be less than or
    577   // equal to the magnitude of the LHS. Therefore any leading zeros that exist
    578   // in the left hand side must also exist in the result.
    579   Known.Zero.setHighBits(LHS.countMinLeadingZeros());
    580   return Known;
    581 }
    582 
    583 KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
    584   // Result bit is 0 if either operand bit is 0.
    585   Zero |= RHS.Zero;
    586   // Result bit is 1 if both operand bits are 1.
    587   One &= RHS.One;
    588   return *this;
    589 }
    590 
    591 KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
    592   // Result bit is 0 if both operand bits are 0.
    593   Zero &= RHS.Zero;
    594   // Result bit is 1 if either operand bit is 1.
    595   One |= RHS.One;
    596   return *this;
    597 }
    598 
    599 KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
    600   // Result bit is 0 if both operand bits are 0 or both are 1.
    601   APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
    602   // Result bit is 1 if one operand bit is 0 and the other is 1.
    603   One = (Zero & RHS.One) | (One & RHS.Zero);
    604   Zero = std::move(Z);
    605   return *this;
    606 }
    607 
    608 void KnownBits::print(raw_ostream &OS) const {
    609   OS << "{Zero=" << Zero << ", One=" << One << "}";
    610 }
    611 void KnownBits::dump() const {
    612   print(dbgs());
    613   dbgs() << "\n";
    614 }
    615