Home | History | Annotate | Line # | Download | only in IR
      1 //===- llvm/FixedPointBuilder.h - Builder for fixed-point ops ---*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file defines the FixedPointBuilder class, which is used as a convenient
     10 // way to lower fixed-point arithmetic operations to LLVM IR.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #ifndef LLVM_IR_FIXEDPOINTBUILDER_H
     15 #define LLVM_IR_FIXEDPOINTBUILDER_H
     16 
     17 #include "llvm/ADT/APFixedPoint.h"
     18 #include "llvm/IR/Constant.h"
     19 #include "llvm/IR/Constants.h"
     20 #include "llvm/IR/IRBuilder.h"
     21 #include "llvm/IR/InstrTypes.h"
     22 #include "llvm/IR/Instruction.h"
     23 #include "llvm/IR/IntrinsicInst.h"
     24 #include "llvm/IR/Intrinsics.h"
     25 #include "llvm/IR/Type.h"
     26 #include "llvm/IR/Value.h"
     27 
     28 namespace llvm {
     29 
     30 template <class IRBuilderTy> class FixedPointBuilder {
     31   IRBuilderTy &B;
     32 
     33   Value *Convert(Value *Src, const FixedPointSemantics &SrcSema,
     34                  const FixedPointSemantics &DstSema, bool DstIsInteger) {
     35     unsigned SrcWidth = SrcSema.getWidth();
     36     unsigned DstWidth = DstSema.getWidth();
     37     unsigned SrcScale = SrcSema.getScale();
     38     unsigned DstScale = DstSema.getScale();
     39     bool SrcIsSigned = SrcSema.isSigned();
     40     bool DstIsSigned = DstSema.isSigned();
     41 
     42     Type *DstIntTy = B.getIntNTy(DstWidth);
     43 
     44     Value *Result = Src;
     45     unsigned ResultWidth = SrcWidth;
     46 
     47     // Downscale.
     48     if (DstScale < SrcScale) {
     49       // When converting to integers, we round towards zero. For negative
     50       // numbers, right shifting rounds towards negative infinity. In this case,
     51       // we can just round up before shifting.
     52       if (DstIsInteger && SrcIsSigned) {
     53         Value *Zero = Constant::getNullValue(Result->getType());
     54         Value *IsNegative = B.CreateICmpSLT(Result, Zero);
     55         Value *LowBits = ConstantInt::get(
     56             B.getContext(), APInt::getLowBitsSet(ResultWidth, SrcScale));
     57         Value *Rounded = B.CreateAdd(Result, LowBits);
     58         Result = B.CreateSelect(IsNegative, Rounded, Result);
     59       }
     60 
     61       Result = SrcIsSigned
     62                    ? B.CreateAShr(Result, SrcScale - DstScale, "downscale")
     63                    : B.CreateLShr(Result, SrcScale - DstScale, "downscale");
     64     }
     65 
     66     if (!DstSema.isSaturated()) {
     67       // Resize.
     68       Result = B.CreateIntCast(Result, DstIntTy, SrcIsSigned, "resize");
     69 
     70       // Upscale.
     71       if (DstScale > SrcScale)
     72         Result = B.CreateShl(Result, DstScale - SrcScale, "upscale");
     73     } else {
     74       // Adjust the number of fractional bits.
     75       if (DstScale > SrcScale) {
     76         // Compare to DstWidth to prevent resizing twice.
     77         ResultWidth = std::max(SrcWidth + DstScale - SrcScale, DstWidth);
     78         Type *UpscaledTy = B.getIntNTy(ResultWidth);
     79         Result = B.CreateIntCast(Result, UpscaledTy, SrcIsSigned, "resize");
     80         Result = B.CreateShl(Result, DstScale - SrcScale, "upscale");
     81       }
     82 
     83       // Handle saturation.
     84       bool LessIntBits = DstSema.getIntegralBits() < SrcSema.getIntegralBits();
     85       if (LessIntBits) {
     86         Value *Max = ConstantInt::get(
     87             B.getContext(),
     88             APFixedPoint::getMax(DstSema).getValue().extOrTrunc(ResultWidth));
     89         Value *TooHigh = SrcIsSigned ? B.CreateICmpSGT(Result, Max)
     90                                      : B.CreateICmpUGT(Result, Max);
     91         Result = B.CreateSelect(TooHigh, Max, Result, "satmax");
     92       }
     93       // Cannot overflow min to dest type if src is unsigned since all fixed
     94       // point types can cover the unsigned min of 0.
     95       if (SrcIsSigned && (LessIntBits || !DstIsSigned)) {
     96         Value *Min = ConstantInt::get(
     97             B.getContext(),
     98             APFixedPoint::getMin(DstSema).getValue().extOrTrunc(ResultWidth));
     99         Value *TooLow = B.CreateICmpSLT(Result, Min);
    100         Result = B.CreateSelect(TooLow, Min, Result, "satmin");
    101       }
    102 
    103       // Resize the integer part to get the final destination size.
    104       if (ResultWidth != DstWidth)
    105         Result = B.CreateIntCast(Result, DstIntTy, SrcIsSigned, "resize");
    106     }
    107     return Result;
    108   }
    109 
    110   /// Get the common semantic for two semantics, with the added imposition that
    111   /// saturated padded types retain the padding bit.
    112   FixedPointSemantics
    113   getCommonBinopSemantic(const FixedPointSemantics &LHSSema,
    114                          const FixedPointSemantics &RHSSema) {
    115     auto C = LHSSema.getCommonSemantics(RHSSema);
    116     bool BothPadded =
    117         LHSSema.hasUnsignedPadding() && RHSSema.hasUnsignedPadding();
    118     return FixedPointSemantics(
    119         C.getWidth() + (unsigned)(BothPadded && C.isSaturated()), C.getScale(),
    120         C.isSigned(), C.isSaturated(), BothPadded);
    121   }
    122 
    123   /// Given a floating point type and a fixed-point semantic, return a floating
    124   /// point type which can accommodate the fixed-point semantic. This is either
    125   /// \p Ty, or a floating point type with a larger exponent than Ty.
    126   Type *getAccommodatingFloatType(Type *Ty, const FixedPointSemantics &Sema) {
    127     const fltSemantics *FloatSema = &Ty->getFltSemantics();
    128     while (!Sema.fitsInFloatSemantics(*FloatSema))
    129       FloatSema = APFixedPoint::promoteFloatSemantics(FloatSema);
    130     return Type::getFloatingPointTy(Ty->getContext(), *FloatSema);
    131   }
    132 
    133 public:
    134   FixedPointBuilder(IRBuilderTy &Builder) : B(Builder) {}
    135 
    136   /// Convert an integer value representing a fixed-point number from one
    137   /// fixed-point semantic to another fixed-point semantic.
    138   /// \p Src     - The source value
    139   /// \p SrcSema - The fixed-point semantic of the source value
    140   /// \p DstSema - The resulting fixed-point semantic
    141   Value *CreateFixedToFixed(Value *Src, const FixedPointSemantics &SrcSema,
    142                             const FixedPointSemantics &DstSema) {
    143     return Convert(Src, SrcSema, DstSema, false);
    144   }
    145 
    146   /// Convert an integer value representing a fixed-point number to an integer
    147   /// with the given bit width and signedness.
    148   /// \p Src         - The source value
    149   /// \p SrcSema     - The fixed-point semantic of the source value
    150   /// \p DstWidth    - The bit width of the result value
    151   /// \p DstIsSigned - The signedness of the result value
    152   Value *CreateFixedToInteger(Value *Src, const FixedPointSemantics &SrcSema,
    153                               unsigned DstWidth, bool DstIsSigned) {
    154     return Convert(
    155         Src, SrcSema,
    156         FixedPointSemantics::GetIntegerSemantics(DstWidth, DstIsSigned), true);
    157   }
    158 
    159   /// Convert an integer value with the given signedness to an integer value
    160   /// representing the given fixed-point semantic.
    161   /// \p Src         - The source value
    162   /// \p SrcIsSigned - The signedness of the source value
    163   /// \p DstSema     - The resulting fixed-point semantic
    164   Value *CreateIntegerToFixed(Value *Src, unsigned SrcIsSigned,
    165                               const FixedPointSemantics &DstSema) {
    166     return Convert(Src,
    167                    FixedPointSemantics::GetIntegerSemantics(
    168                        Src->getType()->getScalarSizeInBits(), SrcIsSigned),
    169                    DstSema, false);
    170   }
    171 
    172   Value *CreateFixedToFloating(Value *Src, const FixedPointSemantics &SrcSema,
    173                                Type *DstTy) {
    174     Value *Result;
    175     Type *OpTy = getAccommodatingFloatType(DstTy, SrcSema);
    176     // Convert the raw fixed-point value directly to floating point. If the
    177     // value is too large to fit, it will be rounded, not truncated.
    178     Result = SrcSema.isSigned() ? B.CreateSIToFP(Src, OpTy)
    179                                 : B.CreateUIToFP(Src, OpTy);
    180     // Rescale the integral-in-floating point by the scaling factor. This is
    181     // lossless, except for overflow to infinity which is unlikely.
    182     Result = B.CreateFMul(Result,
    183         ConstantFP::get(OpTy, std::pow(2, -(int)SrcSema.getScale())));
    184     if (OpTy != DstTy)
    185       Result = B.CreateFPTrunc(Result, DstTy);
    186     return Result;
    187   }
    188 
    189   Value *CreateFloatingToFixed(Value *Src, const FixedPointSemantics &DstSema) {
    190     bool UseSigned = DstSema.isSigned() || DstSema.hasUnsignedPadding();
    191     Value *Result = Src;
    192     Type *OpTy = getAccommodatingFloatType(Src->getType(), DstSema);
    193     if (OpTy != Src->getType())
    194       Result = B.CreateFPExt(Result, OpTy);
    195     // Rescale the floating point value so that its significant bits (for the
    196     // purposes of the conversion) are in the integral range.
    197     Result = B.CreateFMul(Result,
    198         ConstantFP::get(OpTy, std::pow(2, DstSema.getScale())));
    199 
    200     Type *ResultTy = B.getIntNTy(DstSema.getWidth());
    201     if (DstSema.isSaturated()) {
    202       Intrinsic::ID IID =
    203           UseSigned ? Intrinsic::fptosi_sat : Intrinsic::fptoui_sat;
    204       Result = B.CreateIntrinsic(IID, {ResultTy, OpTy}, {Result});
    205     } else {
    206       Result = UseSigned ? B.CreateFPToSI(Result, ResultTy)
    207                          : B.CreateFPToUI(Result, ResultTy);
    208     }
    209 
    210     // When saturating unsigned-with-padding using signed operations, we may
    211     // get negative values. Emit an extra clamp to zero.
    212     if (DstSema.isSaturated() && DstSema.hasUnsignedPadding()) {
    213       Constant *Zero = Constant::getNullValue(Result->getType());
    214       Result =
    215           B.CreateSelect(B.CreateICmpSLT(Result, Zero), Zero, Result, "satmin");
    216     }
    217 
    218     return Result;
    219   }
    220 
    221   /// Add two fixed-point values and return the result in their common semantic.
    222   /// \p LHS     - The left hand side
    223   /// \p LHSSema - The semantic of the left hand side
    224   /// \p RHS     - The right hand side
    225   /// \p RHSSema - The semantic of the right hand side
    226   Value *CreateAdd(Value *LHS, const FixedPointSemantics &LHSSema,
    227                    Value *RHS, const FixedPointSemantics &RHSSema) {
    228     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
    229     bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
    230 
    231     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
    232     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
    233 
    234     Value *Result;
    235     if (CommonSema.isSaturated()) {
    236       Intrinsic::ID IID = UseSigned ? Intrinsic::sadd_sat : Intrinsic::uadd_sat;
    237       Result = B.CreateBinaryIntrinsic(IID, WideLHS, WideRHS);
    238     } else {
    239       Result = B.CreateAdd(WideLHS, WideRHS);
    240     }
    241 
    242     return CreateFixedToFixed(Result, CommonSema,
    243                               LHSSema.getCommonSemantics(RHSSema));
    244   }
    245 
    246   /// Subtract two fixed-point values and return the result in their common
    247   /// semantic.
    248   /// \p LHS     - The left hand side
    249   /// \p LHSSema - The semantic of the left hand side
    250   /// \p RHS     - The right hand side
    251   /// \p RHSSema - The semantic of the right hand side
    252   Value *CreateSub(Value *LHS, const FixedPointSemantics &LHSSema,
    253                    Value *RHS, const FixedPointSemantics &RHSSema) {
    254     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
    255     bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
    256 
    257     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
    258     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
    259 
    260     Value *Result;
    261     if (CommonSema.isSaturated()) {
    262       Intrinsic::ID IID = UseSigned ? Intrinsic::ssub_sat : Intrinsic::usub_sat;
    263       Result = B.CreateBinaryIntrinsic(IID, WideLHS, WideRHS);
    264     } else {
    265       Result = B.CreateSub(WideLHS, WideRHS);
    266     }
    267 
    268     // Subtraction can end up below 0 for padded unsigned operations, so emit
    269     // an extra clamp in that case.
    270     if (CommonSema.isSaturated() && CommonSema.hasUnsignedPadding()) {
    271       Constant *Zero = Constant::getNullValue(Result->getType());
    272       Result =
    273           B.CreateSelect(B.CreateICmpSLT(Result, Zero), Zero, Result, "satmin");
    274     }
    275 
    276     return CreateFixedToFixed(Result, CommonSema,
    277                               LHSSema.getCommonSemantics(RHSSema));
    278   }
    279 
    280   /// Multiply two fixed-point values and return the result in their common
    281   /// semantic.
    282   /// \p LHS     - The left hand side
    283   /// \p LHSSema - The semantic of the left hand side
    284   /// \p RHS     - The right hand side
    285   /// \p RHSSema - The semantic of the right hand side
    286   Value *CreateMul(Value *LHS, const FixedPointSemantics &LHSSema,
    287                    Value *RHS, const FixedPointSemantics &RHSSema) {
    288     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
    289     bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
    290 
    291     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
    292     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
    293 
    294     Intrinsic::ID IID;
    295     if (CommonSema.isSaturated()) {
    296       IID = UseSigned ? Intrinsic::smul_fix_sat : Intrinsic::umul_fix_sat;
    297     } else {
    298       IID = UseSigned ? Intrinsic::smul_fix : Intrinsic::umul_fix;
    299     }
    300     Value *Result = B.CreateIntrinsic(
    301         IID, {WideLHS->getType()},
    302         {WideLHS, WideRHS, B.getInt32(CommonSema.getScale())});
    303 
    304     return CreateFixedToFixed(Result, CommonSema,
    305                               LHSSema.getCommonSemantics(RHSSema));
    306   }
    307 
    308   /// Divide two fixed-point values and return the result in their common
    309   /// semantic.
    310   /// \p LHS     - The left hand side
    311   /// \p LHSSema - The semantic of the left hand side
    312   /// \p RHS     - The right hand side
    313   /// \p RHSSema - The semantic of the right hand side
    314   Value *CreateDiv(Value *LHS, const FixedPointSemantics &LHSSema,
    315                    Value *RHS, const FixedPointSemantics &RHSSema) {
    316     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
    317     bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
    318 
    319     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
    320     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
    321 
    322     Intrinsic::ID IID;
    323     if (CommonSema.isSaturated()) {
    324       IID = UseSigned ? Intrinsic::sdiv_fix_sat : Intrinsic::udiv_fix_sat;
    325     } else {
    326       IID = UseSigned ? Intrinsic::sdiv_fix : Intrinsic::udiv_fix;
    327     }
    328     Value *Result = B.CreateIntrinsic(
    329         IID, {WideLHS->getType()},
    330         {WideLHS, WideRHS, B.getInt32(CommonSema.getScale())});
    331 
    332     return CreateFixedToFixed(Result, CommonSema,
    333                               LHSSema.getCommonSemantics(RHSSema));
    334   }
    335 
    336   /// Left shift a fixed-point value by an unsigned integer value. The integer
    337   /// value can be any bit width.
    338   /// \p LHS     - The left hand side
    339   /// \p LHSSema - The semantic of the left hand side
    340   /// \p RHS     - The right hand side
    341   Value *CreateShl(Value *LHS, const FixedPointSemantics &LHSSema, Value *RHS) {
    342     bool UseSigned = LHSSema.isSigned() || LHSSema.hasUnsignedPadding();
    343 
    344     RHS = B.CreateIntCast(RHS, LHS->getType(), /*IsSigned=*/false);
    345 
    346     Value *Result;
    347     if (LHSSema.isSaturated()) {
    348       Intrinsic::ID IID = UseSigned ? Intrinsic::sshl_sat : Intrinsic::ushl_sat;
    349       Result = B.CreateBinaryIntrinsic(IID, LHS, RHS);
    350     } else {
    351       Result = B.CreateShl(LHS, RHS);
    352     }
    353 
    354     return Result;
    355   }
    356 
    357   /// Right shift a fixed-point value by an unsigned integer value. The integer
    358   /// value can be any bit width.
    359   /// \p LHS     - The left hand side
    360   /// \p LHSSema - The semantic of the left hand side
    361   /// \p RHS     - The right hand side
    362   Value *CreateShr(Value *LHS, const FixedPointSemantics &LHSSema, Value *RHS) {
    363     RHS = B.CreateIntCast(RHS, LHS->getType(), false);
    364 
    365     return LHSSema.isSigned() ? B.CreateAShr(LHS, RHS) : B.CreateLShr(LHS, RHS);
    366   }
    367 
    368   /// Compare two fixed-point values for equality.
    369   /// \p LHS     - The left hand side
    370   /// \p LHSSema - The semantic of the left hand side
    371   /// \p RHS     - The right hand side
    372   /// \p RHSSema - The semantic of the right hand side
    373   Value *CreateEQ(Value *LHS, const FixedPointSemantics &LHSSema,
    374                   Value *RHS, const FixedPointSemantics &RHSSema) {
    375     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
    376 
    377     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
    378     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
    379 
    380     return B.CreateICmpEQ(WideLHS, WideRHS);
    381   }
    382 
    383   /// Compare two fixed-point values for inequality.
    384   /// \p LHS     - The left hand side
    385   /// \p LHSSema - The semantic of the left hand side
    386   /// \p RHS     - The right hand side
    387   /// \p RHSSema - The semantic of the right hand side
    388   Value *CreateNE(Value *LHS, const FixedPointSemantics &LHSSema,
    389                   Value *RHS, const FixedPointSemantics &RHSSema) {
    390     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
    391 
    392     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
    393     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
    394 
    395     return B.CreateICmpNE(WideLHS, WideRHS);
    396   }
    397 
    398   /// Compare two fixed-point values as LHS < RHS.
    399   /// \p LHS     - The left hand side
    400   /// \p LHSSema - The semantic of the left hand side
    401   /// \p RHS     - The right hand side
    402   /// \p RHSSema - The semantic of the right hand side
    403   Value *CreateLT(Value *LHS, const FixedPointSemantics &LHSSema,
    404                   Value *RHS, const FixedPointSemantics &RHSSema) {
    405     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
    406 
    407     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
    408     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
    409 
    410     return CommonSema.isSigned() ? B.CreateICmpSLT(WideLHS, WideRHS)
    411                                  : B.CreateICmpULT(WideLHS, WideRHS);
    412   }
    413 
    414   /// Compare two fixed-point values as LHS <= RHS.
    415   /// \p LHS     - The left hand side
    416   /// \p LHSSema - The semantic of the left hand side
    417   /// \p RHS     - The right hand side
    418   /// \p RHSSema - The semantic of the right hand side
    419   Value *CreateLE(Value *LHS, const FixedPointSemantics &LHSSema,
    420                   Value *RHS, const FixedPointSemantics &RHSSema) {
    421     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
    422 
    423     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
    424     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
    425 
    426     return CommonSema.isSigned() ? B.CreateICmpSLE(WideLHS, WideRHS)
    427                                  : B.CreateICmpULE(WideLHS, WideRHS);
    428   }
    429 
    430   /// Compare two fixed-point values as LHS > RHS.
    431   /// \p LHS     - The left hand side
    432   /// \p LHSSema - The semantic of the left hand side
    433   /// \p RHS     - The right hand side
    434   /// \p RHSSema - The semantic of the right hand side
    435   Value *CreateGT(Value *LHS, const FixedPointSemantics &LHSSema,
    436                   Value *RHS, const FixedPointSemantics &RHSSema) {
    437     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
    438 
    439     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
    440     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
    441 
    442     return CommonSema.isSigned() ? B.CreateICmpSGT(WideLHS, WideRHS)
    443                                  : B.CreateICmpUGT(WideLHS, WideRHS);
    444   }
    445 
    446   /// Compare two fixed-point values as LHS >= RHS.
    447   /// \p LHS     - The left hand side
    448   /// \p LHSSema - The semantic of the left hand side
    449   /// \p RHS     - The right hand side
    450   /// \p RHSSema - The semantic of the right hand side
    451   Value *CreateGE(Value *LHS, const FixedPointSemantics &LHSSema,
    452                   Value *RHS, const FixedPointSemantics &RHSSema) {
    453     auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
    454 
    455     Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
    456     Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
    457 
    458     return CommonSema.isSigned() ? B.CreateICmpSGE(WideLHS, WideRHS)
    459                                  : B.CreateICmpUGE(WideLHS, WideRHS);
    460   }
    461 };
    462 
    463 } // end namespace llvm
    464 
    465 #endif // LLVM_IR_FIXEDPOINTBUILDER_H
    466