Skip to content

Commit b99524b

Browse files
authored
fix softmax_value (#179)
1 parent 0d77476 commit b99524b

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

dlinfer/graph/dicp/vendor/AtbGraph/conversion.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ def prefill_attention(
443443
max_kv_seq_len,
444444
block_size,
445445
mask,
446+
softmax_scale,
446447
is_unpaged_prefill,
447448
kv_scales,
448449
kv_zeros,
@@ -454,7 +455,11 @@ def prefill_attention(
454455
# inplace1 = self.get_proxy(atb_op.Inplace, (fill_kv_cache, k_cache, 0))
455456
# inplace2 = self.get_proxy(atb_op.Inplace, (fill_kv_cache, v_cache, 1))
456457
mask = mask[0]
457-
scale = 1.0 / math.sqrt(query.node.meta["val"].shape[-1])
458+
scale = (
459+
softmax_scale
460+
if softmax_scale
461+
else 1.0 / math.sqrt(query.node.meta["val"].shape[-1])
462+
)
458463
if query.node.meta["val"].dtype != mask.node.meta["val"].dtype:
459464
mask = self.get_proxy(atb_op.Cast, (mask, query.node.meta["val"].dtype))
460465
if is_unpaged_prefill:

dlinfer/vendor/ascend/torch_npu_ops.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def paged_decode_attention(
195195
query = query.contiguous()
196196
attn_output = attn_output.contiguous()
197197
query = query.view(bs, 1, num_q_heads * dim)
198-
scale_value = 1.0 / math.sqrt(dim)
198+
scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(dim)
199199

200200
torch.ops.npu_ext.npu_incre_flash_attention_v4_out(
201201
query,
@@ -252,15 +252,12 @@ def paged_prefill_attention(
252252
raise RuntimeError(
253253
"paged_decode_attention does not " "support alibi_slopes yet"
254254
)
255-
if softmax_scale is not None:
256-
raise RuntimeError(
257-
"paged_decode_attention does not " "support softmax_scale yet"
258-
)
255+
259256
if block_table.dtype != torch.int32:
260257
block_table = block_table.to(torch.int32)
261258

262259
kv_seq_len_list = kv_seq_len.tolist()
263-
scale_value = 1.0 / math.sqrt(query.shape[-1])
260+
scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1])
264261
query = query.contiguous().view(query.shape[0], 1, -1)
265262
torch.ops.npu_ext.npu_incre_flash_attention_v4_out(
266263
query,

0 commit comments

Comments
 (0)