Skip to content

Commit 44d0f95

Browse files
Mapping utilities (#5073)
Summary: Pull Request resolved: #5073 X-link: https://github.com/facebookresearch/FBGEMM/pull/2079 Compile-time static/const mapping utilities for: 1. constexpr value -> constexpr value 2. constexpr value -> type Useful when developing template-heavy cutlass code. Reviewed By: jianyuh Differential Revision: D85893168 fbshipit-source-id: 691dbb90e17c88dfc384432908e8ffdb8c0b2a04
1 parent 9ec0d72 commit 44d0f95

File tree

1 file changed

+80
-9
lines changed
  • fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective

1 file changed

+80
-9
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ to_tiled_mma_sm100_ts(
7878
TiledMMA<MMA_Atom<
7979
MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
8080
cute::C<M>, cute::C<N>,
81-
cute::integral_constant<UMMA::Major, a_major>,
82-
cute::integral_constant<UMMA::Major, b_major>,
83-
cute::integral_constant<UMMA::ScaleIn, a_neg>,
84-
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
81+
cute::integral_constant<UMMA::Major, a_major>,
82+
cute::integral_constant<UMMA::Major, b_major>,
83+
cute::integral_constant<UMMA::ScaleIn, a_neg>,
84+
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
8585
TAs...>, TMs...>) {
8686

8787
return TiledMMA<MMA_Atom<
@@ -101,10 +101,10 @@ to_tiled_mma_sm100_ts(
101101
TiledMMA<MMA_Atom<
102102
SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
103103
M, N,
104-
a_major,
105-
b_major,
106-
a_neg,
107-
b_neg>,
104+
a_major,
105+
b_major,
106+
a_neg,
107+
b_neg>,
108108
TAs...>, TMs...>) {
109109
return TiledMMA<MMA_Atom<
110110
SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
@@ -125,4 +125,75 @@ void warpgroup_reg_set() {
125125
}
126126
}
127127

128-
} // namespace cutlass::fmha::collective
128+
} // namespace cutlass::fmha::collective
129+
130+
namespace constexpr_type_map {
131+
/*
132+
* The following utility type_traits allow mapping constexpr variable to type at
133+
* compile time.
134+
* The default return type defined for each map would be returned if queried key
135+
* does not exist in the map.
136+
*/
137+
138+
template <auto keyVal, typename _valueT>
139+
struct kValTyPair {
140+
static constexpr auto key = keyVal;
141+
using valueT = _valueT;
142+
};
143+
144+
template <typename Default, typename FirstMapping, typename... OtherMapping>
145+
struct kValTyMap {
146+
template <auto QueryKey>
147+
using query = std::conditional_t<
148+
QueryKey == FirstMapping::key,
149+
typename FirstMapping::valueT,
150+
typename kValTyMap<Default, OtherMapping...>::template query<QueryKey>>;
151+
};
152+
153+
template <typename Default, typename LastMapping>
154+
struct kValTyMap<Default, LastMapping> {
155+
template <auto QueryKey>
156+
using query = std::conditional_t<
157+
QueryKey == LastMapping::key,
158+
typename LastMapping::valueT,
159+
Default>;
160+
};
161+
162+
} // namespace constexpr_type_map
163+
164+
namespace constexpr_constexpr_map {
165+
166+
template <auto keyVal, auto valueVal>
167+
struct kValValPair {
168+
static constexpr auto key = keyVal;
169+
static constexpr auto value = valueVal;
170+
};
171+
172+
template <auto Default, typename FirstMapping, typename... OtherMapping>
173+
struct kValValMap {
174+
using ValType = std::add_const_t<decltype(Default)>;
175+
static_assert(
176+
std::is_same_v<ValType, decltype(FirstMapping::value)>,
177+
"Map value type mismatch");
178+
static_assert(
179+
(std::is_same_v<ValType, decltype(OtherMapping::value)> && ...),
180+
"Map value type mismatch");
181+
template <decltype(FirstMapping::key) QueryKey>
182+
static constexpr decltype(FirstMapping::value) query =
183+
(QueryKey == FirstMapping::key)
184+
? FirstMapping::value
185+
: kValValMap<Default, OtherMapping...>::template query<QueryKey>;
186+
};
187+
188+
template <auto Default, typename LastMapping>
189+
struct kValValMap<Default, LastMapping> {
190+
using ValType = std::add_const_t<decltype(Default)>;
191+
static_assert(
192+
std::is_same_v<ValType, decltype(LastMapping::value)>,
193+
"Map value type mismatch");
194+
template <decltype(LastMapping::key) QueryKey>
195+
static constexpr decltype(LastMapping::value) query =
196+
(QueryKey == LastMapping::key) ? LastMapping::value : Default;
197+
};
198+
199+
} // namespace constexpr_constexpr_map

0 commit comments

Comments
 (0)