Skip to content

Commit b0f28c1

Browse files
committed
Hoist waits above the warp specialized region.
1 parent d9f74f8 commit b0f28c1

File tree

2 files changed

+6
-18
lines changed

2 files changed

+6
-18
lines changed

include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -604,15 +604,15 @@ class GemmUniversal<
604604
// Get pipeline stage increments from tensor shapes
605605
auto k_tile_count = size<3>(gA_mkl);
606606
607+
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
608+
cutlass::arch::wait_on_dependent_grids();
609+
#endif
607610
if (warp_group_role == WarpGroupRole::Producer) {
608611
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
609612
610613
if (producer_warp_role == ProducerWarpRole::Scheduler) {
611614
// GroupScheduler requires a producer warp to iterate over the group infos and push
612615
// the work tile infos to the downstream pipelines.
613-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
614-
cutlass::arch::wait_on_dependent_grids();
615-
#endif
616616
if constexpr (cute::is_same_v<SchedulerTag, GroupScheduler>) {
617617
do {
618618
auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(tile_scheduler_pipeline, tile_scheduler_pipe_producer_state);
@@ -626,9 +626,6 @@ class GemmUniversal<
626626
}
627627
// Mainloop Producer Warp
628628
else if (producer_warp_role == ProducerWarpRole::Mainloop) {
629-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
630-
cutlass::arch::wait_on_dependent_grids();
631-
#endif
632629
int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
633630
int32_t const mock_l_coord = 0;
634631
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
@@ -793,9 +790,6 @@ class GemmUniversal<
793790
} // Mainloop Auxiliary Load Producer Warp End
794791
// Epilogue Producer Warp
795792
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) {
796-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
797-
cutlass::arch::wait_on_dependent_grids();
798-
#endif
799793
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
800794
int32_t const sm_count = params.hw_info.sm_count;
801795

include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,9 @@ class GemmUniversal<
642642
// Get pipeline stage increments from tensor shapes
643643
auto k_tile_count = size<3>(gA_mkl);
644644
645+
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
646+
cutlass::arch::wait_on_dependent_grids();
647+
#endif
645648
if (warp_group_role == WarpGroupRole::Producer) {
646649
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
647650
@@ -661,9 +664,6 @@ class GemmUniversal<
661664
}
662665
// Mainloop Producer Warp
663666
else if (producer_warp_role == ProducerWarpRole::Mainloop) {
664-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
665-
cutlass::arch::wait_on_dependent_grids();
666-
#endif
667667
int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
668668
int32_t const mock_l_coord = 0;
669669
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
@@ -766,9 +766,6 @@ class GemmUniversal<
766766
} // Mainloop Producer Warp End
767767
else if (producer_warp_role == ProducerWarpRole::MainloopAux) {
768768
if constexpr (IsMainloopAuxiliaryLoadNeeded) {
769-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
770-
cutlass::arch::wait_on_dependent_grids();
771-
#endif
772769
int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
773770
int32_t const mock_l_coord = 0;
774771

@@ -831,9 +828,6 @@ class GemmUniversal<
831828
} // Mainloop Auxiliary Load Producer Warp End
832829
// Epilogue Producer Warp
833830
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) {
834-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
835-
cutlass::arch::wait_on_dependent_grids();
836-
#endif
837831
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
838832
int32_t const sm_count = params.hw_info.sm_count;
839833

0 commit comments

Comments
 (0)