Skip to content

Commit 433ec3a

Browse files
committed
fix issue included by rebase
Signed-off-by: MengqingCao <[email protected]>
1 parent 409c83e commit 433ec3a

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,10 @@ def build(
323323

324324
query_start_loc = query_start_loc_cpu.to(self.device,
325325
non_blocking=True)
326+
is_causal_pooling = None
327+
if self.model_config.runner_type == "pooling":
328+
is_causal_pooling = common_attn_metadata.causal if hasattr(
329+
common_attn_metadata, 'causal') else True
326330

327331
attn_metadata = AscendMetadata(
328332
num_actual_tokens=num_actual_tokens,
@@ -602,9 +606,10 @@ def _forward_decode_only(
602606
out=output)
603607
return output
604608

605-
def _forward_encoder_attention(self, query: torch.Tensor, key: torch.Tensor,
606-
value: torch.Tensor, attn_metadata: AscendMetadata,
607-
_: torch.Tensor) -> torch.Tensor:
609+
def _forward_encoder_attention(self, query: torch.Tensor,
610+
key: torch.Tensor, value: torch.Tensor,
611+
attn_metadata: AscendMetadata,
612+
_: torch.Tensor) -> torch.Tensor:
608613
assert attn_metadata is not None
609614
assert attn_metadata.is_causal_pooling is not None
610615

0 commit comments

Comments
 (0)