1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines some vectorizer utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H 14 #define LLVM_ANALYSIS_VECTORUTILS_H 15 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/Analysis/LoopAccessAnalysis.h" 19 #include "llvm/Support/CheckedArithmetic.h" 20 21 namespace llvm { 22 class TargetLibraryInfo; 23 24 /// Describes the type of Parameters 25 enum class VFParamKind { 26 Vector, // No semantic information. 27 OMP_Linear, // declare simd linear(i) 28 OMP_LinearRef, // declare simd linear(ref(i)) 29 OMP_LinearVal, // declare simd linear(val(i)) 30 OMP_LinearUVal, // declare simd linear(uval(i)) 31 OMP_LinearPos, // declare simd linear(i:c) uniform(c) 32 OMP_LinearValPos, // declare simd linear(val(i:c)) uniform(c) 33 OMP_LinearRefPos, // declare simd linear(ref(i:c)) uniform(c) 34 OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c 35 OMP_Uniform, // declare simd uniform(i) 36 GlobalPredicate, // Global logical predicate that acts on all lanes 37 // of the input and output mask concurrently. For 38 // example, it is implied by the `M` token in the 39 // Vector Function ABI mangled name. 40 Unknown 41 }; 42 43 /// Describes the type of Instruction Set Architecture 44 enum class VFISAKind { 45 AdvancedSIMD, // AArch64 Advanced SIMD (NEON) 46 SVE, // AArch64 Scalable Vector Extension 47 SSE, // x86 SSE 48 AVX, // x86 AVX 49 AVX2, // x86 AVX2 50 AVX512, // x86 AVX512 51 LLVM, // LLVM internal ISA for functions that are not 52 // attached to an existing ABI via name mangling. 53 Unknown // Unknown ISA 54 }; 55 56 /// Encapsulates information needed to describe a parameter. 57 /// 58 /// The description of the parameter is not linked directly to 59 /// OpenMP or any other vector function description. This structure 60 /// is extendible to handle other paradigms that describe vector 61 /// functions and their parameters. 62 struct VFParameter { 63 unsigned ParamPos; // Parameter Position in Scalar Function. 64 VFParamKind ParamKind; // Kind of Parameter. 65 int LinearStepOrPos = 0; // Step or Position of the Parameter. 66 Align Alignment = Align(); // Optional alignment in bytes, defaulted to 1. 67 68 // Comparison operator. 69 bool operator==(const VFParameter &Other) const { 70 return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) == 71 std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos, 72 Other.Alignment); 73 } 74 }; 75 76 /// Contains the information about the kind of vectorization 77 /// available. 78 /// 79 /// This object in independent on the paradigm used to 80 /// represent vector functions. in particular, it is not attached to 81 /// any target-specific ABI. 82 struct VFShape { 83 unsigned VF; // Vectorization factor. 84 bool IsScalable; // True if the function is a scalable function. 85 SmallVector<VFParameter, 8> Parameters; // List of parameter information. 86 // Comparison operator. 87 bool operator==(const VFShape &Other) const { 88 return std::tie(VF, IsScalable, Parameters) == 89 std::tie(Other.VF, Other.IsScalable, Other.Parameters); 90 } 91 92 /// Update the parameter in position P.ParamPos to P. 93 void updateParam(VFParameter P) { 94 assert(P.ParamPos < Parameters.size() && "Invalid parameter position."); 95 Parameters[P.ParamPos] = P; 96 assert(hasValidParameterList() && "Invalid parameter list"); 97 } 98 99 // Retrieve the VFShape that can be used to map a (scalar) function to itself, 100 // with VF = 1. 101 static VFShape getScalarShape(const CallInst &CI) { 102 return VFShape::get(CI, ElementCount::getFixed(1), 103 /*HasGlobalPredicate*/ false); 104 } 105 106 // Retrieve the basic vectorization shape of the function, where all 107 // parameters are mapped to VFParamKind::Vector with \p EC 108 // lanes. Specifies whether the function has a Global Predicate 109 // argument via \p HasGlobalPred. 110 static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) { 111 SmallVector<VFParameter, 8> Parameters; 112 for (unsigned I = 0; I < CI.arg_size(); ++I) 113 Parameters.push_back(VFParameter({I, VFParamKind::Vector})); 114 if (HasGlobalPred) 115 Parameters.push_back( 116 VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate})); 117 118 return {EC.getKnownMinValue(), EC.isScalable(), Parameters}; 119 } 120 /// Sanity check on the Parameters in the VFShape. 121 bool hasValidParameterList() const; 122 }; 123 124 /// Holds the VFShape for a specific scalar to vector function mapping. 125 struct VFInfo { 126 VFShape Shape; /// Classification of the vector function. 127 std::string ScalarName; /// Scalar Function Name. 128 std::string VectorName; /// Vector Function Name associated to this VFInfo. 129 VFISAKind ISA; /// Instruction Set Architecture. 130 }; 131 132 namespace VFABI { 133 /// LLVM Internal VFABI ISA token for vector functions. 134 static constexpr char const *_LLVM_ = "_LLVM_"; 135 /// Prefix for internal name redirection for vector function that 136 /// tells the compiler to scalarize the call using the scalar name 137 /// of the function. For example, a mangled name like 138 /// `_ZGV_LLVM_N2v_foo(_LLVM_Scalarize_foo)` would tell the 139 /// vectorizer to vectorize the scalar call `foo`, and to scalarize 140 /// it once vectorization is done. 141 static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_"; 142 143 /// Function to construct a VFInfo out of a mangled names in the 144 /// following format: 145 /// 146 /// <VFABI_name>{(<redirection>)} 147 /// 148 /// where <VFABI_name> is the name of the vector function, mangled according 149 /// to the rules described in the Vector Function ABI of the target vector 150 /// extension (or <isa> from now on). The <VFABI_name> is in the following 151 /// format: 152 /// 153 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)] 154 /// 155 /// This methods support demangling rules for the following <isa>: 156 /// 157 /// * AArch64: https://developer.arm.com/docs/101129/latest 158 /// 159 /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and 160 /// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt 161 /// 162 /// \param MangledName -> input string in the format 163 /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]. 164 /// \param M -> Module used to retrieve informations about the vector 165 /// function that are not possible to retrieve from the mangled 166 /// name. At the moment, this parameter is needed only to retrieve the 167 /// Vectorization Factor of scalable vector functions from their 168 /// respective IR declarations. 169 Optional<VFInfo> tryDemangleForVFABI(StringRef MangledName, const Module &M); 170 171 /// This routine mangles the given VectorName according to the LangRef 172 /// specification for vector-function-abi-variant attribute and is specific to 173 /// the TLI mappings. It is the responsibility of the caller to make sure that 174 /// this is only used if all parameters in the vector function are vector type. 175 /// This returned string holds scalar-to-vector mapping: 176 /// _ZGV<isa><mask><vlen><vparams>_<scalarname>(<vectorname>) 177 /// 178 /// where: 179 /// 180 /// <isa> = "_LLVM_" 181 /// <mask> = "N". Note: TLI does not support masked interfaces. 182 /// <vlen> = Number of concurrent lanes, stored in the `VectorizationFactor` 183 /// field of the `VecDesc` struct. If the number of lanes is scalable 184 /// then 'x' is printed instead. 185 /// <vparams> = "v", as many as are the numArgs. 186 /// <scalarname> = the name of the scalar function. 187 /// <vectorname> = the name of the vector function. 188 std::string mangleTLIVectorName(StringRef VectorName, StringRef ScalarName, 189 unsigned numArgs, ElementCount VF); 190 191 /// Retrieve the `VFParamKind` from a string token. 192 VFParamKind getVFParamKindFromString(const StringRef Token); 193 194 // Name of the attribute where the variant mappings are stored. 195 static constexpr char const *MappingsAttrName = "vector-function-abi-variant"; 196 197 /// Populates a set of strings representing the Vector Function ABI variants 198 /// associated to the CallInst CI. If the CI does not contain the 199 /// vector-function-abi-variant attribute, we return without populating 200 /// VariantMappings, i.e. callers of getVectorVariantNames need not check for 201 /// the presence of the attribute (see InjectTLIMappings). 202 void getVectorVariantNames(const CallInst &CI, 203 SmallVectorImpl<std::string> &VariantMappings); 204 } // end namespace VFABI 205 206 /// The Vector Function Database. 207 /// 208 /// Helper class used to find the vector functions associated to a 209 /// scalar CallInst. 210 class VFDatabase { 211 /// The Module of the CallInst CI. 212 const Module *M; 213 /// The CallInst instance being queried for scalar to vector mappings. 214 const CallInst &CI; 215 /// List of vector functions descriptors associated to the call 216 /// instruction. 217 const SmallVector<VFInfo, 8> ScalarToVectorMappings; 218 219 /// Retrieve the scalar-to-vector mappings associated to the rule of 220 /// a vector Function ABI. 221 static void getVFABIMappings(const CallInst &CI, 222 SmallVectorImpl<VFInfo> &Mappings) { 223 if (!CI.getCalledFunction()) 224 return; 225 226 const StringRef ScalarName = CI.getCalledFunction()->getName(); 227 228 SmallVector<std::string, 8> ListOfStrings; 229 // The check for the vector-function-abi-variant attribute is done when 230 // retrieving the vector variant names here. 231 VFABI::getVectorVariantNames(CI, ListOfStrings); 232 if (ListOfStrings.empty()) 233 return; 234 for (const auto &MangledName : ListOfStrings) { 235 const Optional<VFInfo> Shape = 236 VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule())); 237 // A match is found via scalar and vector names, and also by 238 // ensuring that the variant described in the attribute has a 239 // corresponding definition or declaration of the vector 240 // function in the Module M. 241 if (Shape.hasValue() && (Shape.getValue().ScalarName == ScalarName)) { 242 assert(CI.getModule()->getFunction(Shape.getValue().VectorName) && 243 "Vector function is missing."); 244 Mappings.push_back(Shape.getValue()); 245 } 246 } 247 } 248 249 public: 250 /// Retrieve all the VFInfo instances associated to the CallInst CI. 251 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) { 252 SmallVector<VFInfo, 8> Ret; 253 254 // Get mappings from the Vector Function ABI variants. 255 getVFABIMappings(CI, Ret); 256 257 // Other non-VFABI variants should be retrieved here. 258 259 return Ret; 260 } 261 262 /// Constructor, requires a CallInst instance. 263 VFDatabase(CallInst &CI) 264 : M(CI.getModule()), CI(CI), 265 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {} 266 /// \defgroup VFDatabase query interface. 267 /// 268 /// @{ 269 /// Retrieve the Function with VFShape \p Shape. 270 Function *getVectorizedFunction(const VFShape &Shape) const { 271 if (Shape == VFShape::getScalarShape(CI)) 272 return CI.getCalledFunction(); 273 274 for (const auto &Info : ScalarToVectorMappings) 275 if (Info.Shape == Shape) 276 return M->getFunction(Info.VectorName); 277 278 return nullptr; 279 } 280 /// @} 281 }; 282 283 template <typename T> class ArrayRef; 284 class DemandedBits; 285 class GetElementPtrInst; 286 template <typename InstTy> class InterleaveGroup; 287 class IRBuilderBase; 288 class Loop; 289 class ScalarEvolution; 290 class TargetTransformInfo; 291 class Type; 292 class Value; 293 294 namespace Intrinsic { 295 typedef unsigned ID; 296 } 297 298 /// A helper function for converting Scalar types to vector types. If 299 /// the incoming type is void, we return void. If the EC represents a 300 /// scalar, we return the scalar type. 301 inline Type *ToVectorTy(Type *Scalar, ElementCount EC) { 302 if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar()) 303 return Scalar; 304 return VectorType::get(Scalar, EC); 305 } 306 307 inline Type *ToVectorTy(Type *Scalar, unsigned VF) { 308 return ToVectorTy(Scalar, ElementCount::getFixed(VF)); 309 } 310 311 /// Identify if the intrinsic is trivially vectorizable. 312 /// This method returns true if the intrinsic's argument types are all scalars 313 /// for the scalar form of the intrinsic and all vectors (or scalars handled by 314 /// hasVectorInstrinsicScalarOpd) for the vector form of the intrinsic. 315 bool isTriviallyVectorizable(Intrinsic::ID ID); 316 317 /// Identifies if the vector form of the intrinsic has a scalar operand. 318 bool hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, unsigned ScalarOpdIdx); 319 320 /// Returns intrinsic ID for call. 321 /// For the input call instruction it finds mapping intrinsic and returns 322 /// its intrinsic ID, in case it does not found it return not_intrinsic. 323 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI, 324 const TargetLibraryInfo *TLI); 325 326 /// Find the operand of the GEP that should be checked for consecutive 327 /// stores. This ignores trailing indices that have no effect on the final 328 /// pointer. 329 unsigned getGEPInductionOperand(const GetElementPtrInst *Gep); 330 331 /// If the argument is a GEP, then returns the operand identified by 332 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 333 /// operand, it returns that instead. 334 Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp); 335 336 /// If a value has only one user that is a CastInst, return it. 337 Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty); 338 339 /// Get the stride of a pointer access in a loop. Looks for symbolic 340 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 341 Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp); 342 343 /// Given a vector and an element number, see if the scalar value is 344 /// already around as a register, for example if it were inserted then extracted 345 /// from the vector. 346 Value *findScalarElement(Value *V, unsigned EltNo); 347 348 /// If all non-negative \p Mask elements are the same value, return that value. 349 /// If all elements are negative (undefined) or \p Mask contains different 350 /// non-negative values, return -1. 351 int getSplatIndex(ArrayRef<int> Mask); 352 353 /// Get splat value if the input is a splat vector or return nullptr. 354 /// The value may be extracted from a splat constants vector or from 355 /// a sequence of instructions that broadcast a single value into a vector. 356 Value *getSplatValue(const Value *V); 357 358 /// Return true if each element of the vector value \p V is poisoned or equal to 359 /// every other non-poisoned element. If an index element is specified, either 360 /// every element of the vector is poisoned or the element at that index is not 361 /// poisoned and equal to every other non-poisoned element. 362 /// This may be more powerful than the related getSplatValue() because it is 363 /// not limited by finding a scalar source value to a splatted vector. 364 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); 365 366 /// Replace each shuffle mask index with the scaled sequential indices for an 367 /// equivalent mask of narrowed elements. Mask elements that are less than 0 368 /// (sentinel values) are repeated in the output mask. 369 /// 370 /// Example with Scale = 4: 371 /// <4 x i32> <3, 2, 0, -1> --> 372 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> 373 /// 374 /// This is the reverse process of widening shuffle mask elements, but it always 375 /// succeeds because the indexes can always be multiplied (scaled up) to map to 376 /// narrower vector elements. 377 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask, 378 SmallVectorImpl<int> &ScaledMask); 379 380 /// Try to transform a shuffle mask by replacing elements with the scaled index 381 /// for an equivalent mask of widened elements. If all mask elements that would 382 /// map to a wider element of the new mask are the same negative number 383 /// (sentinel value), that element of the new mask is the same value. If any 384 /// element in a given slice is negative and some other element in that slice is 385 /// not the same value, return false (partial matches with sentinel values are 386 /// not allowed). 387 /// 388 /// Example with Scale = 4: 389 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> --> 390 /// <4 x i32> <3, 2, 0, -1> 391 /// 392 /// This is the reverse process of narrowing shuffle mask elements if it 393 /// succeeds. This transform is not always possible because indexes may not 394 /// divide evenly (scale down) to map to wider vector elements. 395 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask, 396 SmallVectorImpl<int> &ScaledMask); 397 398 /// Compute a map of integer instructions to their minimum legal type 399 /// size. 400 /// 401 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int 402 /// type (e.g. i32) whenever arithmetic is performed on them. 403 /// 404 /// For targets with native i8 or i16 operations, usually InstCombine can shrink 405 /// the arithmetic type down again. However InstCombine refuses to create 406 /// illegal types, so for targets without i8 or i16 registers, the lengthening 407 /// and shrinking remains. 408 /// 409 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when 410 /// their scalar equivalents do not, so during vectorization it is important to 411 /// remove these lengthens and truncates when deciding the profitability of 412 /// vectorization. 413 /// 414 /// This function analyzes the given range of instructions and determines the 415 /// minimum type size each can be converted to. It attempts to remove or 416 /// minimize type size changes across each def-use chain, so for example in the 417 /// following code: 418 /// 419 /// %1 = load i8, i8* 420 /// %2 = add i8 %1, 2 421 /// %3 = load i16, i16* 422 /// %4 = zext i8 %2 to i32 423 /// %5 = zext i16 %3 to i32 424 /// %6 = add i32 %4, %5 425 /// %7 = trunc i32 %6 to i16 426 /// 427 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes 428 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}. 429 /// 430 /// If the optional TargetTransformInfo is provided, this function tries harder 431 /// to do less work by only looking at illegal types. 432 MapVector<Instruction*, uint64_t> 433 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks, 434 DemandedBits &DB, 435 const TargetTransformInfo *TTI=nullptr); 436 437 /// Compute the union of two access-group lists. 438 /// 439 /// If the list contains just one access group, it is returned directly. If the 440 /// list is empty, returns nullptr. 441 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2); 442 443 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2 444 /// are both in. If either instruction does not access memory at all, it is 445 /// considered to be in every list. 446 /// 447 /// If the list contains just one access group, it is returned directly. If the 448 /// list is empty, returns nullptr. 449 MDNode *intersectAccessGroups(const Instruction *Inst1, 450 const Instruction *Inst2); 451 452 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, 453 /// MD_nontemporal, MD_access_group]. 454 /// For K in Kinds, we get the MDNode for K from each of the 455 /// elements of VL, compute their "intersection" (i.e., the most generic 456 /// metadata value that covers all of the individual values), and set I's 457 /// metadata for M equal to the intersection value. 458 /// 459 /// This function always sets a (possibly null) value for each K in Kinds. 460 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL); 461 462 /// Create a mask that filters the members of an interleave group where there 463 /// are gaps. 464 /// 465 /// For example, the mask for \p Group with interleave-factor 3 466 /// and \p VF 4, that has only its first member present is: 467 /// 468 /// <1,0,0,1,0,0,1,0,0,1,0,0> 469 /// 470 /// Note: The result is a mask of 0's and 1's, as opposed to the other 471 /// create[*]Mask() utilities which create a shuffle mask (mask that 472 /// consists of indices). 473 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF, 474 const InterleaveGroup<Instruction> &Group); 475 476 /// Create a mask with replicated elements. 477 /// 478 /// This function creates a shuffle mask for replicating each of the \p VF 479 /// elements in a vector \p ReplicationFactor times. It can be used to 480 /// transform a mask of \p VF elements into a mask of 481 /// \p VF * \p ReplicationFactor elements used by a predicated 482 /// interleaved-group of loads/stores whose Interleaved-factor == 483 /// \p ReplicationFactor. 484 /// 485 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is: 486 /// 487 /// <0,0,0,1,1,1,2,2,2,3,3,3> 488 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor, 489 unsigned VF); 490 491 /// Create an interleave shuffle mask. 492 /// 493 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of 494 /// vectorization factor \p VF into a single wide vector. The mask is of the 495 /// form: 496 /// 497 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...> 498 /// 499 /// For example, the mask for VF = 4 and NumVecs = 2 is: 500 /// 501 /// <0, 4, 1, 5, 2, 6, 3, 7>. 502 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs); 503 504 /// Create a stride shuffle mask. 505 /// 506 /// This function creates a shuffle mask whose elements begin at \p Start and 507 /// are incremented by \p Stride. The mask can be used to deinterleave an 508 /// interleaved vector into separate vectors of vectorization factor \p VF. The 509 /// mask is of the form: 510 /// 511 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)> 512 /// 513 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is: 514 /// 515 /// <0, 2, 4, 6> 516 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride, 517 unsigned VF); 518 519 /// Create a sequential shuffle mask. 520 /// 521 /// This function creates shuffle mask whose elements are sequential and begin 522 /// at \p Start. The mask contains \p NumInts integers and is padded with \p 523 /// NumUndefs undef values. The mask is of the form: 524 /// 525 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs> 526 /// 527 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is: 528 /// 529 /// <0, 1, 2, 3, undef, undef, undef, undef> 530 llvm::SmallVector<int, 16> 531 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs); 532 533 /// Concatenate a list of vectors. 534 /// 535 /// This function generates code that concatenate the vectors in \p Vecs into a 536 /// single large vector. The number of vectors should be greater than one, and 537 /// their element types should be the same. The number of elements in the 538 /// vectors should also be the same; however, if the last vector has fewer 539 /// elements, it will be padded with undefs. 540 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs); 541 542 /// Given a mask vector of i1, Return true if all of the elements of this 543 /// predicate mask are known to be false or undef. That is, return true if all 544 /// lanes can be assumed inactive. 545 bool maskIsAllZeroOrUndef(Value *Mask); 546 547 /// Given a mask vector of i1, Return true if all of the elements of this 548 /// predicate mask are known to be true or undef. That is, return true if all 549 /// lanes can be assumed active. 550 bool maskIsAllOneOrUndef(Value *Mask); 551 552 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y) 553 /// for each lane which may be active. 554 APInt possiblyDemandedEltsInMask(Value *Mask); 555 556 /// The group of interleaved loads/stores sharing the same stride and 557 /// close to each other. 558 /// 559 /// Each member in this group has an index starting from 0, and the largest 560 /// index should be less than interleaved factor, which is equal to the absolute 561 /// value of the access's stride. 562 /// 563 /// E.g. An interleaved load group of factor 4: 564 /// for (unsigned i = 0; i < 1024; i+=4) { 565 /// a = A[i]; // Member of index 0 566 /// b = A[i+1]; // Member of index 1 567 /// d = A[i+3]; // Member of index 3 568 /// ... 569 /// } 570 /// 571 /// An interleaved store group of factor 4: 572 /// for (unsigned i = 0; i < 1024; i+=4) { 573 /// ... 574 /// A[i] = a; // Member of index 0 575 /// A[i+1] = b; // Member of index 1 576 /// A[i+2] = c; // Member of index 2 577 /// A[i+3] = d; // Member of index 3 578 /// } 579 /// 580 /// Note: the interleaved load group could have gaps (missing members), but 581 /// the interleaved store group doesn't allow gaps. 582 template <typename InstTy> class InterleaveGroup { 583 public: 584 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment) 585 : Factor(Factor), Reverse(Reverse), Alignment(Alignment), 586 InsertPos(nullptr) {} 587 588 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment) 589 : Alignment(Alignment), InsertPos(Instr) { 590 Factor = std::abs(Stride); 591 assert(Factor > 1 && "Invalid interleave factor"); 592 593 Reverse = Stride < 0; 594 Members[0] = Instr; 595 } 596 597 bool isReverse() const { return Reverse; } 598 uint32_t getFactor() const { return Factor; } 599 Align getAlign() const { return Alignment; } 600 uint32_t getNumMembers() const { return Members.size(); } 601 602 /// Try to insert a new member \p Instr with index \p Index and 603 /// alignment \p NewAlign. The index is related to the leader and it could be 604 /// negative if it is the new leader. 605 /// 606 /// \returns false if the instruction doesn't belong to the group. 607 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) { 608 // Make sure the key fits in an int32_t. 609 Optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey); 610 if (!MaybeKey) 611 return false; 612 int32_t Key = *MaybeKey; 613 614 // Skip if the key is used for either the tombstone or empty special values. 615 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key || 616 DenseMapInfo<int32_t>::getEmptyKey() == Key) 617 return false; 618 619 // Skip if there is already a member with the same index. 620 if (Members.find(Key) != Members.end()) 621 return false; 622 623 if (Key > LargestKey) { 624 // The largest index is always less than the interleave factor. 625 if (Index >= static_cast<int32_t>(Factor)) 626 return false; 627 628 LargestKey = Key; 629 } else if (Key < SmallestKey) { 630 631 // Make sure the largest index fits in an int32_t. 632 Optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key); 633 if (!MaybeLargestIndex) 634 return false; 635 636 // The largest index is always less than the interleave factor. 637 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor)) 638 return false; 639 640 SmallestKey = Key; 641 } 642 643 // It's always safe to select the minimum alignment. 644 Alignment = std::min(Alignment, NewAlign); 645 Members[Key] = Instr; 646 return true; 647 } 648 649 /// Get the member with the given index \p Index 650 /// 651 /// \returns nullptr if contains no such member. 652 InstTy *getMember(uint32_t Index) const { 653 int32_t Key = SmallestKey + Index; 654 return Members.lookup(Key); 655 } 656 657 /// Get the index for the given member. Unlike the key in the member 658 /// map, the index starts from 0. 659 uint32_t getIndex(const InstTy *Instr) const { 660 for (auto I : Members) { 661 if (I.second == Instr) 662 return I.first - SmallestKey; 663 } 664 665 llvm_unreachable("InterleaveGroup contains no such member"); 666 } 667 668 InstTy *getInsertPos() const { return InsertPos; } 669 void setInsertPos(InstTy *Inst) { InsertPos = Inst; } 670 671 /// Add metadata (e.g. alias info) from the instructions in this group to \p 672 /// NewInst. 673 /// 674 /// FIXME: this function currently does not add noalias metadata a'la 675 /// addNewMedata. To do that we need to compute the intersection of the 676 /// noalias info from all members. 677 void addMetadata(InstTy *NewInst) const; 678 679 /// Returns true if this Group requires a scalar iteration to handle gaps. 680 bool requiresScalarEpilogue() const { 681 // If the last member of the Group exists, then a scalar epilog is not 682 // needed for this group. 683 if (getMember(getFactor() - 1)) 684 return false; 685 686 // We have a group with gaps. It therefore cannot be a group of stores, 687 // and it can't be a reversed access, because such groups get invalidated. 688 assert(!getMember(0)->mayWriteToMemory() && 689 "Group should have been invalidated"); 690 assert(!isReverse() && "Group should have been invalidated"); 691 692 // This is a group of loads, with gaps, and without a last-member 693 return true; 694 } 695 696 private: 697 uint32_t Factor; // Interleave Factor. 698 bool Reverse; 699 Align Alignment; 700 DenseMap<int32_t, InstTy *> Members; 701 int32_t SmallestKey = 0; 702 int32_t LargestKey = 0; 703 704 // To avoid breaking dependences, vectorized instructions of an interleave 705 // group should be inserted at either the first load or the last store in 706 // program order. 707 // 708 // E.g. %even = load i32 // Insert Position 709 // %add = add i32 %even // Use of %even 710 // %odd = load i32 711 // 712 // store i32 %even 713 // %odd = add i32 // Def of %odd 714 // store i32 %odd // Insert Position 715 InstTy *InsertPos; 716 }; 717 718 /// Drive the analysis of interleaved memory accesses in the loop. 719 /// 720 /// Use this class to analyze interleaved accesses only when we can vectorize 721 /// a loop. Otherwise it's meaningless to do analysis as the vectorization 722 /// on interleaved accesses is unsafe. 723 /// 724 /// The analysis collects interleave groups and records the relationships 725 /// between the member and the group in a map. 726 class InterleavedAccessInfo { 727 public: 728 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, 729 DominatorTree *DT, LoopInfo *LI, 730 const LoopAccessInfo *LAI) 731 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} 732 733 ~InterleavedAccessInfo() { invalidateGroups(); } 734 735 /// Analyze the interleaved accesses and collect them in interleave 736 /// groups. Substitute symbolic strides using \p Strides. 737 /// Consider also predicated loads/stores in the analysis if 738 /// \p EnableMaskedInterleavedGroup is true. 739 void analyzeInterleaving(bool EnableMaskedInterleavedGroup); 740 741 /// Invalidate groups, e.g., in case all blocks in loop will be predicated 742 /// contrary to original assumption. Although we currently prevent group 743 /// formation for predicated accesses, we may be able to relax this limitation 744 /// in the future once we handle more complicated blocks. Returns true if any 745 /// groups were invalidated. 746 bool invalidateGroups() { 747 if (InterleaveGroups.empty()) { 748 assert( 749 !RequiresScalarEpilogue && 750 "RequiresScalarEpilog should not be set without interleave groups"); 751 return false; 752 } 753 754 InterleaveGroupMap.clear(); 755 for (auto *Ptr : InterleaveGroups) 756 delete Ptr; 757 InterleaveGroups.clear(); 758 RequiresScalarEpilogue = false; 759 return true; 760 } 761 762 /// Check if \p Instr belongs to any interleave group. 763 bool isInterleaved(Instruction *Instr) const { 764 return InterleaveGroupMap.find(Instr) != InterleaveGroupMap.end(); 765 } 766 767 /// Get the interleave group that \p Instr belongs to. 768 /// 769 /// \returns nullptr if doesn't have such group. 770 InterleaveGroup<Instruction> * 771 getInterleaveGroup(const Instruction *Instr) const { 772 return InterleaveGroupMap.lookup(Instr); 773 } 774 775 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>> 776 getInterleaveGroups() { 777 return make_range(InterleaveGroups.begin(), InterleaveGroups.end()); 778 } 779 780 /// Returns true if an interleaved group that may access memory 781 /// out-of-bounds requires a scalar epilogue iteration for correctness. 782 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } 783 784 /// Invalidate groups that require a scalar epilogue (due to gaps). This can 785 /// happen when optimizing for size forbids a scalar epilogue, and the gap 786 /// cannot be filtered by masking the load/store. 787 void invalidateGroupsRequiringScalarEpilogue(); 788 789 private: 790 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. 791 /// Simplifies SCEV expressions in the context of existing SCEV assumptions. 792 /// The interleaved access analysis can also add new predicates (for example 793 /// by versioning strides of pointers). 794 PredicatedScalarEvolution &PSE; 795 796 Loop *TheLoop; 797 DominatorTree *DT; 798 LoopInfo *LI; 799 const LoopAccessInfo *LAI; 800 801 /// True if the loop may contain non-reversed interleaved groups with 802 /// out-of-bounds accesses. We ensure we don't speculatively access memory 803 /// out-of-bounds by executing at least one scalar epilogue iteration. 804 bool RequiresScalarEpilogue = false; 805 806 /// Holds the relationships between the members and the interleave group. 807 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap; 808 809 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups; 810 811 /// Holds dependences among the memory accesses in the loop. It maps a source 812 /// access to a set of dependent sink accesses. 813 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; 814 815 /// The descriptor for a strided memory access. 816 struct StrideDescriptor { 817 StrideDescriptor() = default; 818 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, 819 Align Alignment) 820 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {} 821 822 // The access's stride. It is negative for a reverse access. 823 int64_t Stride = 0; 824 825 // The scalar expression of this access. 826 const SCEV *Scev = nullptr; 827 828 // The size of the memory object. 829 uint64_t Size = 0; 830 831 // The alignment of this access. 832 Align Alignment; 833 }; 834 835 /// A type for holding instructions and their stride descriptors. 836 using StrideEntry = std::pair<Instruction *, StrideDescriptor>; 837 838 /// Create a new interleave group with the given instruction \p Instr, 839 /// stride \p Stride and alignment \p Align. 840 /// 841 /// \returns the newly created interleave group. 842 InterleaveGroup<Instruction> * 843 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) { 844 assert(!InterleaveGroupMap.count(Instr) && 845 "Already in an interleaved access group"); 846 InterleaveGroupMap[Instr] = 847 new InterleaveGroup<Instruction>(Instr, Stride, Alignment); 848 InterleaveGroups.insert(InterleaveGroupMap[Instr]); 849 return InterleaveGroupMap[Instr]; 850 } 851 852 /// Release the group and remove all the relationships. 853 void releaseGroup(InterleaveGroup<Instruction> *Group) { 854 for (unsigned i = 0; i < Group->getFactor(); i++) 855 if (Instruction *Member = Group->getMember(i)) 856 InterleaveGroupMap.erase(Member); 857 858 InterleaveGroups.erase(Group); 859 delete Group; 860 } 861 862 /// Collect all the accesses with a constant stride in program order. 863 void collectConstStrideAccesses( 864 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, 865 const ValueToValueMap &Strides); 866 867 /// Returns true if \p Stride is allowed in an interleaved group. 868 static bool isStrided(int Stride); 869 870 /// Returns true if \p BB is a predicated block. 871 bool isPredicated(BasicBlock *BB) const { 872 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); 873 } 874 875 /// Returns true if LoopAccessInfo can be used for dependence queries. 876 bool areDependencesValid() const { 877 return LAI && LAI->getDepChecker().getDependences(); 878 } 879 880 /// Returns true if memory accesses \p A and \p B can be reordered, if 881 /// necessary, when constructing interleaved groups. 882 /// 883 /// \p A must precede \p B in program order. We return false if reordering is 884 /// not necessary or is prevented because \p A and \p B may be dependent. 885 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, 886 StrideEntry *B) const { 887 // Code motion for interleaved accesses can potentially hoist strided loads 888 // and sink strided stores. The code below checks the legality of the 889 // following two conditions: 890 // 891 // 1. Potentially moving a strided load (B) before any store (A) that 892 // precedes B, or 893 // 894 // 2. Potentially moving a strided store (A) after any load or store (B) 895 // that A precedes. 896 // 897 // It's legal to reorder A and B if we know there isn't a dependence from A 898 // to B. Note that this determination is conservative since some 899 // dependences could potentially be reordered safely. 900 901 // A is potentially the source of a dependence. 902 auto *Src = A->first; 903 auto SrcDes = A->second; 904 905 // B is potentially the sink of a dependence. 906 auto *Sink = B->first; 907 auto SinkDes = B->second; 908 909 // Code motion for interleaved accesses can't violate WAR dependences. 910 // Thus, reordering is legal if the source isn't a write. 911 if (!Src->mayWriteToMemory()) 912 return true; 913 914 // At least one of the accesses must be strided. 915 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) 916 return true; 917 918 // If dependence information is not available from LoopAccessInfo, 919 // conservatively assume the instructions can't be reordered. 920 if (!areDependencesValid()) 921 return false; 922 923 // If we know there is a dependence from source to sink, assume the 924 // instructions can't be reordered. Otherwise, reordering is legal. 925 return Dependences.find(Src) == Dependences.end() || 926 !Dependences.lookup(Src).count(Sink); 927 } 928 929 /// Collect the dependences from LoopAccessInfo. 930 /// 931 /// We process the dependences once during the interleaved access analysis to 932 /// enable constant-time dependence queries. 933 void collectDependences() { 934 if (!areDependencesValid()) 935 return; 936 auto *Deps = LAI->getDepChecker().getDependences(); 937 for (auto Dep : *Deps) 938 Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI)); 939 } 940 }; 941 942 } // llvm namespace 943 944 #endif 945