Skip to content

Commit 5ba34e4

Browse files
committed
Optimize memory reuse and reduce data copying
Signed-off-by: anon189Ty <[email protected]>
1 parent 343b465 commit 5ba34e4

File tree

2 files changed

+29
-65
lines changed

2 files changed

+29
-65
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,8 @@ def build(
451451
num_reqs_pad_size = graph_pad_size - num_reqs
452452
actual_seq_lengths_q = self.pad_actual_seq_len_q(
453453
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
454-
seq_lens_list = seq_lens_list + [0] * num_reqs_pad_size
454+
seq_lens_list = seq_lens_list + [0] * (
455+
graph_pad_size - num_decodes)
455456
num_block_pad_size = graph_pad_size - block_table.shape[0]
456457
if num_block_pad_size > 0:
457458
block_table_padding = torch.zeros(

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 27 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,6 @@ def __init__(
7575
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
7676
"index_topk")
7777

78-
self.query_start_loc = torch.zeros(
79-
self.runner.max_num_reqs * (self.num_speculative_tokens + 1) + 1,
80-
dtype=torch.int32,
81-
device=self.device)
82-
self.query_start_loc_cpu = torch.zeros(
83-
self.runner.max_num_reqs * (self.num_speculative_tokens + 1) + 1,
84-
dtype=torch.int32,
85-
device="cpu",
86-
pin_memory=True)
87-
self.slot_mapping = torch.zeros(self.runner.max_num_tokens,
88-
dtype=torch.int32,
89-
device=self.device)
90-
self.seq_lens_cpu = torch.zeros(self.runner.max_num_reqs *
91-
(self.num_speculative_tokens + 1),
92-
dtype=torch.int32,
93-
device="cpu",
94-
pin_memory=True)
95-
9678
def load_model(self, model) -> None:
9779
loader = get_model_loader(self.vllm_config.load_config)
9880

@@ -177,8 +159,6 @@ def dummy_run(self,
177159
# assert with_prefill is False, \
178160
# "Full decode graph only supports uniform batch now."
179161
max_seq_lens = self.runner.model_config.max_model_len
180-
self.seq_lens_cpu[:num_reqs] = max_seq_lens
181-
self.seq_lens_cpu[num_reqs:] = 0
182162
if len(self.runner.attn_groups) > 0:
183163
num_computed_tokens_cpu = (
184164
self.runner.input_batch.
@@ -187,22 +167,24 @@ def dummy_run(self,
187167
[0] + self.runner.actual_seq_lengths_q[:num_reqs],
188168
device=self.runner.device,
189169
dtype=torch.int32)
190-
self.query_start_loc[:num_reqs + 1].copy_(query_start_loc)
191170
common_attn_metadata = AscendCommonAttentionMetadata(
192-
query_start_loc=self.query_start_loc[:num_reqs + 1],
171+
query_start_loc=torch.tensor(
172+
[0] + self.runner.actual_seq_lengths_q[:num_reqs],
173+
device=self.device,
174+
dtype=torch.int32),
193175
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
194176
1],
195-
seq_lens_cpu=self.seq_lens_cpu,
196-
seq_lens=self.seq_lens_cpu[:num_reqs],
177+
seq_lens_cpu=self.runner.seq_lens_cpu,
178+
seq_lens=self.runner.seq_lens_cpu[:num_reqs],
197179
num_reqs=num_reqs,
198180
num_actual_tokens=num_tokens,
199181
max_query_len=self.num_speculative_tokens + 1,
200182
num_computed_tokens_cpu=num_computed_tokens_cpu,
201183
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
202184
block_table_tensor=self.runner.input_batch.block_table[0].
203185
get_device_tensor()[:num_reqs],
204-
slot_mapping=self.slot_mapping,
205-
positions=self.positions,
186+
slot_mapping=self.runner.input_batch.block_table[0].slot_mapping,
187+
positions=self.runner.positions,
206188
attn_mask=self.runner.attn_mask,
207189
spec_attn_mask=self.runner.spec_attn_mask,
208190
attn_state=self.runner.attn_state,
@@ -319,22 +301,6 @@ def generate_token_ids(self,
319301
target_hidden_states = hidden_states[:num_scheduled_tokens]
320302
target_slot_mapping = attn_metadata.slot_mapping
321303
cu_num_tokens = attn_metadata.query_start_loc
322-
323-
query_start_loc_num = len(cu_num_tokens)
324-
self.query_start_loc[:query_start_loc_num].copy_(
325-
cu_num_tokens[:query_start_loc_num])
326-
self.query_start_loc[query_start_loc_num:].fill_(0)
327-
self.query_start_loc_cpu[:query_start_loc_num].copy_(
328-
self.query_start_loc[:query_start_loc_num], non_blocking=True)
329-
self.query_start_loc_cpu[query_start_loc_num:].fill_(0)
330-
331-
target_slot_mapping_len = target_slot_mapping.shape[0]
332-
self.slot_mapping[:target_slot_mapping_len].copy_(
333-
target_slot_mapping)
334-
self.slot_mapping[target_slot_mapping_len:].fill_(0)
335-
target_positions_len = target_positions.shape[0]
336-
self.positions[:target_positions_len].copy_(target_positions)
337-
self.positions[target_positions_len:].fill_(0)
338304
else:
339305
# TODO(woosuk): Refactor this.
340306
num_draft_tokens = spec_decode_metadata.num_draft_tokens
@@ -432,20 +398,6 @@ def _prepare_inputs(
432398
target_hidden_states = hidden_states[token_indices]
433399
target_slot_mapping = slot_mapping[token_indices]
434400

435-
batch_size = num_rejected_tokens.shape[0]
436-
self.query_start_loc[:batch_size + 1].copy_(cu_num_tokens[:batch_size +
437-
1])
438-
self.query_start_loc[batch_size + 1:].fill_(0)
439-
self.query_start_loc_cpu[:batch_size + 1].copy_(
440-
self.query_start_loc[:batch_size + 1], non_blocking=True)
441-
self.query_start_loc_cpu[batch_size + 1:].fill_(0)
442-
target_positions_len = target_positions.shape[0]
443-
self.positions[:target_positions_len].copy_(target_positions)
444-
self.positions[target_positions_len:].fill_(0)
445-
target_slot_mapping_len = target_slot_mapping.shape[0]
446-
self.slot_mapping[:target_slot_mapping_len].copy_(target_slot_mapping)
447-
self.slot_mapping[target_slot_mapping_len:].fill_(0)
448-
449401
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
450402

451403
def _propose(
@@ -517,8 +469,6 @@ def _propose(
517469
seq_lens = target_positions[last_token_indices] + 1
518470
seq_lens = seq_lens.int()
519471
seq_lens_len = seq_lens.shape[0]
520-
self.seq_lens_cpu[:seq_lens_len].copy_(seq_lens, non_blocking=True)
521-
self.seq_lens_cpu[seq_lens_len:].fill_(0)
522472

523473
if not self.torchair_graph_enabled:
524474
# torch mode need to update num_tokens_across_dp
@@ -552,18 +502,27 @@ def _propose(
552502
# Currently, if not torchair, runner.graph_pad_size will always be -1.
553503
graph_pad_size = self.runner.graph_pad_size
554504

505+
runner_slot_mapping = self.runner.input_batch.block_table[0].slot_mapping
506+
runner_slot_mapping[:target_slot_mapping.shape[0]].copy_(target_slot_mapping)
507+
runner_slot_mapping[target_slot_mapping.shape[0]:num_input_tokens].fill_(0)
508+
509+
# NOTE: Currently, just positions, slot_mapping, block_table and
510+
# seq_lens will be sent into MLAMetadata.
511+
# But only block_table and slot_mapping will be used, actually.
512+
# So we only fixed the block_table and slot_mapping's address.
513+
# If attention need to use other params one day, they should be fixed too.
555514
common_attn_metadata = AscendCommonAttentionMetadata(
556-
query_start_loc=self.query_start_loc[:batch_size + 1],
557-
query_start_loc_cpu=self.query_start_loc_cpu[:batch_size + 1],
558-
seq_lens_cpu=self.seq_lens_cpu[:seq_lens_len],
515+
query_start_loc=cu_num_tokens[:batch_size + 1],
516+
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
517+
seq_lens_cpu=seq_lens.cpu(),
559518
num_reqs=batch_size,
560519
num_actual_tokens=num_tokens,
561520
max_query_len=max_query_len,
562521
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
563522
block_table_tensor=self.runner.input_batch.block_table[0].
564523
get_device_tensor(),
565-
slot_mapping=self.slot_mapping[:target_slot_mapping.shape[0]],
566-
positions=self.positions[:target_positions.shape[0]],
524+
slot_mapping=runner_slot_mapping,
525+
positions=target_positions,
567526
attn_mask=self.runner.attn_mask,
568527
spec_attn_mask=self.runner.spec_attn_mask,
569528
attn_state=self.runner.attn_state,
@@ -585,6 +544,7 @@ def _propose(
585544
attn_metadata = self.runner.attn_metadata_builder.build(
586545
0, common_attn_metadata, self.runner.get_model())
587546

547+
self.positions[:num_tokens] = target_positions
588548
self.hidden_states[:num_tokens] = target_hidden_states
589549
self.hidden_states[num_tokens:].fill_(0)
590550

@@ -734,7 +694,10 @@ def _propose(
734694
self.positions[:batch_size] = clamped_positions
735695
self.hidden_states[:hidden_states.shape[0]] = hidden_states
736696
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
737-
697+
if not self.torchair_graph_enabled:
698+
self.positions[batch_size:num_input_tokens] = 0
699+
self.input_ids[batch_size:num_input_tokens] = 0
700+
self.hidden_states[batch_size:num_input_tokens].fill_(0)
738701
if attn_metadata_i.prefill is not None:
739702
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
740703
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(

0 commit comments

Comments
 (0)