Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ class AscendMetadataForDecode:
class AscendMetadata:
# **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None
fia_attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill

Expand Down Expand Up @@ -312,21 +311,18 @@ def build(
num_actual_tokens_pcp_padded]
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
fia_attn_mask = common_attn_metadata.fia_attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
num_computed_tokens_cpu = (seq_lens - query_lens)

if attn_state == AscendAttentionState.DecodeOnly and \
common_attn_metadata.num_input_tokens > num_actual_tokens:
if common_attn_metadata.num_input_tokens > num_actual_tokens:
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
seq_lens = torch.cat([
seq_lens,
torch.ones(padded_num_tokens,
dtype=seq_lens.dtype,
device=seq_lens.device)
torch.tensor([padded_num_tokens
]).to(seq_lens.device).to(seq_lens.dtype)
])
block_table_padding = torch.zeros(
(padded_num_tokens, ) + block_table.shape[1:],
Expand All @@ -335,10 +331,8 @@ def build(
block_table = torch.cat([block_table, block_table_padding], dim=0)
query_start_loc_cpu = torch.cat([
query_start_loc_cpu,
torch.arange(query_start_loc_cpu[-1] + 1,
query_start_loc_cpu[-1] + padded_num_tokens,
dtype=query_start_loc_cpu.dtype,
device=query_start_loc_cpu.device)
torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to(
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
])

query_start_loc = query_start_loc_cpu.to(self.device,
Expand Down Expand Up @@ -471,7 +465,6 @@ def build(
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
slot_mapping=slot_mapping,
attn_mask=attn_mask,
fia_attn_mask=fia_attn_mask,
attn_state=attn_state,
num_prefills=num_prefills,
num_decodes=num_decodes,
Expand Down Expand Up @@ -604,7 +597,6 @@ def full_graph_attention(self,
actual_seq_lengths_kv = attn_metadata.seq_lens_list

num_tokens = attn_metadata.query_start_loc_list[-1]
query = query[:num_tokens]
graph_params = get_graph_params()
query_start_loc = attn_metadata.query_start_loc_list
# Prepare tensors for attention output
Expand All @@ -618,7 +610,7 @@ def full_graph_attention(self,
query=query,
key=key,
value=value,
atten_mask=attn_metadata.fia_attn_mask,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
Expand All @@ -641,7 +633,7 @@ def full_graph_attention(self,
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(query), weak_ref_tensors(key),
weak_ref_tensors(value), weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
weak_ref_tensors(attn_metadata.attn_mask), block_size,
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
self.num_heads, self.scale, weak_ref_tensors(output),
weak_ref_tensors(softmax_lse)))
Expand All @@ -651,7 +643,7 @@ def full_graph_attention(self,
query=query,
key=key,
value=value,
atten_mask=attn_metadata.fia_attn_mask,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
Expand Down
29 changes: 4 additions & 25 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.attn_groups: list[list[AttentionGroup]] = []
self.encoder_cache: Dict[str, torch.Tensor] = {}
self.attn_mask = None
self.fia_attn_mask = None
self.attn_state = None
self.requests: Dict[str, CachedRequestState] = {}
self.intermediate_tensors: Optional[IntermediateTensors] = None
Expand Down Expand Up @@ -984,23 +983,6 @@ def _make_attention_mask(self, seq_lens, position,
# Pooling situation.
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
return self.attn_mask_builder.get_pooling_mask(self.device)
# fia prefill situation.
if attn_state in [
AscendAttentionState.PrefillNoCache,
AscendAttentionState.PrefillCacheHit,
AscendAttentionState.ChunkedPrefill
]:
return self.attn_mask_builder.get_splitfuse_attn_mask()

# Decode-only situation.
return None

def _make_fia_attention_mask(self) -> torch.Tensor:
# pcp situation.
if self.pcp_size > 1:
return None
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
return self.attn_mask_builder.get_splitfuse_attn_mask()

def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
Expand Down Expand Up @@ -1581,7 +1563,6 @@ def _prepare_inputs(
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
self.fia_attn_mask = self._make_fia_attention_mask()
self.attn_state = attn_state # type: ignore

self.with_prefill = with_prefill
Expand Down Expand Up @@ -1806,7 +1787,6 @@ def _prepare_inputs(
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
Expand Down Expand Up @@ -2729,10 +2709,10 @@ def _build_dummy_attn_metadata(
self.query_lens = torch.from_numpy(num_scheduled_tokens)

assigned_mask_dim = 2048
self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim,
assigned_mask_dim),
diagonal=1).to(torch.int8).to(
self.device)
self.attn_mask = torch.triu(torch.ones(assigned_mask_dim,
assigned_mask_dim),
diagonal=1).to(torch.int8).to(
self.device)

num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
Expand Down Expand Up @@ -2776,7 +2756,6 @@ def _build_dummy_attn_metadata(
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
max_query_len=max_query_len,
Expand Down
Loading