Home | History | Annotate | Line # | Download | only in ADT
      1 //===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
      2 //
      3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4 // See https://llvm.org/LICENSE.txt for license information.
      5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 //  This file implements the TypeSwitch template, which mimics a switch()
     10 //  statement whose cases are type names.
     11 //
     12 //===-----------------------------------------------------------------------===/
     13 
     14 #ifndef LLVM_ADT_TYPESWITCH_H
     15 #define LLVM_ADT_TYPESWITCH_H
     16 
     17 #include "llvm/ADT/Optional.h"
     18 #include "llvm/ADT/STLExtras.h"
     19 #include "llvm/Support/Casting.h"
     20 
     21 namespace llvm {
     22 namespace detail {
     23 
     24 template <typename DerivedT, typename T> class TypeSwitchBase {
     25 public:
     26   TypeSwitchBase(const T &value) : value(value) {}
     27   TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {}
     28   ~TypeSwitchBase() = default;
     29 
     30   /// TypeSwitchBase is not copyable.
     31   TypeSwitchBase(const TypeSwitchBase &) = delete;
     32   void operator=(const TypeSwitchBase &) = delete;
     33   void operator=(TypeSwitchBase &&other) = delete;
     34 
     35   /// Invoke a case on the derived class with multiple case types.
     36   template <typename CaseT, typename CaseT2, typename... CaseTs,
     37             typename CallableT>
     38   DerivedT &Case(CallableT &&caseFn) {
     39     DerivedT &derived = static_cast<DerivedT &>(*this);
     40     return derived.template Case<CaseT>(caseFn)
     41         .template Case<CaseT2, CaseTs...>(caseFn);
     42   }
     43 
     44   /// Invoke a case on the derived class, inferring the type of the Case from
     45   /// the first input of the given callable.
     46   /// Note: This inference rules for this overload are very simple: strip
     47   ///       pointers and references.
     48   template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
     49     using Traits = function_traits<std::decay_t<CallableT>>;
     50     using CaseT = std::remove_cv_t<std::remove_pointer_t<
     51         std::remove_reference_t<typename Traits::template arg_t<0>>>>;
     52 
     53     DerivedT &derived = static_cast<DerivedT &>(*this);
     54     return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
     55   }
     56 
     57 protected:
     58   /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
     59   /// `CastT`.
     60   template <typename ValueT, typename CastT>
     61   using has_dyn_cast_t =
     62       decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
     63 
     64   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
     65   /// selected if `value` already has a suitable dyn_cast method.
     66   template <typename CastT, typename ValueT>
     67   static auto castValue(
     68       ValueT value,
     69       typename std::enable_if_t<
     70           is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
     71     return value.template dyn_cast<CastT>();
     72   }
     73 
     74   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
     75   /// selected if llvm::dyn_cast should be used.
     76   template <typename CastT, typename ValueT>
     77   static auto castValue(
     78       ValueT value,
     79       typename std::enable_if_t<
     80           !is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
     81     return dyn_cast<CastT>(value);
     82   }
     83 
     84   /// The root value we are switching on.
     85   const T value;
     86 };
     87 } // end namespace detail
     88 
     89 /// This class implements a switch-like dispatch statement for a value of 'T'
     90 /// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked
     91 /// if the root value isa<T>, the callable is invoked with the result of
     92 /// dyn_cast<T>() as a parameter.
     93 ///
     94 /// Example:
     95 ///  Operation *op = ...;
     96 ///  LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op)
     97 ///    .Case<ConstantOp>([](ConstantOp op) { ... })
     98 ///    .Default([](Operation *op) { ... });
     99 ///
    100 template <typename T, typename ResultT = void>
    101 class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
    102 public:
    103   using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>;
    104   using BaseT::BaseT;
    105   using BaseT::Case;
    106   TypeSwitch(TypeSwitch &&other) = default;
    107 
    108   /// Add a case on the given type.
    109   template <typename CaseT, typename CallableT>
    110   TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
    111     if (result)
    112       return *this;
    113 
    114     // Check to see if CaseT applies to 'value'.
    115     if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
    116       result = caseFn(caseValue);
    117     return *this;
    118   }
    119 
    120   /// As a default, invoke the given callable within the root value.
    121   template <typename CallableT>
    122   LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) {
    123     if (result)
    124       return std::move(*result);
    125     return defaultFn(this->value);
    126   }
    127   /// As a default, return the given value.
    128   LLVM_NODISCARD ResultT Default(ResultT defaultResult) {
    129     if (result)
    130       return std::move(*result);
    131     return defaultResult;
    132   }
    133 
    134   LLVM_NODISCARD
    135   operator ResultT() {
    136     assert(result && "Fell off the end of a type-switch");
    137     return std::move(*result);
    138   }
    139 
    140 private:
    141   /// The pointer to the result of this switch statement, once known,
    142   /// null before that.
    143   Optional<ResultT> result;
    144 };
    145 
    146 /// Specialization of TypeSwitch for void returning callables.
    147 template <typename T>
    148 class TypeSwitch<T, void>
    149     : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
    150 public:
    151   using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>;
    152   using BaseT::BaseT;
    153   using BaseT::Case;
    154   TypeSwitch(TypeSwitch &&other) = default;
    155 
    156   /// Add a case on the given type.
    157   template <typename CaseT, typename CallableT>
    158   TypeSwitch<T, void> &Case(CallableT &&caseFn) {
    159     if (foundMatch)
    160       return *this;
    161 
    162     // Check to see if any of the types apply to 'value'.
    163     if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
    164       caseFn(caseValue);
    165       foundMatch = true;
    166     }
    167     return *this;
    168   }
    169 
    170   /// As a default, invoke the given callable within the root value.
    171   template <typename CallableT> void Default(CallableT &&defaultFn) {
    172     if (!foundMatch)
    173       defaultFn(this->value);
    174   }
    175 
    176 private:
    177   /// A flag detailing if we have already found a match.
    178   bool foundMatch = false;
    179 };
    180 } // end namespace llvm
    181 
    182 #endif // LLVM_ADT_TYPESWITCH_H
    183