Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -356,60 +356,32 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
KernelHardwareInfo hw_info;
};

// Helper function to calculate number of previous K blocks that this block
// needs to wait for
template <class BlkCoord, class ProblemShape_>
CUTLASS_DEVICE int calculate_participating_k_blocks(
BlkCoord const& blk_coord,
template <class ProblemShape_>
CUTLASS_DEVICE int compute_expected_turn(
int iter_index,
int block_k,
ProblemShape_ const& problem_shape,
MainloopParams const& mainloop_params) {
auto
[blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] =
blk_coord;

// For local attention, we need to calculate which K blocks actually
// participate. Due to attention window properties, only early blocks can
// exit, so we can loop backwards and stop at first non-participating block.
// If mask is causal or local, reverse ordering of reduction
if constexpr (
std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::LocalMask<true>, Mask> ||
std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>) {
auto [Q, K, D, D_VO, HB] = problem_shape;

int total_k_blocks = ceil_div(K, TileShapeK{});
int offset = 0;
if constexpr (std::is_base_of_v<
cutlass::fmha::collective::LocalMask<false>,
Mask>) {
offset = K - Q;
}

// Loop backwards to find the first non-participating block
// This is efficient because participation is contiguous after the first
// participating block
for (int k_blk = blk_coord_k - 1; k_blk >= 0; --k_blk) {
int k_max = (k_blk + 1) * TileShapeK{};
int q_max = min(Q, k_max - offset + mainloop_params.window_size_left);
int iter_end_for_k = ceil_div(q_max, TileShapeQ{});

int k_min = k_blk * TileShapeK{};
int q_min = max(0, k_min - offset - mainloop_params.window_size_right);
int iter_start_for_k = q_min / (int)TileShapeQ{};

if (iter_end_for_k <= iter_start_for_k) {
// Found first non-participating block from the end
// Blocks 0 through k_blk don't participate
// Blocks k_blk+1 through blk_coord_k-1 participate
return blk_coord_k - 1 - k_blk;
if (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>
|| std::is_base_of_v<cutlass::fmha::collective::LocalMask<false>, Mask>){
offset = (get<1>(problem_shape) - get<0>(problem_shape));
}
}

// If we reach here, all previous blocks participate
return blk_coord_k;
} else {
// For causal, no mask or residual mask, block x waits for x previous
// blocks
return blk_coord_k;
}
int k_global_max = cute::ceil_div(get<1>(problem_shape) , TileShapeK{});
int k_max_for_q_block = std::min(
k_global_max,
cute::ceil_div((iter_index + 1) * TileShapeQ{} + offset + mainloop_params.window_size_right
, TileShapeK{}));
int last_k_block = k_max_for_q_block - 1;
return last_k_block - block_k;
}
return block_k;
}

static bool can_implement(Arguments const& args) {
Expand Down Expand Up @@ -1506,10 +1478,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
PipelineMmaReduceDQ& pipeline_mma_reduce_dq,
typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state,
PipelineReduceTmaStore& pipeline_reduce_tma_store,
typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state,
int max_iter_count,
int max_iter_end) {

typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state
) {
using X = Underscore;

auto [Q, K, D, D_VO, HB] = problem_shape;
Expand Down Expand Up @@ -1552,32 +1522,38 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0;
using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
int *lock_ptr = !IsDeterministic
? nullptr
: (mainloop_args.ptr_dq_semaphore + blx_b * H_R * H_K + blx_h_k * H_R);

// Calculate the actual number of participating K blocks for deterministic
// mode
int barrier_target = blk_coord_k; // Default for backward compatibility
if constexpr (IsDeterministic) {
barrier_target = calculate_participating_k_blocks(
blk_coord, problem_shape, mainloop_params);
}

auto full_iter_count = IsDeterministic ? max_iter_count : iter_count;
auto full_iter_index = 0;
? nullptr
: (mainloop_args.ptr_dq_semaphore + blx_b * H_R * H_K + blx_h_k * H_R);

int expected_turn = 0;
// Optimized: Only iterate over Q blocks this K block actually processes
while (iter_count > 0) {
__threadfence();
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier)
.arrive_and_wait();

while (full_iter_count > 0) {
if constexpr (IsDeterministic) {
// Wait until the semaphore flag reaches the actual number of
// participating blocks
expected_turn = compute_expected_turn(
iter_index,
blk_coord_k,
problem_shape,
mainloop_params);
Barrier::wait_eq(
lock_ptr,
thread_idx,
full_iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch),
barrier_target);
lock_ptr,
thread_idx,
iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch),
expected_turn);
}
if (!IsDeterministic || (full_iter_index >= iter_start && full_iter_index < iter_end)) {
pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state);
__threadfence();
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier)
.arrive_and_wait();
{
pipeline_mma_reduce_dq.consumer_wait(
pipeline_mma_reduce_dq_consumer_state);

Tensor tTR_rDQ = make_tensor<ElementAcc>(shape(tTR_cDQ));

Expand Down Expand Up @@ -1609,41 +1585,48 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
cutlass::arch::ReservedNamedBarriers::TransposeBarrier
).arrive_and_wait();
if (lane_predicate) {
// launch tma store
copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index,blk_coord_batch));
pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state);
// TMA REDUCE ADD - atomic operation to global memory
copy(
mainloop_params.tma_red_dq,
tDQsDQ(
_,
_,
_0{},
pipeline_reduce_tma_store_producer_state.index()),
tDQgDQ(_, _, i, iter_index, blk_coord_batch));
/// tma_store_arrive();
pipeline_reduce_tma_store.producer_commit(
pipeline_reduce_tma_store_producer_state);
}

++pipeline_reduce_tma_store_producer_state;
}

// Update iter index
iter_index += 1;
}

}
__threadfence();
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier)
.arrive_and_wait();
if constexpr (IsDeterministic) {
// Increment the semaphore flag
Barrier::arrive_inc(
lock_ptr,
lock_ptr,
thread_idx,
full_iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch));

full_iter_index += 1;
iter_index * H_R * H_K * B + get<0, 0>(blk_coord_batch));

if (full_iter_index == max_iter_end) {
iter_index = iter_start;
full_iter_index = 0;
get<0,0>(blk_coord_batch) += 1;
}
}
else {
if (iter_index == iter_end) {
iter_index += 1;
if (iter_index == iter_end) {
iter_index = iter_start;
get<0,0>(blk_coord_batch) += 1;
}
}
}
__threadfence();
cutlass::arch::NamedBarrier(
kNumReduceWarps * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::TransposeBarrier)
.arrive_and_wait();

full_iter_count -= 1;
iter_count -= 1;
}
}

Expand Down Expand Up @@ -1856,7 +1839,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
);
int iter_end = ceil_div(get<0>(problem_shape), TileShapeQ{});
int iter_start = 0;
int max_iter_end = IsDeterministic ? iter_end : 0;
if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<true>, Mask>) {
iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{};
} else if constexpr (std::is_base_of_v<cutlass::fmha::collective::CausalMask<false>, Mask>) {
Expand All @@ -1883,7 +1865,6 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
}

int iter_count = (iter_end - iter_start) * get<4,0,0>(problem_shape);
int max_iter_count = IsDeterministic ? max_iter_end * get<4,0,0>(problem_shape) : 0;

if (iter_count <= 0) {
epilogue_clear(
Expand Down Expand Up @@ -1986,13 +1967,14 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state,
pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state,
max_iter_count, max_iter_end
);

pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state);
}
pipeline_mma_reduce_dq,
pipeline_mma_reduce_dq_consumer_state,
pipeline_reduce_tma_store,
pipeline_reduce_tma_store_producer_state);

pipeline_reduce_tma_store.producer_tail(
pipeline_reduce_tma_store_producer_state);
}
else {
warpgroup_reg_set<RegisterAllocation::kEmpty>();

Expand Down
Loading