@@ -78,10 +78,10 @@ to_tiled_mma_sm100_ts(
7878 TiledMMA<MMA_Atom<
7979 MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
8080 cute::C<M>, cute::C<N>,
81- cute::integral_constant<UMMA::Major, a_major>,
82- cute::integral_constant<UMMA::Major, b_major>,
83- cute::integral_constant<UMMA::ScaleIn, a_neg>,
84- cute::integral_constant<UMMA::ScaleIn, b_neg>>,
81+ cute::integral_constant<UMMA::Major, a_major>,
82+ cute::integral_constant<UMMA::Major, b_major>,
83+ cute::integral_constant<UMMA::ScaleIn, a_neg>,
84+ cute::integral_constant<UMMA::ScaleIn, b_neg>>,
8585 TAs...>, TMs...>) {
8686
8787 return TiledMMA<MMA_Atom<
@@ -101,10 +101,10 @@ to_tiled_mma_sm100_ts(
101101 TiledMMA<MMA_Atom<
102102 SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
103103 M, N,
104- a_major,
105- b_major,
106- a_neg,
107- b_neg>,
104+ a_major,
105+ b_major,
106+ a_neg,
107+ b_neg>,
108108 TAs...>, TMs...>) {
109109 return TiledMMA<MMA_Atom<
110110 SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
@@ -125,4 +125,75 @@ void warpgroup_reg_set() {
125125 }
126126}
127127
128- } // namespace cutlass::fmha::collective
128+ } // namespace cutlass::fmha::collective
129+
130+ namespace constexpr_type_map {
131+ /*
132+ * The following utility type_traits allow mapping constexpr variable to type at
133+ * compile time.
134+ * The default return type defined for each map would be returned if queried key
135+ * does not exist in the map.
136+ */
137+
138+ template <auto keyVal, typename _valueT>
139+ struct kValTyPair {
140+ static constexpr auto key = keyVal;
141+ using valueT = _valueT;
142+ };
143+
144+ template <typename Default, typename FirstMapping, typename ... OtherMapping>
145+ struct kValTyMap {
146+ template <auto QueryKey>
147+ using query = std::conditional_t <
148+ QueryKey == FirstMapping::key,
149+ typename FirstMapping::valueT,
150+ typename kValTyMap <Default, OtherMapping...>::template query<QueryKey>>;
151+ };
152+
153+ template <typename Default, typename LastMapping>
154+ struct kValTyMap <Default, LastMapping> {
155+ template <auto QueryKey>
156+ using query = std::conditional_t <
157+ QueryKey == LastMapping::key,
158+ typename LastMapping::valueT,
159+ Default>;
160+ };
161+
162+ } // namespace constexpr_type_map
163+
164+ namespace constexpr_constexpr_map {
165+
166+ template <auto keyVal, auto valueVal>
167+ struct kValValPair {
168+ static constexpr auto key = keyVal;
169+ static constexpr auto value = valueVal;
170+ };
171+
172+ template <auto Default, typename FirstMapping, typename ... OtherMapping>
173+ struct kValValMap {
174+ using ValType = std::add_const_t <decltype (Default)>;
175+ static_assert (
176+ std::is_same_v<ValType, decltype (FirstMapping::value)>,
177+ " Map value type mismatch" );
178+ static_assert (
179+ (std::is_same_v<ValType, decltype (OtherMapping::value)> && ...),
180+ " Map value type mismatch" );
181+ template <decltype (FirstMapping::key) QueryKey>
182+ static constexpr decltype (FirstMapping::value) query =
183+ (QueryKey == FirstMapping::key)
184+ ? FirstMapping::value
185+ : kValValMap<Default, OtherMapping...>::template query<QueryKey>;
186+ };
187+
188+ template <auto Default, typename LastMapping>
189+ struct kValValMap <Default, LastMapping> {
190+ using ValType = std::add_const_t <decltype (Default)>;
191+ static_assert (
192+ std::is_same_v<ValType, decltype (LastMapping::value)>,
193+ " Map value type mismatch" );
194+ template <decltype (LastMapping::key) QueryKey>
195+ static constexpr decltype (LastMapping::value) query =
196+ (QueryKey == LastMapping::key) ? LastMapping::value : Default;
197+ };
198+
199+ } // namespace constexpr_constexpr_map
0 commit comments