1 //===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the TypeSwitch template, which mimics a switch()
10 // statement whose cases are type names.
12 //===-----------------------------------------------------------------------===/
14 #ifndef LLVM_ADT_TYPESWITCH_H
15 #define LLVM_ADT_TYPESWITCH_H
17 #include "llvm/ADT/Optional.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/Casting.h"
24 template <typename DerivedT, typename T> class TypeSwitchBase {
26 TypeSwitchBase(const T &value) : value(value) {}
27 TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {}
28 ~TypeSwitchBase() = default;
30 /// TypeSwitchBase is not copyable.
31 TypeSwitchBase(const TypeSwitchBase &) = delete;
32 void operator=(const TypeSwitchBase &) = delete;
33 void operator=(TypeSwitchBase &&other) = delete;
35 /// Invoke a case on the derived class with multiple case types.
36 template <typename CaseT, typename CaseT2, typename... CaseTs,
38 DerivedT &Case(CallableT &&caseFn) {
39 DerivedT &derived = static_cast<DerivedT &>(*this);
40 return derived.template Case<CaseT>(caseFn)
41 .template Case<CaseT2, CaseTs...>(caseFn);
44 /// Invoke a case on the derived class, inferring the type of the Case from
45 /// the first input of the given callable.
46 /// Note: This inference rules for this overload are very simple: strip
47 /// pointers and references.
48 template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
49 using Traits = function_traits<std::decay_t<CallableT>>;
50 using CaseT = std::remove_cv_t<std::remove_pointer_t<
51 std::remove_reference_t<typename Traits::template arg_t<0>>>>;
53 DerivedT &derived = static_cast<DerivedT &>(*this);
54 return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
58 /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
60 template <typename ValueT, typename CastT>
61 using has_dyn_cast_t =
62 decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
64 /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
65 /// selected if `value` already has a suitable dyn_cast method.
66 template <typename CastT, typename ValueT>
67 static auto castValue(
69 typename std::enable_if_t<
70 is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
71 return value.template dyn_cast<CastT>();
74 /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
75 /// selected if llvm::dyn_cast should be used.
76 template <typename CastT, typename ValueT>
77 static auto castValue(
79 typename std::enable_if_t<
80 !is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
81 return dyn_cast<CastT>(value);
84 /// The root value we are switching on.
87 } // end namespace detail
89 /// This class implements a switch-like dispatch statement for a value of 'T'
90 /// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked
91 /// if the root value isa<T>, the callable is invoked with the result of
92 /// dyn_cast<T>() as a parameter.
95 /// Operation *op = ...;
96 /// LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op)
97 /// .Case<ConstantOp>([](ConstantOp op) { ... })
98 /// .Default([](Operation *op) { ... });
100 template <typename T, typename ResultT = void>
101 class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
103 using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>;
106 TypeSwitch(TypeSwitch &&other) = default;
108 /// Add a case on the given type.
109 template <typename CaseT, typename CallableT>
110 TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
114 // Check to see if CaseT applies to 'value'.
115 if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
116 result = caseFn(caseValue);
120 /// As a default, invoke the given callable within the root value.
121 template <typename CallableT>
122 LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) {
124 return std::move(*result);
125 return defaultFn(this->value);
130 assert(result && "Fell off the end of a type-switch");
131 return std::move(*result);
135 /// The pointer to the result of this switch statement, once known,
136 /// null before that.
137 Optional<ResultT> result;
140 /// Specialization of TypeSwitch for void returning callables.
141 template <typename T>
142 class TypeSwitch<T, void>
143 : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
145 using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>;
148 TypeSwitch(TypeSwitch &&other) = default;
150 /// Add a case on the given type.
151 template <typename CaseT, typename CallableT>
152 TypeSwitch<T, void> &Case(CallableT &&caseFn) {
156 // Check to see if any of the types apply to 'value'.
157 if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
164 /// As a default, invoke the given callable within the root value.
165 template <typename CallableT> void Default(CallableT &&defaultFn) {
167 defaultFn(this->value);
171 /// A flag detailing if we have already found a match.
172 bool foundMatch = false;
174 } // end namespace llvm
176 #endif // LLVM_ADT_TYPESWITCH_H