Skip to content

Commit a919aef

Browse files
author
weijinqian_v1
committed
[Refactor] add fia_v3 attention & remove other attention operator.
Signed-off-by: weijinqian_v1 <[email protected]>
1 parent 0250679 commit a919aef

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -890,14 +890,14 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
890890

891891
def _make_attention_mask(self, seq_lens, position,
892892
attn_state) -> torch.Tensor:
893+
if self.vllm_config.model_config.use_mla:
894+
return None
893895
# Pooling situation.
894896
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
895897
return self.attn_mask_builder.get_pooling_mask(self.device)
896-
# Chunk Prefill situation.
897-
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
898-
return self.attn_mask_builder.get_splitfuse_attn_mask()
899-
# Prefill without cache situation and Prefill with cache hit.
900-
if attn_state == AscendAttentionState.PrefillNoCache or attn_state == AscendAttentionState.PrefillCacheHit:
898+
# fia prefill situation.
899+
if attn_state in [AscendAttentionState.PrefillNoCache, AscendAttentionState.PrefillCacheHit,
900+
AscendAttentionState.ChunkedPrefill]:
901901
return self.attn_mask_builder.get_splitfuse_attn_mask()
902902
# Decode-only situation.
903903
return None

0 commit comments

Comments
 (0)