Skip to content

Commit b0a83c0

Browse files
committed
Support PDL in sm90_gemm_array_tma_warpspecialized_cooperative
1 parent 6513a85 commit b0a83c0

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,9 @@ class GemmUniversal<
610610
if (producer_warp_role == ProducerWarpRole::Scheduler) {
611611
// GroupScheduler requires a producer warp to iterate over the group infos and push
612612
// the work tile infos to the downstream pipelines.
613+
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
614+
cutlass::arch::wait_on_dependent_grids();
615+
#endif
613616
if constexpr (cute::is_same_v<SchedulerTag, GroupScheduler>) {
614617
do {
615618
auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(tile_scheduler_pipeline, tile_scheduler_pipe_producer_state);
@@ -623,6 +626,9 @@ class GemmUniversal<
623626
}
624627
// Mainloop Producer Warp
625628
else if (producer_warp_role == ProducerWarpRole::Mainloop) {
629+
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
630+
cutlass::arch::wait_on_dependent_grids();
631+
#endif
626632
int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx;
627633
int32_t const mock_l_coord = 0;
628634
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
@@ -787,6 +793,9 @@ class GemmUniversal<
787793
} // Mainloop Auxiliary Load Producer Warp End
788794
// Epilogue Producer Warp
789795
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) {
796+
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
797+
cutlass::arch::wait_on_dependent_grids();
798+
#endif
790799
int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x);
791800
int32_t const sm_count = params.hw_info.sm_count;
792801

@@ -990,6 +999,11 @@ class GemmUniversal<
990999

9911000
// Get next work tile
9921001
auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state);
1002+
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
1003+
if (!next_work_tile_info.is_valid()) {
1004+
cutlass::arch::launch_dependent_grids();
1005+
}
1006+
#endif
9931007
work_tile_info = next_work_tile_info;
9941008
if (increment_pipe) {
9951009
++tile_scheduler_pipe_consumer_state;

0 commit comments

Comments
 (0)