Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion aiter/ops/triton/_triton_kernels/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def _attn_fwd_inner(
"VARLEN",
"NUM_XCD",
"USE_INT64_STRIDES",
"ENABLE_SINK",
],
)

Expand All @@ -288,6 +289,7 @@ def _attn_fwd(
s_dmask_ptr: torch.Tensor,
dropout_mask_ptr: torch.Tensor,
softmax_lse_ptr: torch.Tensor,
sink_ptr: torch.Tensor,
stride_qz_in,
stride_qh_in,
stride_qm_in,
Expand Down Expand Up @@ -341,6 +343,7 @@ def _attn_fwd(
BATCH,
NUM_XCD: tl.constexpr,
USE_INT64_STRIDES: tl.constexpr,
ENABLE_SINK: tl.constexpr,
):
NUM_BLOCKS = (SEQLEN_Q + BLOCK_M - 1) // BLOCK_M
# calculate offsets
Expand Down Expand Up @@ -631,7 +634,13 @@ def _attn_fwd(
dropout_mask_ptrs = None
philox_ptrs = None

m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
if ENABLE_SINK:
RCP_LN2: tl.constexpr = 1.4426950408889634
m_i_value = tl.load(sink_ptr + off_q_head).to(tl.float32) * RCP_LN2
else:
m_i_value = float("-inf")

m_i = tl.full([BLOCK_M], m_i_value, dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=tl.float32)
if BLOCK_DMODEL == BLOCK_DMODEL_POW2:
Expand Down
50 changes: 38 additions & 12 deletions aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _bwd_dq_inner(
V,
do,
m,
Delta,
Di, # D (= delta) is pre-divided by ds_scale.
sm_scale, # input
# shared by Q/K/V.
stride_qm,
Expand All @@ -345,7 +345,6 @@ def _bwd_dq_inner(
stride_vk,
stride_dropoutm,
stride_dropoutn, # stride for dropout
stride_deltam,
seqlen_q,
seqlen_k, #
BLOCK_M2: tl.constexpr, #
Expand Down Expand Up @@ -393,8 +392,6 @@ def _bwd_dq_inner(
if HAS_PE:
kT_pe_ptrs = K + offs_n[None, :] * stride_kn + offs_k_pe[:, None] * stride_kk
vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
curr_n = start_n
Expand Down Expand Up @@ -514,6 +511,7 @@ def _bwd_dq_inner(
"USE_EXP2",
"IS_FP8",
"USE_INT64_STRIDES",
"ENABLE_SINK",
],
)

Expand All @@ -523,11 +521,13 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
Q,
K,
V,
Sink,
sm_scale,
DO,
DQ,
DK,
DV,
DSink,
M,
Delta,
stride_qb_in,
Expand Down Expand Up @@ -603,6 +603,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
DEBUG_TRITON: tl.constexpr,
DEBUG_TRITON_DETAIL: tl.constexpr,
USE_INT64_STRIDES: tl.constexpr,
ENABLE_SINK: tl.constexpr,
):
if USE_INT64_STRIDES:
stride_qb = tl.cast(stride_qb_in, tl.int64)
Expand Down Expand Up @@ -1053,8 +1054,20 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
else:
q_pe = None
do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0)
m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q)
mask_m = offs_m < seqlen_q
m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=mask_m, other=0.0)
m = m[:, None]
delta = tl.load(Delta_ptr + offs_m * stride_deltam, mask=mask_m, other=0.0)

if ENABLE_SINK:
sink = tl.load(Sink + hqid).to(tl.float32)
if USE_EXP2:
RCP_LN2: tl.constexpr = 1.4426950408889634
psink = tl.math.exp2(sink * RCP_LN2 - m * RCP_LN2)
else:
psink = tl.math.exp(sink - m)
dsink = tl.sum(-psink * delta[:, None])
tl.atomic_add(DSink + hqid, dsink, sem="relaxed")

MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
# start can only be 0 at minimum
Expand Down Expand Up @@ -1083,7 +1096,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
V,
do,
m,
Delta_ptr,
delta,
sm_scale,
stride_qm,
stride_qd,
Expand All @@ -1093,7 +1106,6 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
stride_vd,
stride_dropoutm,
stride_dropoutn,
stride_deltam,
seqlen_q,
seqlen_k,
BLOCK_M2,
Expand Down Expand Up @@ -1139,7 +1151,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
V,
do,
m,
Delta_ptr,
delta,
sm_scale,
stride_qm,
stride_qd,
Expand All @@ -1149,7 +1161,6 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
stride_vd,
stride_dropoutm,
stride_dropoutn,
stride_deltam,
seqlen_q,
seqlen_k,
BLOCK_M2,
Expand Down Expand Up @@ -1208,6 +1219,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
"USE_EXP2",
"IS_FP8",
"USE_INT64_STRIDES",
"ENABLE_SINK",
],
)

Expand All @@ -1217,11 +1229,13 @@ def bwd_kernel_noncausal(
Q,
K,
V,
Sink,
sm_scale,
DO,
DQ,
DK,
DV,
DSink,
M,
Delta,
stride_qb_in,
Expand Down Expand Up @@ -1297,6 +1311,7 @@ def bwd_kernel_noncausal(
DEBUG_TRITON: tl.constexpr,
DEBUG_TRITON_DETAIL: tl.constexpr,
USE_INT64_STRIDES: tl.constexpr,
ENABLE_SINK: tl.constexpr,
):
if USE_INT64_STRIDES:
stride_qb = tl.cast(stride_qb_in, tl.int64)
Expand Down Expand Up @@ -1613,8 +1628,20 @@ def bwd_kernel_noncausal(
else:
q_pe = None
do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0)
m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q)
mask_m = offs_m < seqlen_q
m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=mask_m, other=0.0)
m = m[:, None]
delta = tl.load(Delta_ptr + offs_m * stride_deltam, mask=mask_m, other=0.0)

if ENABLE_SINK:
sink = tl.load(Sink + hqid).to(tl.float32)
if USE_EXP2:
RCP_LN2: tl.constexpr = 1.4426950408889634
psink = tl.math.exp2(sink * RCP_LN2 - m * RCP_LN2)
else:
psink = tl.math.exp(sink - m)
dsink = tl.sum(-psink * delta[:, None])
tl.atomic_add(DSink + hqid, dsink, sem="relaxed")

if IS_FP8:
descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid)
Expand Down Expand Up @@ -1643,7 +1670,7 @@ def bwd_kernel_noncausal(
V,
do,
m,
Delta_ptr,
delta,
sm_scale,
stride_qm,
stride_qd,
Expand All @@ -1653,7 +1680,6 @@ def bwd_kernel_noncausal(
stride_vd,
stride_dropoutm,
stride_dropoutn,
stride_deltam,
seqlen_q,
seqlen_k,
BLOCK_M2,
Expand Down
Loading