From 499740db85011460cdba5c6a2a30b1eef81b16ae Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Tue, 28 Oct 2025 13:01:26 -0700 Subject: [PATCH] Efficient Causal/Local scheduling Summary: For causal and local causal mask, we better use the order below for deterministic. ``` qi\ki ki=0 ki=1 ki=2 ki=3 ki=4 ki=5 ki=6 ki=7 ki=8 ----------------------------------------------------------- qi=0 1 0 - - - - - - - qi=1 2 1 0 - - - - - - qi=2 - 2 1 0 - - - - - qi=3 - - 2 1 0 - - - - qi=4 - - - 2 1 0 - - - qi=5 - - - - 2 1 0 - - qi=6 - - - - - 2 1 0 - qi=7 - - - - - - 2 1 0 ``` Differential Revision: D85308820 --- ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 190 ++++++++---------- 1 file changed, 86 insertions(+), 104 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index b80491497e..43c321a871 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -356,60 +356,32 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { KernelHardwareInfo hw_info; }; - // Helper function to calculate number of previous K blocks that this block - // needs to wait for - template - CUTLASS_DEVICE int calculate_participating_k_blocks( - BlkCoord const& blk_coord, + template + CUTLASS_DEVICE int compute_expected_turn( + int iter_index, + int block_k, ProblemShape_ const& problem_shape, MainloopParams const& mainloop_params) { - auto - [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = - blk_coord; - - // For local attention, we need to calculate which K blocks actually - // participate. Due to attention window properties, only early blocks can - // exit, so we can loop backwards and stop at first non-participating block. + // If mask is causal or local, reverse ordering of reduction if constexpr ( + std::is_base_of_v, Mask> || + std::is_base_of_v, Mask> || std::is_base_of_v, Mask> || std::is_base_of_v, Mask>) { - auto [Q, K, D, D_VO, HB] = problem_shape; - - int total_k_blocks = ceil_div(K, TileShapeK{}); int offset = 0; - if constexpr (std::is_base_of_v< - cutlass::fmha::collective::LocalMask, - Mask>) { - offset = K - Q; - } - - // Loop backwards to find the first non-participating block - // This is efficient because participation is contiguous after the first - // participating block - for (int k_blk = blk_coord_k - 1; k_blk >= 0; --k_blk) { - int k_max = (k_blk + 1) * TileShapeK{}; - int q_max = min(Q, k_max - offset + mainloop_params.window_size_left); - int iter_end_for_k = ceil_div(q_max, TileShapeQ{}); - - int k_min = k_blk * TileShapeK{}; - int q_min = max(0, k_min - offset - mainloop_params.window_size_right); - int iter_start_for_k = q_min / (int)TileShapeQ{}; - - if (iter_end_for_k <= iter_start_for_k) { - // Found first non-participating block from the end - // Blocks 0 through k_blk don't participate - // Blocks k_blk+1 through blk_coord_k-1 participate - return blk_coord_k - 1 - k_blk; + if (std::is_base_of_v, Mask> + || std::is_base_of_v, Mask>){ + offset = (get<1>(problem_shape) - get<0>(problem_shape)); } - } - - // If we reach here, all previous blocks participate - return blk_coord_k; - } else { - // For causal, no mask or residual mask, block x waits for x previous - // blocks - return blk_coord_k; - } + int k_global_max = cute::ceil_div(get<1>(problem_shape) , TileShapeK{}); + int k_max_for_q_block = std::min( + k_global_max, + cute::ceil_div((iter_index + 1) * TileShapeQ{} + offset + mainloop_params.window_size_right + , TileShapeK{})); + int last_k_block = k_max_for_q_block - 1; + return last_k_block - block_k; + } + return block_k; } static bool can_implement(Arguments const& args) { @@ -1506,10 +1478,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { PipelineMmaReduceDQ& pipeline_mma_reduce_dq, typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, PipelineReduceTmaStore& pipeline_reduce_tma_store, - typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state, - int max_iter_count, - int max_iter_end) { - + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state + ) { using X = Underscore; auto [Q, K, D, D_VO, HB] = problem_shape; @@ -1552,32 +1522,38 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; using Barrier = cutlass::GenericBarrier; int *lock_ptr = !IsDeterministic - ? nullptr - : (mainloop_args.ptr_dq_semaphore + blx_b * H_R * H_K + blx_h_k * H_R); - - // Calculate the actual number of participating K blocks for deterministic - // mode - int barrier_target = blk_coord_k; // Default for backward compatibility - if constexpr (IsDeterministic) { - barrier_target = calculate_participating_k_blocks( - blk_coord, problem_shape, mainloop_params); - } - - auto full_iter_count = IsDeterministic ? max_iter_count : iter_count; - auto full_iter_index = 0; + ? nullptr + : (mainloop_args.ptr_dq_semaphore + blx_b * H_R * H_K + blx_h_k * H_R); + + int expected_turn = 0; + // Optimized: Only iterate over Q blocks this K block actually processes + while (iter_count > 0) { + __threadfence(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier) + .arrive_and_wait(); - while (full_iter_count > 0) { if constexpr (IsDeterministic) { - // Wait until the semaphore flag reaches the actual number of - // participating blocks + expected_turn = compute_expected_turn( + iter_index, + blk_coord_k, + problem_shape, + mainloop_params); Barrier::wait_eq( - lock_ptr, - thread_idx, - full_iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch), - barrier_target); + lock_ptr, + thread_idx, + iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch), + expected_turn); } - if (!IsDeterministic || (full_iter_index >= iter_start && full_iter_index < iter_end)) { - pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + __threadfence(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier) + .arrive_and_wait(); + { + pipeline_mma_reduce_dq.consumer_wait( + pipeline_mma_reduce_dq_consumer_state); Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); @@ -1609,41 +1585,48 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { cutlass::arch::ReservedNamedBarriers::TransposeBarrier ).arrive_and_wait(); if (lane_predicate) { - // launch tma store - copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index,blk_coord_batch)); - pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + // TMA REDUCE ADD - atomic operation to global memory + copy( + mainloop_params.tma_red_dq, + tDQsDQ( + _, + _, + _0{}, + pipeline_reduce_tma_store_producer_state.index()), + tDQgDQ(_, _, i, iter_index, blk_coord_batch)); + /// tma_store_arrive(); + pipeline_reduce_tma_store.producer_commit( + pipeline_reduce_tma_store_producer_state); } ++pipeline_reduce_tma_store_producer_state; } - - // Update iter index - iter_index += 1; - } - + } + __threadfence(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier) + .arrive_and_wait(); if constexpr (IsDeterministic) { // Increment the semaphore flag Barrier::arrive_inc( - lock_ptr, + lock_ptr, thread_idx, - full_iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch)); - - full_iter_index += 1; + iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch)); - if (full_iter_index == max_iter_end) { - iter_index = iter_start; - full_iter_index = 0; - get<0,0>(blk_coord_batch) += 1; - } } - else { - if (iter_index == iter_end) { + iter_index += 1; + if (iter_index == iter_end) { iter_index = iter_start; get<0,0>(blk_coord_batch) += 1; - } - } + } + __threadfence(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier) + .arrive_and_wait(); - full_iter_count -= 1; + iter_count -= 1; } } @@ -1856,7 +1839,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ); int iter_end = ceil_div(get<0>(problem_shape), TileShapeQ{}); int iter_start = 0; - int max_iter_end = IsDeterministic ? iter_end : 0; if constexpr (std::is_base_of_v, Mask>) { iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; } else if constexpr (std::is_base_of_v, Mask>) { @@ -1883,7 +1865,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } int iter_count = (iter_end - iter_start) * get<4,0,0>(problem_shape); - int max_iter_count = IsDeterministic ? max_iter_end * get<4,0,0>(problem_shape) : 0; if (iter_count <= 0) { epilogue_clear( @@ -1986,13 +1967,14 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { params.mainloop, params.mainloop_params, shared_storage.tensors, - pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, - pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state, - max_iter_count, max_iter_end - ); - - pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); - } + pipeline_mma_reduce_dq, + pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, + pipeline_reduce_tma_store_producer_state); + + pipeline_reduce_tma_store.producer_tail( + pipeline_reduce_tma_store_producer_state); + } else { warpgroup_reg_set();