@@ -642,6 +642,9 @@ class GemmUniversal<
642642 // Get pipeline stage increments from tensor shapes
643643 auto k_tile_count = size<3>(gA_mkl);
644644
645+ #ifdef CUTLASS_ENABLE_GDC_FOR_SM90
646+ cutlass::arch::wait_on_dependent_grids();
647+ #endif
645648 if (warp_group_role == WarpGroupRole::Producer) {
646649 cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
647650
@@ -661,9 +664,6 @@ class GemmUniversal<
661664 }
662665 // Mainloop Producer Warp
663666 else if (producer_warp_role == ProducerWarpRole::Mainloop) {
664- #ifdef CUTLASS_ENABLE_GDC_FOR_SM90
665- cutlass::arch::wait_on_dependent_grids();
666- #endif
667667 int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
668668 int32_t const mock_l_coord = 0;
669669 int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
@@ -766,9 +766,6 @@ class GemmUniversal<
766766 } // Mainloop Producer Warp End
767767 else if (producer_warp_role == ProducerWarpRole::MainloopAux) {
768768 if constexpr (IsMainloopAuxiliaryLoadNeeded) {
769- #ifdef CUTLASS_ENABLE_GDC_FOR_SM90
770- cutlass::arch::wait_on_dependent_grids ();
771- #endif
772769 int32_t curr_batch = idx2crd (work_tile_info.L_idx , shape<4 >(gB_nkl )); // Usually just returns work_tile_info.L_idx;
773770 int32_t const mock_l_coord = 0 ;
774771
@@ -831,9 +828,6 @@ class GemmUniversal<
831828 } // Mainloop Auxiliary Load Producer Warp End
832829 // Epilogue Producer Warp
833830 else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) {
834- #ifdef CUTLASS_ENABLE_GDC_FOR_SM90
835- cutlass::arch::wait_on_dependent_grids ();
836- #endif
837831 int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x );
838832 int32_t const sm_count = params.hw_info .sm_count ;
839833
0 commit comments