diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index c69defbfa..e758059e1 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -994,6 +994,11 @@ class GemmUniversal< // Get next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); + + if (!next_work_tile_info.is_valid()) { + cutlass::arch::launch_dependent_grids(); + } + work_tile_info = next_work_tile_info; if (increment_pipe) { ++tile_scheduler_pipe_consumer_state; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index d380b2cd0..7120db78d 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -1038,6 +1038,11 @@ class GemmUniversal< // Get next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); + + if (!next_work_tile_info.is_valid()) { + cutlass::arch::launch_dependent_grids(); + } + work_tile_info = next_work_tile_info; if (increment_pipe) { ++tile_scheduler_pipe_consumer_state; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 5c04259bd..e9908fad5 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -804,7 +804,7 @@ class GemmUniversal< // Update starting mainloop pipeline state for the next tile mainloop_pipe_consumer_state.advance(work_k_tile_count); } - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 + if (scheduler.is_last_tile(work_tile_info)) { // Hint on an early release of global memory resources. // The timing of calling this function only influences performance, @@ -812,7 +812,6 @@ class GemmUniversal< cutlass::arch::launch_dependent_grids(); } - #endif // Index of warp group within consumer warp groups int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index 073f3a50e..f3d99d4e4 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -800,7 +800,6 @@ class GemmUniversal< else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { cutlass::arch::warpgroup_reg_alloc(); - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 // It is possible to have work tiles start off invalid, // so we have to check that first. if (not work_tile_info.is_valid()) { @@ -811,7 +810,6 @@ class GemmUniversal< return; } - #endif if constexpr (IsSchedDynamicPersistent) { // Consumer0's initial tile is static. It starts consuming the 2nd tile. @@ -868,7 +866,6 @@ class GemmUniversal< // Update starting mainloop pipeline state for the next tile mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups)) { // Hint on an early release of global memory resources. // The timing of calling this function only influences performance, @@ -876,7 +873,6 @@ class GemmUniversal< cutlass::arch::launch_dependent_grids(); } - #endif // Order two Math WG's Epilogue one after the other math_wg_order_barrier.wait();