Skip to content

Commit 9a902f3

Browse files
committed
Hoist waits above the warp specialized region.
1 parent c328bae commit 9a902f3

File tree

2 files changed

+6
-18
lines changed

2 files changed

+6
-18
lines changed

include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -608,15 +608,15 @@ class GemmUniversal<
608608
// Get pipeline stage increments from tensor shapes
609609
auto k_tile_count = size<3>(gA_mkl);
610610
611+
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
612+
cutlass::arch::wait_on_dependent_grids();
613+
#endif
611614
if (warp_group_role == WarpGroupRole::Producer) {
612615
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
613616
614617
if (producer_warp_role == ProducerWarpRole::Scheduler) {
615618
// GroupScheduler requires a producer warp to iterate over the group infos and push
616619
// the work tile infos to the downstream pipelines.
617-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
618-
cutlass::arch::wait_on_dependent_grids();
619-
#endif
620620
if constexpr (cute::is_same_v<SchedulerTag, GroupScheduler>) {
621621
do {
622622
auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(tile_scheduler_pipeline, tile_scheduler_pipe_producer_state);
@@ -630,9 +630,6 @@ class GemmUniversal<
630630
}
631631
// Mainloop Producer Warp
632632
else if (producer_warp_role == ProducerWarpRole::Mainloop) {
633-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
634-
cutlass::arch::wait_on_dependent_grids();
635-
#endif
636633
int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
637634
int32_t const mock_l_coord = 0;
638635
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
@@ -797,9 +794,6 @@ class GemmUniversal<
797794
} // Mainloop Auxiliary Load Producer Warp End
798795
// Epilogue Producer Warp
799796
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) {
800-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
801-
cutlass::arch::wait_on_dependent_grids();
802-
#endif
803797
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
804798
int32_t const sm_count = params.hw_info.sm_count;
805799

include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,9 @@ class GemmUniversal<
646646
// Get pipeline stage increments from tensor shapes
647647
auto k_tile_count = size<3>(gA_mkl);
648648
649+
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
650+
cutlass::arch::wait_on_dependent_grids();
651+
#endif
649652
if (warp_group_role == WarpGroupRole::Producer) {
650653
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
651654
@@ -665,9 +668,6 @@ class GemmUniversal<
665668
}
666669
// Mainloop Producer Warp
667670
else if (producer_warp_role == ProducerWarpRole::Mainloop) {
668-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
669-
cutlass::arch::wait_on_dependent_grids();
670-
#endif
671671
int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
672672
int32_t const mock_l_coord = 0;
673673
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
@@ -770,9 +770,6 @@ class GemmUniversal<
770770
} // Mainloop Producer Warp End
771771
else if (producer_warp_role == ProducerWarpRole::MainloopAux) {
772772
if constexpr (IsMainloopAuxiliaryLoadNeeded) {
773-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
774-
cutlass::arch::wait_on_dependent_grids();
775-
#endif
776773
int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
777774
int32_t const mock_l_coord = 0;
778775

@@ -835,9 +832,6 @@ class GemmUniversal<
835832
} // Mainloop Auxiliary Load Producer Warp End
836833
// Epilogue Producer Warp
837834
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) {
838-
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
839-
cutlass::arch::wait_on_dependent_grids();
840-
#endif
841835
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
842836
int32_t const sm_count = params.hw_info.sm_count;
843837

0 commit comments

Comments
 (0)