Skip to content

Commit 373f75b

Browse files
committed
fix maximum update in triton flash attn
1 parent d08ddb4 commit 373f75b

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

ring_attention_pytorch/triton_flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,10 @@ def _fwd_kernel(
231231

232232
bias = bias.to(tl.float32)
233233
qk = qk * softmax_scale + bias
234-
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
234+
m_ij = tl.maximum(tl.max(qk, 1), m_i)
235235
p = tl.exp(qk - m_ij[:, None])
236236
else:
237-
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
237+
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, m_i)
238238
p = tl.exp(qk * softmax_scale - m_ij[:, None])
239239

240240
l_ij = tl.sum(p, 1)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'ring-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.5.19',
6+
version = '0.5.20',
77
license='MIT',
88
description = 'Ring Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)