diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index 87f426ba92e..a72d8d95ee6 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -48,6 +48,12 @@ namespace kernels::moe_comm #define SWITCH_TOP_K(top_k, TOP_K, ...) \ switch (top_k) \ { \ + case 22: \ + { \ + constexpr int TOP_K = 22; \ + __VA_ARGS__; \ + break; \ + } \ case 16: \ { \ constexpr int TOP_K = 16; \ @@ -654,7 +660,69 @@ __device__ void vectorized_combine_impl( acc[k].load(recv_buffer + base_token + offset); } // Reduce acc[TOP_K] into acc[0] - if constexpr (TOP_K == 16) + if constexpr (TOP_K == 22) + { + T* a0 = reinterpret_cast(&acc[0]); + T* a1 = reinterpret_cast(&acc[1]); + T* a2 = reinterpret_cast(&acc[2]); + T* a3 = reinterpret_cast(&acc[3]); + T* a4 = reinterpret_cast(&acc[4]); + T* a5 = reinterpret_cast(&acc[5]); + T* a6 = reinterpret_cast(&acc[6]); + T* a7 = reinterpret_cast(&acc[7]); + T* a8 = reinterpret_cast(&acc[8]); + T* a9 = reinterpret_cast(&acc[9]); + T* a10 = reinterpret_cast(&acc[10]); + T* a11 = reinterpret_cast(&acc[11]); + T* a12 = reinterpret_cast(&acc[12]); + T* a13 = reinterpret_cast(&acc[13]); + T* a14 = reinterpret_cast(&acc[14]); + T* a15 = reinterpret_cast(&acc[15]); + T* a16 = reinterpret_cast(&acc[16]); + T* a17 = reinterpret_cast(&acc[17]); + T* a18 = reinterpret_cast(&acc[18]); + T* a19 = reinterpret_cast(&acc[19]); + T* a20 = reinterpret_cast(&acc[20]); + T* a21 = reinterpret_cast(&acc[21]); +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) + { + a0[j] += a1[j]; + a2[j] += a3[j]; + a4[j] += a5[j]; + a6[j] += a7[j]; + a8[j] += a9[j]; + a10[j] += a11[j]; + a12[j] += a13[j]; + a14[j] += a15[j]; + a16[j] += a17[j]; + a18[j] += a19[j]; + a20[j] += a21[j]; + } +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) + { + a0[j] += a2[j]; + a4[j] += a6[j]; + a8[j] += a10[j]; + a12[j] += a14[j]; + a16[j] += a18[j]; + } +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) + { + a0[j] += a4[j]; + a8[j] += a12[j]; + a16[j] += a20[j]; + } +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) + { + a0[j] += a8[j]; + a0[j] += a16[j]; + } + } + else if constexpr (TOP_K == 16) { T* a0 = reinterpret_cast(&acc[0]); T* a1 = reinterpret_cast(&acc[1]); diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h index 193a3806df6..20e68657fc9 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h @@ -26,7 +26,7 @@ namespace kernels::moe_comm { // Configuration constants -static constexpr int kMaxTopK = 16; // Maximum top-k experts per token +static constexpr int kMaxTopK = 22; // Maximum top-k experts per token static constexpr int kMaxPayloads = 4; // Maximum number of different payload types static constexpr int kMaxRanks = 64; // Maximum supported EP size diff --git a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py index c584942e492..49ed032aff1 100644 --- a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py +++ b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py @@ -565,6 +565,7 @@ def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k): (2, [100, 50], 2), (4, [32, 32, 32, 32], 4), (4, [32, 32, 32, 32], 10), # (top_k=10 is used by Qwen3-next) + (4, [32, 32, 32, 32], 22), (4, [1, 1, 1, 1], 2), (8, [640, 640, 640, 640, 640, 640, 640, 640], 4), (4, [32, 0, 16, 0], 2),