Home | History | Annotate | Line # | Download | only in Analysis
      1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file defines some vectorizer utilities.
     10 //
     11 //===----------------------------------------------------------------------===//
     12 
     13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
     14 #define LLVM_ANALYSIS_VECTORUTILS_H
     15 
     16 #include "llvm/ADT/MapVector.h"
     17 #include "llvm/ADT/SmallVector.h"
     18 #include "llvm/Analysis/LoopAccessAnalysis.h"
     19 #include "llvm/Support/CheckedArithmetic.h"
     20 
     21 namespace llvm {
     22 class TargetLibraryInfo;
     23 
     24 /// Describes the type of Parameters
     25 enum class VFParamKind {
     26   Vector,            // No semantic information.
     27   OMP_Linear,        // declare simd linear(i)
     28   OMP_LinearRef,     // declare simd linear(ref(i))
     29   OMP_LinearVal,     // declare simd linear(val(i))
     30   OMP_LinearUVal,    // declare simd linear(uval(i))
     31   OMP_LinearPos,     // declare simd linear(i:c) uniform(c)
     32   OMP_LinearValPos,  // declare simd linear(val(i:c)) uniform(c)
     33   OMP_LinearRefPos,  // declare simd linear(ref(i:c)) uniform(c)
     34   OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c
     35   OMP_Uniform,       // declare simd uniform(i)
     36   GlobalPredicate,   // Global logical predicate that acts on all lanes
     37                      // of the input and output mask concurrently. For
     38                      // example, it is implied by the `M` token in the
     39                      // Vector Function ABI mangled name.
     40   Unknown
     41 };
     42 
     43 /// Describes the type of Instruction Set Architecture
     44 enum class VFISAKind {
     45   AdvancedSIMD, // AArch64 Advanced SIMD (NEON)
     46   SVE,          // AArch64 Scalable Vector Extension
     47   SSE,          // x86 SSE
     48   AVX,          // x86 AVX
     49   AVX2,         // x86 AVX2
     50   AVX512,       // x86 AVX512
     51   LLVM,         // LLVM internal ISA for functions that are not
     52   // attached to an existing ABI via name mangling.
     53   Unknown // Unknown ISA
     54 };
     55 
     56 /// Encapsulates information needed to describe a parameter.
     57 ///
     58 /// The description of the parameter is not linked directly to
     59 /// OpenMP or any other vector function description. This structure
     60 /// is extendible to handle other paradigms that describe vector
     61 /// functions and their parameters.
     62 struct VFParameter {
     63   unsigned ParamPos;         // Parameter Position in Scalar Function.
     64   VFParamKind ParamKind;     // Kind of Parameter.
     65   int LinearStepOrPos = 0;   // Step or Position of the Parameter.
     66   Align Alignment = Align(); // Optional alignment in bytes, defaulted to 1.
     67 
     68   // Comparison operator.
     69   bool operator==(const VFParameter &Other) const {
     70     return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) ==
     71            std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos,
     72                     Other.Alignment);
     73   }
     74 };
     75 
     76 /// Contains the information about the kind of vectorization
     77 /// available.
     78 ///
     79 /// This object in independent on the paradigm used to
     80 /// represent vector functions. in particular, it is not attached to
     81 /// any target-specific ABI.
     82 struct VFShape {
     83   unsigned VF;     // Vectorization factor.
     84   bool IsScalable; // True if the function is a scalable function.
     85   SmallVector<VFParameter, 8> Parameters; // List of parameter information.
     86   // Comparison operator.
     87   bool operator==(const VFShape &Other) const {
     88     return std::tie(VF, IsScalable, Parameters) ==
     89            std::tie(Other.VF, Other.IsScalable, Other.Parameters);
     90   }
     91 
     92   /// Update the parameter in position P.ParamPos to P.
     93   void updateParam(VFParameter P) {
     94     assert(P.ParamPos < Parameters.size() && "Invalid parameter position.");
     95     Parameters[P.ParamPos] = P;
     96     assert(hasValidParameterList() && "Invalid parameter list");
     97   }
     98 
     99   // Retrieve the VFShape that can be used to map a (scalar) function to itself,
    100   // with VF = 1.
    101   static VFShape getScalarShape(const CallInst &CI) {
    102     return VFShape::get(CI, ElementCount::getFixed(1),
    103                         /*HasGlobalPredicate*/ false);
    104   }
    105 
    106   // Retrieve the basic vectorization shape of the function, where all
    107   // parameters are mapped to VFParamKind::Vector with \p EC
    108   // lanes. Specifies whether the function has a Global Predicate
    109   // argument via \p HasGlobalPred.
    110   static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) {
    111     SmallVector<VFParameter, 8> Parameters;
    112     for (unsigned I = 0; I < CI.arg_size(); ++I)
    113       Parameters.push_back(VFParameter({I, VFParamKind::Vector}));
    114     if (HasGlobalPred)
    115       Parameters.push_back(
    116           VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate}));
    117 
    118     return {EC.getKnownMinValue(), EC.isScalable(), Parameters};
    119   }
    120   /// Sanity check on the Parameters in the VFShape.
    121   bool hasValidParameterList() const;
    122 };
    123 
    124 /// Holds the VFShape for a specific scalar to vector function mapping.
    125 struct VFInfo {
    126   VFShape Shape;          /// Classification of the vector function.
    127   std::string ScalarName; /// Scalar Function Name.
    128   std::string VectorName; /// Vector Function Name associated to this VFInfo.
    129   VFISAKind ISA;          /// Instruction Set Architecture.
    130 };
    131 
    132 namespace VFABI {
    133 /// LLVM Internal VFABI ISA token for vector functions.
    134 static constexpr char const *_LLVM_ = "_LLVM_";
    135 /// Prefix for internal name redirection for vector function that
    136 /// tells the compiler to scalarize the call using the scalar name
    137 /// of the function. For example, a mangled name like
    138 /// `_ZGV_LLVM_N2v_foo(_LLVM_Scalarize_foo)` would tell the
    139 /// vectorizer to vectorize the scalar call `foo`, and to scalarize
    140 /// it once vectorization is done.
    141 static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
    142 
    143 /// Function to construct a VFInfo out of a mangled names in the
    144 /// following format:
    145 ///
    146 /// <VFABI_name>{(<redirection>)}
    147 ///
    148 /// where <VFABI_name> is the name of the vector function, mangled according
    149 /// to the rules described in the Vector Function ABI of the target vector
    150 /// extension (or <isa> from now on). The <VFABI_name> is in the following
    151 /// format:
    152 ///
    153 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
    154 ///
    155 /// This methods support demangling rules for the following <isa>:
    156 ///
    157 /// * AArch64: https://developer.arm.com/docs/101129/latest
    158 ///
    159 /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and
    160 ///  https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt
    161 ///
    162 /// \param MangledName -> input string in the format
    163 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
    164 /// \param M -> Module used to retrieve informations about the vector
    165 /// function that are not possible to retrieve from the mangled
    166 /// name. At the moment, this parameter is needed only to retrieve the
    167 /// Vectorization Factor of scalable vector functions from their
    168 /// respective IR declarations.
    169 Optional<VFInfo> tryDemangleForVFABI(StringRef MangledName, const Module &M);
    170 
    171 /// This routine mangles the given VectorName according to the LangRef
    172 /// specification for vector-function-abi-variant attribute and is specific to
    173 /// the TLI mappings. It is the responsibility of the caller to make sure that
    174 /// this is only used if all parameters in the vector function are vector type.
    175 /// This returned string holds scalar-to-vector mapping:
    176 ///    _ZGV<isa><mask><vlen><vparams>_<scalarname>(<vectorname>)
    177 ///
    178 /// where:
    179 ///
    180 /// <isa> = "_LLVM_"
    181 /// <mask> = "N". Note: TLI does not support masked interfaces.
    182 /// <vlen> = Number of concurrent lanes, stored in the `VectorizationFactor`
    183 ///          field of the `VecDesc` struct. If the number of lanes is scalable
    184 ///          then 'x' is printed instead.
    185 /// <vparams> = "v", as many as are the numArgs.
    186 /// <scalarname> = the name of the scalar function.
    187 /// <vectorname> = the name of the vector function.
    188 std::string mangleTLIVectorName(StringRef VectorName, StringRef ScalarName,
    189                                 unsigned numArgs, ElementCount VF);
    190 
    191 /// Retrieve the `VFParamKind` from a string token.
    192 VFParamKind getVFParamKindFromString(const StringRef Token);
    193 
    194 // Name of the attribute where the variant mappings are stored.
    195 static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
    196 
    197 /// Populates a set of strings representing the Vector Function ABI variants
    198 /// associated to the CallInst CI. If the CI does not contain the
    199 /// vector-function-abi-variant attribute, we return without populating
    200 /// VariantMappings, i.e. callers of getVectorVariantNames need not check for
    201 /// the presence of the attribute (see InjectTLIMappings).
    202 void getVectorVariantNames(const CallInst &CI,
    203                            SmallVectorImpl<std::string> &VariantMappings);
    204 } // end namespace VFABI
    205 
    206 /// The Vector Function Database.
    207 ///
    208 /// Helper class used to find the vector functions associated to a
    209 /// scalar CallInst.
    210 class VFDatabase {
    211   /// The Module of the CallInst CI.
    212   const Module *M;
    213   /// The CallInst instance being queried for scalar to vector mappings.
    214   const CallInst &CI;
    215   /// List of vector functions descriptors associated to the call
    216   /// instruction.
    217   const SmallVector<VFInfo, 8> ScalarToVectorMappings;
    218 
    219   /// Retrieve the scalar-to-vector mappings associated to the rule of
    220   /// a vector Function ABI.
    221   static void getVFABIMappings(const CallInst &CI,
    222                                SmallVectorImpl<VFInfo> &Mappings) {
    223     if (!CI.getCalledFunction())
    224       return;
    225 
    226     const StringRef ScalarName = CI.getCalledFunction()->getName();
    227 
    228     SmallVector<std::string, 8> ListOfStrings;
    229     // The check for the vector-function-abi-variant attribute is done when
    230     // retrieving the vector variant names here.
    231     VFABI::getVectorVariantNames(CI, ListOfStrings);
    232     if (ListOfStrings.empty())
    233       return;
    234     for (const auto &MangledName : ListOfStrings) {
    235       const Optional<VFInfo> Shape =
    236           VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule()));
    237       // A match is found via scalar and vector names, and also by
    238       // ensuring that the variant described in the attribute has a
    239       // corresponding definition or declaration of the vector
    240       // function in the Module M.
    241       if (Shape.hasValue() && (Shape.getValue().ScalarName == ScalarName)) {
    242         assert(CI.getModule()->getFunction(Shape.getValue().VectorName) &&
    243                "Vector function is missing.");
    244         Mappings.push_back(Shape.getValue());
    245       }
    246     }
    247   }
    248 
    249 public:
    250   /// Retrieve all the VFInfo instances associated to the CallInst CI.
    251   static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
    252     SmallVector<VFInfo, 8> Ret;
    253 
    254     // Get mappings from the Vector Function ABI variants.
    255     getVFABIMappings(CI, Ret);
    256 
    257     // Other non-VFABI variants should be retrieved here.
    258 
    259     return Ret;
    260   }
    261 
    262   /// Constructor, requires a CallInst instance.
    263   VFDatabase(CallInst &CI)
    264       : M(CI.getModule()), CI(CI),
    265         ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
    266   /// \defgroup VFDatabase query interface.
    267   ///
    268   /// @{
    269   /// Retrieve the Function with VFShape \p Shape.
    270   Function *getVectorizedFunction(const VFShape &Shape) const {
    271     if (Shape == VFShape::getScalarShape(CI))
    272       return CI.getCalledFunction();
    273 
    274     for (const auto &Info : ScalarToVectorMappings)
    275       if (Info.Shape == Shape)
    276         return M->getFunction(Info.VectorName);
    277 
    278     return nullptr;
    279   }
    280   /// @}
    281 };
    282 
    283 template <typename T> class ArrayRef;
    284 class DemandedBits;
    285 class GetElementPtrInst;
    286 template <typename InstTy> class InterleaveGroup;
    287 class IRBuilderBase;
    288 class Loop;
    289 class ScalarEvolution;
    290 class TargetTransformInfo;
    291 class Type;
    292 class Value;
    293 
    294 namespace Intrinsic {
    295 typedef unsigned ID;
    296 }
    297 
    298 /// A helper function for converting Scalar types to vector types. If
    299 /// the incoming type is void, we return void. If the EC represents a
    300 /// scalar, we return the scalar type.
    301 inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
    302   if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
    303     return Scalar;
    304   return VectorType::get(Scalar, EC);
    305 }
    306 
    307 inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
    308   return ToVectorTy(Scalar, ElementCount::getFixed(VF));
    309 }
    310 
    311 /// Identify if the intrinsic is trivially vectorizable.
    312 /// This method returns true if the intrinsic's argument types are all scalars
    313 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
    314 /// hasVectorInstrinsicScalarOpd) for the vector form of the intrinsic.
    315 bool isTriviallyVectorizable(Intrinsic::ID ID);
    316 
    317 /// Identifies if the vector form of the intrinsic has a scalar operand.
    318 bool hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, unsigned ScalarOpdIdx);
    319 
    320 /// Returns intrinsic ID for call.
    321 /// For the input call instruction it finds mapping intrinsic and returns
    322 /// its intrinsic ID, in case it does not found it return not_intrinsic.
    323 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
    324                                           const TargetLibraryInfo *TLI);
    325 
    326 /// Find the operand of the GEP that should be checked for consecutive
    327 /// stores. This ignores trailing indices that have no effect on the final
    328 /// pointer.
    329 unsigned getGEPInductionOperand(const GetElementPtrInst *Gep);
    330 
    331 /// If the argument is a GEP, then returns the operand identified by
    332 /// getGEPInductionOperand. However, if there is some other non-loop-invariant
    333 /// operand, it returns that instead.
    334 Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
    335 
    336 /// If a value has only one user that is a CastInst, return it.
    337 Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty);
    338 
    339 /// Get the stride of a pointer access in a loop. Looks for symbolic
    340 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
    341 Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
    342 
    343 /// Given a vector and an element number, see if the scalar value is
    344 /// already around as a register, for example if it were inserted then extracted
    345 /// from the vector.
    346 Value *findScalarElement(Value *V, unsigned EltNo);
    347 
    348 /// If all non-negative \p Mask elements are the same value, return that value.
    349 /// If all elements are negative (undefined) or \p Mask contains different
    350 /// non-negative values, return -1.
    351 int getSplatIndex(ArrayRef<int> Mask);
    352 
    353 /// Get splat value if the input is a splat vector or return nullptr.
    354 /// The value may be extracted from a splat constants vector or from
    355 /// a sequence of instructions that broadcast a single value into a vector.
    356 Value *getSplatValue(const Value *V);
    357 
    358 /// Return true if each element of the vector value \p V is poisoned or equal to
    359 /// every other non-poisoned element. If an index element is specified, either
    360 /// every element of the vector is poisoned or the element at that index is not
    361 /// poisoned and equal to every other non-poisoned element.
    362 /// This may be more powerful than the related getSplatValue() because it is
    363 /// not limited by finding a scalar source value to a splatted vector.
    364 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
    365 
    366 /// Replace each shuffle mask index with the scaled sequential indices for an
    367 /// equivalent mask of narrowed elements. Mask elements that are less than 0
    368 /// (sentinel values) are repeated in the output mask.
    369 ///
    370 /// Example with Scale = 4:
    371 ///   <4 x i32> <3, 2, 0, -1> -->
    372 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
    373 ///
    374 /// This is the reverse process of widening shuffle mask elements, but it always
    375 /// succeeds because the indexes can always be multiplied (scaled up) to map to
    376 /// narrower vector elements.
    377 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
    378                            SmallVectorImpl<int> &ScaledMask);
    379 
    380 /// Try to transform a shuffle mask by replacing elements with the scaled index
    381 /// for an equivalent mask of widened elements. If all mask elements that would
    382 /// map to a wider element of the new mask are the same negative number
    383 /// (sentinel value), that element of the new mask is the same value. If any
    384 /// element in a given slice is negative and some other element in that slice is
    385 /// not the same value, return false (partial matches with sentinel values are
    386 /// not allowed).
    387 ///
    388 /// Example with Scale = 4:
    389 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
    390 ///   <4 x i32> <3, 2, 0, -1>
    391 ///
    392 /// This is the reverse process of narrowing shuffle mask elements if it
    393 /// succeeds. This transform is not always possible because indexes may not
    394 /// divide evenly (scale down) to map to wider vector elements.
    395 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
    396                           SmallVectorImpl<int> &ScaledMask);
    397 
    398 /// Compute a map of integer instructions to their minimum legal type
    399 /// size.
    400 ///
    401 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
    402 /// type (e.g. i32) whenever arithmetic is performed on them.
    403 ///
    404 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
    405 /// the arithmetic type down again. However InstCombine refuses to create
    406 /// illegal types, so for targets without i8 or i16 registers, the lengthening
    407 /// and shrinking remains.
    408 ///
    409 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
    410 /// their scalar equivalents do not, so during vectorization it is important to
    411 /// remove these lengthens and truncates when deciding the profitability of
    412 /// vectorization.
    413 ///
    414 /// This function analyzes the given range of instructions and determines the
    415 /// minimum type size each can be converted to. It attempts to remove or
    416 /// minimize type size changes across each def-use chain, so for example in the
    417 /// following code:
    418 ///
    419 ///   %1 = load i8, i8*
    420 ///   %2 = add i8 %1, 2
    421 ///   %3 = load i16, i16*
    422 ///   %4 = zext i8 %2 to i32
    423 ///   %5 = zext i16 %3 to i32
    424 ///   %6 = add i32 %4, %5
    425 ///   %7 = trunc i32 %6 to i16
    426 ///
    427 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
    428 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
    429 ///
    430 /// If the optional TargetTransformInfo is provided, this function tries harder
    431 /// to do less work by only looking at illegal types.
    432 MapVector<Instruction*, uint64_t>
    433 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
    434                          DemandedBits &DB,
    435                          const TargetTransformInfo *TTI=nullptr);
    436 
    437 /// Compute the union of two access-group lists.
    438 ///
    439 /// If the list contains just one access group, it is returned directly. If the
    440 /// list is empty, returns nullptr.
    441 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
    442 
    443 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
    444 /// are both in. If either instruction does not access memory at all, it is
    445 /// considered to be in every list.
    446 ///
    447 /// If the list contains just one access group, it is returned directly. If the
    448 /// list is empty, returns nullptr.
    449 MDNode *intersectAccessGroups(const Instruction *Inst1,
    450                               const Instruction *Inst2);
    451 
    452 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
    453 /// MD_nontemporal, MD_access_group].
    454 /// For K in Kinds, we get the MDNode for K from each of the
    455 /// elements of VL, compute their "intersection" (i.e., the most generic
    456 /// metadata value that covers all of the individual values), and set I's
    457 /// metadata for M equal to the intersection value.
    458 ///
    459 /// This function always sets a (possibly null) value for each K in Kinds.
    460 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
    461 
    462 /// Create a mask that filters the members of an interleave group where there
    463 /// are gaps.
    464 ///
    465 /// For example, the mask for \p Group with interleave-factor 3
    466 /// and \p VF 4, that has only its first member present is:
    467 ///
    468 ///   <1,0,0,1,0,0,1,0,0,1,0,0>
    469 ///
    470 /// Note: The result is a mask of 0's and 1's, as opposed to the other
    471 /// create[*]Mask() utilities which create a shuffle mask (mask that
    472 /// consists of indices).
    473 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
    474                                const InterleaveGroup<Instruction> &Group);
    475 
    476 /// Create a mask with replicated elements.
    477 ///
    478 /// This function creates a shuffle mask for replicating each of the \p VF
    479 /// elements in a vector \p ReplicationFactor times. It can be used to
    480 /// transform a mask of \p VF elements into a mask of
    481 /// \p VF * \p ReplicationFactor elements used by a predicated
    482 /// interleaved-group of loads/stores whose Interleaved-factor ==
    483 /// \p ReplicationFactor.
    484 ///
    485 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
    486 ///
    487 ///   <0,0,0,1,1,1,2,2,2,3,3,3>
    488 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
    489                                                 unsigned VF);
    490 
    491 /// Create an interleave shuffle mask.
    492 ///
    493 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
    494 /// vectorization factor \p VF into a single wide vector. The mask is of the
    495 /// form:
    496 ///
    497 ///   <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
    498 ///
    499 /// For example, the mask for VF = 4 and NumVecs = 2 is:
    500 ///
    501 ///   <0, 4, 1, 5, 2, 6, 3, 7>.
    502 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
    503 
    504 /// Create a stride shuffle mask.
    505 ///
    506 /// This function creates a shuffle mask whose elements begin at \p Start and
    507 /// are incremented by \p Stride. The mask can be used to deinterleave an
    508 /// interleaved vector into separate vectors of vectorization factor \p VF. The
    509 /// mask is of the form:
    510 ///
    511 ///   <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
    512 ///
    513 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
    514 ///
    515 ///   <0, 2, 4, 6>
    516 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
    517                                             unsigned VF);
    518 
    519 /// Create a sequential shuffle mask.
    520 ///
    521 /// This function creates shuffle mask whose elements are sequential and begin
    522 /// at \p Start.  The mask contains \p NumInts integers and is padded with \p
    523 /// NumUndefs undef values. The mask is of the form:
    524 ///
    525 ///   <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
    526 ///
    527 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
    528 ///
    529 ///   <0, 1, 2, 3, undef, undef, undef, undef>
    530 llvm::SmallVector<int, 16>
    531 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
    532 
    533 /// Concatenate a list of vectors.
    534 ///
    535 /// This function generates code that concatenate the vectors in \p Vecs into a
    536 /// single large vector. The number of vectors should be greater than one, and
    537 /// their element types should be the same. The number of elements in the
    538 /// vectors should also be the same; however, if the last vector has fewer
    539 /// elements, it will be padded with undefs.
    540 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
    541 
    542 /// Given a mask vector of i1, Return true if all of the elements of this
    543 /// predicate mask are known to be false or undef.  That is, return true if all
    544 /// lanes can be assumed inactive.
    545 bool maskIsAllZeroOrUndef(Value *Mask);
    546 
    547 /// Given a mask vector of i1, Return true if all of the elements of this
    548 /// predicate mask are known to be true or undef.  That is, return true if all
    549 /// lanes can be assumed active.
    550 bool maskIsAllOneOrUndef(Value *Mask);
    551 
    552 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
    553 /// for each lane which may be active.
    554 APInt possiblyDemandedEltsInMask(Value *Mask);
    555 
    556 /// The group of interleaved loads/stores sharing the same stride and
    557 /// close to each other.
    558 ///
    559 /// Each member in this group has an index starting from 0, and the largest
    560 /// index should be less than interleaved factor, which is equal to the absolute
    561 /// value of the access's stride.
    562 ///
    563 /// E.g. An interleaved load group of factor 4:
    564 ///        for (unsigned i = 0; i < 1024; i+=4) {
    565 ///          a = A[i];                           // Member of index 0
    566 ///          b = A[i+1];                         // Member of index 1
    567 ///          d = A[i+3];                         // Member of index 3
    568 ///          ...
    569 ///        }
    570 ///
    571 ///      An interleaved store group of factor 4:
    572 ///        for (unsigned i = 0; i < 1024; i+=4) {
    573 ///          ...
    574 ///          A[i]   = a;                         // Member of index 0
    575 ///          A[i+1] = b;                         // Member of index 1
    576 ///          A[i+2] = c;                         // Member of index 2
    577 ///          A[i+3] = d;                         // Member of index 3
    578 ///        }
    579 ///
    580 /// Note: the interleaved load group could have gaps (missing members), but
    581 /// the interleaved store group doesn't allow gaps.
    582 template <typename InstTy> class InterleaveGroup {
    583 public:
    584   InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
    585       : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
    586         InsertPos(nullptr) {}
    587 
    588   InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
    589       : Alignment(Alignment), InsertPos(Instr) {
    590     Factor = std::abs(Stride);
    591     assert(Factor > 1 && "Invalid interleave factor");
    592 
    593     Reverse = Stride < 0;
    594     Members[0] = Instr;
    595   }
    596 
    597   bool isReverse() const { return Reverse; }
    598   uint32_t getFactor() const { return Factor; }
    599   Align getAlign() const { return Alignment; }
    600   uint32_t getNumMembers() const { return Members.size(); }
    601 
    602   /// Try to insert a new member \p Instr with index \p Index and
    603   /// alignment \p NewAlign. The index is related to the leader and it could be
    604   /// negative if it is the new leader.
    605   ///
    606   /// \returns false if the instruction doesn't belong to the group.
    607   bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
    608     // Make sure the key fits in an int32_t.
    609     Optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
    610     if (!MaybeKey)
    611       return false;
    612     int32_t Key = *MaybeKey;
    613 
    614     // Skip if the key is used for either the tombstone or empty special values.
    615     if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
    616         DenseMapInfo<int32_t>::getEmptyKey() == Key)
    617       return false;
    618 
    619     // Skip if there is already a member with the same index.
    620     if (Members.find(Key) != Members.end())
    621       return false;
    622 
    623     if (Key > LargestKey) {
    624       // The largest index is always less than the interleave factor.
    625       if (Index >= static_cast<int32_t>(Factor))
    626         return false;
    627 
    628       LargestKey = Key;
    629     } else if (Key < SmallestKey) {
    630 
    631       // Make sure the largest index fits in an int32_t.
    632       Optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
    633       if (!MaybeLargestIndex)
    634         return false;
    635 
    636       // The largest index is always less than the interleave factor.
    637       if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
    638         return false;
    639 
    640       SmallestKey = Key;
    641     }
    642 
    643     // It's always safe to select the minimum alignment.
    644     Alignment = std::min(Alignment, NewAlign);
    645     Members[Key] = Instr;
    646     return true;
    647   }
    648 
    649   /// Get the member with the given index \p Index
    650   ///
    651   /// \returns nullptr if contains no such member.
    652   InstTy *getMember(uint32_t Index) const {
    653     int32_t Key = SmallestKey + Index;
    654     return Members.lookup(Key);
    655   }
    656 
    657   /// Get the index for the given member. Unlike the key in the member
    658   /// map, the index starts from 0.
    659   uint32_t getIndex(const InstTy *Instr) const {
    660     for (auto I : Members) {
    661       if (I.second == Instr)
    662         return I.first - SmallestKey;
    663     }
    664 
    665     llvm_unreachable("InterleaveGroup contains no such member");
    666   }
    667 
    668   InstTy *getInsertPos() const { return InsertPos; }
    669   void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
    670 
    671   /// Add metadata (e.g. alias info) from the instructions in this group to \p
    672   /// NewInst.
    673   ///
    674   /// FIXME: this function currently does not add noalias metadata a'la
    675   /// addNewMedata.  To do that we need to compute the intersection of the
    676   /// noalias info from all members.
    677   void addMetadata(InstTy *NewInst) const;
    678 
    679   /// Returns true if this Group requires a scalar iteration to handle gaps.
    680   bool requiresScalarEpilogue() const {
    681     // If the last member of the Group exists, then a scalar epilog is not
    682     // needed for this group.
    683     if (getMember(getFactor() - 1))
    684       return false;
    685 
    686     // We have a group with gaps. It therefore cannot be a group of stores,
    687     // and it can't be a reversed access, because such groups get invalidated.
    688     assert(!getMember(0)->mayWriteToMemory() &&
    689            "Group should have been invalidated");
    690     assert(!isReverse() && "Group should have been invalidated");
    691 
    692     // This is a group of loads, with gaps, and without a last-member
    693     return true;
    694   }
    695 
    696 private:
    697   uint32_t Factor; // Interleave Factor.
    698   bool Reverse;
    699   Align Alignment;
    700   DenseMap<int32_t, InstTy *> Members;
    701   int32_t SmallestKey = 0;
    702   int32_t LargestKey = 0;
    703 
    704   // To avoid breaking dependences, vectorized instructions of an interleave
    705   // group should be inserted at either the first load or the last store in
    706   // program order.
    707   //
    708   // E.g. %even = load i32             // Insert Position
    709   //      %add = add i32 %even         // Use of %even
    710   //      %odd = load i32
    711   //
    712   //      store i32 %even
    713   //      %odd = add i32               // Def of %odd
    714   //      store i32 %odd               // Insert Position
    715   InstTy *InsertPos;
    716 };
    717 
    718 /// Drive the analysis of interleaved memory accesses in the loop.
    719 ///
    720 /// Use this class to analyze interleaved accesses only when we can vectorize
    721 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
    722 /// on interleaved accesses is unsafe.
    723 ///
    724 /// The analysis collects interleave groups and records the relationships
    725 /// between the member and the group in a map.
    726 class InterleavedAccessInfo {
    727 public:
    728   InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
    729                         DominatorTree *DT, LoopInfo *LI,
    730                         const LoopAccessInfo *LAI)
    731       : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
    732 
    733   ~InterleavedAccessInfo() { invalidateGroups(); }
    734 
    735   /// Analyze the interleaved accesses and collect them in interleave
    736   /// groups. Substitute symbolic strides using \p Strides.
    737   /// Consider also predicated loads/stores in the analysis if
    738   /// \p EnableMaskedInterleavedGroup is true.
    739   void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
    740 
    741   /// Invalidate groups, e.g., in case all blocks in loop will be predicated
    742   /// contrary to original assumption. Although we currently prevent group
    743   /// formation for predicated accesses, we may be able to relax this limitation
    744   /// in the future once we handle more complicated blocks. Returns true if any
    745   /// groups were invalidated.
    746   bool invalidateGroups() {
    747     if (InterleaveGroups.empty()) {
    748       assert(
    749           !RequiresScalarEpilogue &&
    750           "RequiresScalarEpilog should not be set without interleave groups");
    751       return false;
    752     }
    753 
    754     InterleaveGroupMap.clear();
    755     for (auto *Ptr : InterleaveGroups)
    756       delete Ptr;
    757     InterleaveGroups.clear();
    758     RequiresScalarEpilogue = false;
    759     return true;
    760   }
    761 
    762   /// Check if \p Instr belongs to any interleave group.
    763   bool isInterleaved(Instruction *Instr) const {
    764     return InterleaveGroupMap.find(Instr) != InterleaveGroupMap.end();
    765   }
    766 
    767   /// Get the interleave group that \p Instr belongs to.
    768   ///
    769   /// \returns nullptr if doesn't have such group.
    770   InterleaveGroup<Instruction> *
    771   getInterleaveGroup(const Instruction *Instr) const {
    772     return InterleaveGroupMap.lookup(Instr);
    773   }
    774 
    775   iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
    776   getInterleaveGroups() {
    777     return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
    778   }
    779 
    780   /// Returns true if an interleaved group that may access memory
    781   /// out-of-bounds requires a scalar epilogue iteration for correctness.
    782   bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
    783 
    784   /// Invalidate groups that require a scalar epilogue (due to gaps). This can
    785   /// happen when optimizing for size forbids a scalar epilogue, and the gap
    786   /// cannot be filtered by masking the load/store.
    787   void invalidateGroupsRequiringScalarEpilogue();
    788 
    789 private:
    790   /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
    791   /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
    792   /// The interleaved access analysis can also add new predicates (for example
    793   /// by versioning strides of pointers).
    794   PredicatedScalarEvolution &PSE;
    795 
    796   Loop *TheLoop;
    797   DominatorTree *DT;
    798   LoopInfo *LI;
    799   const LoopAccessInfo *LAI;
    800 
    801   /// True if the loop may contain non-reversed interleaved groups with
    802   /// out-of-bounds accesses. We ensure we don't speculatively access memory
    803   /// out-of-bounds by executing at least one scalar epilogue iteration.
    804   bool RequiresScalarEpilogue = false;
    805 
    806   /// Holds the relationships between the members and the interleave group.
    807   DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
    808 
    809   SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
    810 
    811   /// Holds dependences among the memory accesses in the loop. It maps a source
    812   /// access to a set of dependent sink accesses.
    813   DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
    814 
    815   /// The descriptor for a strided memory access.
    816   struct StrideDescriptor {
    817     StrideDescriptor() = default;
    818     StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
    819                      Align Alignment)
    820         : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
    821 
    822     // The access's stride. It is negative for a reverse access.
    823     int64_t Stride = 0;
    824 
    825     // The scalar expression of this access.
    826     const SCEV *Scev = nullptr;
    827 
    828     // The size of the memory object.
    829     uint64_t Size = 0;
    830 
    831     // The alignment of this access.
    832     Align Alignment;
    833   };
    834 
    835   /// A type for holding instructions and their stride descriptors.
    836   using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
    837 
    838   /// Create a new interleave group with the given instruction \p Instr,
    839   /// stride \p Stride and alignment \p Align.
    840   ///
    841   /// \returns the newly created interleave group.
    842   InterleaveGroup<Instruction> *
    843   createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
    844     assert(!InterleaveGroupMap.count(Instr) &&
    845            "Already in an interleaved access group");
    846     InterleaveGroupMap[Instr] =
    847         new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
    848     InterleaveGroups.insert(InterleaveGroupMap[Instr]);
    849     return InterleaveGroupMap[Instr];
    850   }
    851 
    852   /// Release the group and remove all the relationships.
    853   void releaseGroup(InterleaveGroup<Instruction> *Group) {
    854     for (unsigned i = 0; i < Group->getFactor(); i++)
    855       if (Instruction *Member = Group->getMember(i))
    856         InterleaveGroupMap.erase(Member);
    857 
    858     InterleaveGroups.erase(Group);
    859     delete Group;
    860   }
    861 
    862   /// Collect all the accesses with a constant stride in program order.
    863   void collectConstStrideAccesses(
    864       MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
    865       const ValueToValueMap &Strides);
    866 
    867   /// Returns true if \p Stride is allowed in an interleaved group.
    868   static bool isStrided(int Stride);
    869 
    870   /// Returns true if \p BB is a predicated block.
    871   bool isPredicated(BasicBlock *BB) const {
    872     return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
    873   }
    874 
    875   /// Returns true if LoopAccessInfo can be used for dependence queries.
    876   bool areDependencesValid() const {
    877     return LAI && LAI->getDepChecker().getDependences();
    878   }
    879 
    880   /// Returns true if memory accesses \p A and \p B can be reordered, if
    881   /// necessary, when constructing interleaved groups.
    882   ///
    883   /// \p A must precede \p B in program order. We return false if reordering is
    884   /// not necessary or is prevented because \p A and \p B may be dependent.
    885   bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
    886                                                  StrideEntry *B) const {
    887     // Code motion for interleaved accesses can potentially hoist strided loads
    888     // and sink strided stores. The code below checks the legality of the
    889     // following two conditions:
    890     //
    891     // 1. Potentially moving a strided load (B) before any store (A) that
    892     //    precedes B, or
    893     //
    894     // 2. Potentially moving a strided store (A) after any load or store (B)
    895     //    that A precedes.
    896     //
    897     // It's legal to reorder A and B if we know there isn't a dependence from A
    898     // to B. Note that this determination is conservative since some
    899     // dependences could potentially be reordered safely.
    900 
    901     // A is potentially the source of a dependence.
    902     auto *Src = A->first;
    903     auto SrcDes = A->second;
    904 
    905     // B is potentially the sink of a dependence.
    906     auto *Sink = B->first;
    907     auto SinkDes = B->second;
    908 
    909     // Code motion for interleaved accesses can't violate WAR dependences.
    910     // Thus, reordering is legal if the source isn't a write.
    911     if (!Src->mayWriteToMemory())
    912       return true;
    913 
    914     // At least one of the accesses must be strided.
    915     if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
    916       return true;
    917 
    918     // If dependence information is not available from LoopAccessInfo,
    919     // conservatively assume the instructions can't be reordered.
    920     if (!areDependencesValid())
    921       return false;
    922 
    923     // If we know there is a dependence from source to sink, assume the
    924     // instructions can't be reordered. Otherwise, reordering is legal.
    925     return Dependences.find(Src) == Dependences.end() ||
    926            !Dependences.lookup(Src).count(Sink);
    927   }
    928 
    929   /// Collect the dependences from LoopAccessInfo.
    930   ///
    931   /// We process the dependences once during the interleaved access analysis to
    932   /// enable constant-time dependence queries.
    933   void collectDependences() {
    934     if (!areDependencesValid())
    935       return;
    936     auto *Deps = LAI->getDepChecker().getDependences();
    937     for (auto Dep : *Deps)
    938       Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI));
    939   }
    940 };
    941 
    942 } // llvm namespace
    943 
    944 #endif
    945