@@ -646,6 +646,9 @@ class GemmUniversal<
646646 // Get pipeline stage increments from tensor shapes
647647 auto k_tile_count = size<3>(gA_mkl);
648648
649+ #ifdef CUTLASS_ENABLE_GDC_FOR_SM90
650+ cutlass::arch::wait_on_dependent_grids();
651+ #endif
649652 if (warp_group_role == WarpGroupRole::Producer) {
650653 cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
651654
@@ -665,9 +668,6 @@ class GemmUniversal<
665668 }
666669 // Mainloop Producer Warp
667670 else if (producer_warp_role == ProducerWarpRole::Mainloop) {
668- #ifdef CUTLASS_ENABLE_GDC_FOR_SM90
669- cutlass::arch::wait_on_dependent_grids();
670- #endif
671671 int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
672672 int32_t const mock_l_coord = 0;
673673 int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
@@ -770,9 +770,6 @@ class GemmUniversal<
770770 } // Mainloop Producer Warp End
771771 else if (producer_warp_role == ProducerWarpRole::MainloopAux) {
772772 if constexpr (IsMainloopAuxiliaryLoadNeeded) {
773- #ifdef CUTLASS_ENABLE_GDC_FOR_SM90
774- cutlass::arch::wait_on_dependent_grids ();
775- #endif
776773 int32_t curr_batch = idx2crd (work_tile_info.L_idx , shape<4 >(gB_nkl )); // Usually just returns work_tile_info.L_idx;
777774 int32_t const mock_l_coord = 0 ;
778775
@@ -835,9 +832,6 @@ class GemmUniversal<
835832 } // Mainloop Auxiliary Load Producer Warp End
836833 // Epilogue Producer Warp
837834 else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) {
838- #ifdef CUTLASS_ENABLE_GDC_FOR_SM90
839- cutlass::arch::wait_on_dependent_grids ();
840- #endif
841835 int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x );
842836 int32_t const sm_count = params.hw_info .sm_count ;
843837
0 commit comments