Skip to content

Commit cda2170

Browse files
committed
Add ut aboult mtp fullgraph
Signed-off-by: anon189Ty <[email protected]>
1 parent 9258fa3 commit cda2170

File tree

4 files changed

+104
-23
lines changed

4 files changed

+104
-23
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,41 @@ def test_ascend_mla_metadata_builder_spec_decode(self):
229229
builder.chunked_prefill_enabled,
230230
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
231231

232+
def test_ascend_mla_metadata_builder_build_full_graph(self):
233+
mock_vllm_config = MagicMock()
234+
mock_vllm_config.model_config.max_model_len = 1024
235+
mock_vllm_config.model_config.get_head_size.return_value = 64
236+
mock_vllm_config.model_config.dtype = torch.float16
237+
mock_vllm_config.cache_config.block_size = 16
238+
mock_vllm_config.scheduler_config.max_num_seqs = 4
239+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
240+
mock_device = 'cpu'
241+
242+
mock_spec_config = MagicMock()
243+
mock_spec_config.num_speculative_tokens = 3
244+
mock_vllm_config.speculative_config = mock_spec_config
245+
246+
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
247+
mock_device)
248+
common_metadata = MagicMock()
249+
model = MagicMock()
250+
common_metadata.graph_pad_size = 8
251+
common_metadata.num_reqs = 4
252+
common_metadata.num_actual_tokens = 5
253+
common_metadata.max_query_len = 5
254+
common_metadata.seq_lens_cpu = torch.Tensor([9, 10, 8, 8]).int()
255+
common_metadata.query_start_loc = torch.Tensor([0, 1, 2, 4, 5]).int()
256+
common_metadata.query_start_loc_cpu = torch.Tensor([0, 1, 2, 4,
257+
5]).int()
258+
common_metadata.positions = torch.Tensor([1, 2, 3, 4, 5, 6]).int()
259+
block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int()
260+
common_metadata.block_table_tensor = block_table
261+
metadata = builder.build(0, common_metadata, model)
262+
263+
self.assertEqual(metadata.decode.actual_seq_lengths_q,
264+
[1, 2, 4, 5, 6, 6, 7, 8])
265+
self.assertEqual(metadata.decode.block_table.shape[0], 8)
266+
232267
def test_reorder_batch(self):
233268
ascend_config = MagicMock()
234269

@@ -266,6 +301,28 @@ def test_reorder_batch(self):
266301
self.assertTrue(modified)
267302
input_batch.swap_states.assert_called_once_with(1, 2)
268303

304+
def test_pad_actual_seq_lens_q(self):
305+
mock_vllm_config = MagicMock()
306+
mock_vllm_config.model_config.max_model_len = 1024
307+
mock_vllm_config.model_config.get_head_size.return_value = 64
308+
mock_vllm_config.model_config.dtype = torch.float16
309+
mock_vllm_config.cache_config.block_size = 16
310+
mock_vllm_config.scheduler_config.max_num_seqs = 4
311+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
312+
mock_device = 'cpu'
313+
mock_vllm_config.speculative_config = None
314+
315+
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
316+
mock_device)
317+
input_seq_lens = [1, 2, 4, 5]
318+
expect_output = [1, 2, 4, 5, 6, 6, 7, 8]
319+
num_reqs = 4
320+
num_reqs_pad_size = 4
321+
output_seq_lens = builder.pad_actual_seq_len_q(num_reqs_pad_size,
322+
num_reqs,
323+
input_seq_lens)
324+
self.assertEqual(output_seq_lens, expect_output)
325+
269326

270327
class TestAscendMLAImpl(TestBase):
271328

tests/ut/compilation/test_acl_graph.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from vllm.forward_context import BatchDescriptor, ForwardContext
2222

2323
from tests.ut.base import TestBase
24-
from vllm_ascend.compilation.acl_graph import ACLGraphEntry, ACLGraphWrapper
24+
from vllm_ascend.compilation.acl_graph import (
25+
ACLGraphEntry, ACLGraphWrapper, get_mtp_graph_params, set_mtp_graph_params,
26+
update_mtp_graph_params_workspaces)
2527

2628

2729
class TestACLGraphEntry(TestBase):
@@ -718,3 +720,24 @@ def test_unwrap_method(self):
718720

719721
unwrapped = wrapper.unwrap()
720722
self.assertEqual(unwrapped, self.mock_runnable)
723+
724+
725+
class TestMTPGraphParams(TestBase):
726+
727+
def test_set_mtp_graph_params(self):
728+
with patch('vllm_ascend.compilation.acl_graph._mtp_graph_params',
729+
new=None):
730+
set_mtp_graph_params([4])
731+
from vllm_ascend.compilation.acl_graph import _mtp_graph_params
732+
self.assertIsNotNone(_mtp_graph_params)
733+
734+
@patch('vllm_ascend.compilation.acl_graph._mtp_graph_params')
735+
def test_update_mtp_graph_params_workspaces(self, mtp_graph_params_mock):
736+
mtp_graph_params_mock.workspaces = {4: 5}
737+
update_mtp_graph_params_workspaces(4, 6)
738+
self.assertEqual(mtp_graph_params_mock.workspaces[4], 6)
739+
740+
@patch('vllm_ascend.compilation.acl_graph._mtp_graph_params')
741+
def test_get_mtp_graph_params(self, mtp_graph_params_mock):
742+
graph_params = get_mtp_graph_params()
743+
self.assertIs(mtp_graph_params_mock, graph_params)

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,8 @@ def build(
451451
num_reqs_pad_size = graph_pad_size - num_reqs
452452
actual_seq_lengths_q = self.pad_actual_seq_len_q(
453453
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
454-
seq_lens_list = seq_lens_list + [0] * (
455-
graph_pad_size - num_decodes)
454+
seq_lens_list = seq_lens_list + [0] * (graph_pad_size -
455+
num_decodes)
456456
num_block_pad_size = graph_pad_size - block_table.shape[0]
457457
if num_block_pad_size > 0:
458458
block_table_padding = torch.zeros(

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636

3737
class MtpProposer(Proposer):
3838

39+
# TODO: Find out why ModuleRunner does not this explicit typing?
40+
model: Union[nn.Module, ACLGraphWrapper]
41+
3942
def __init__(
4043
self,
4144
vllm_config: VllmConfig,
@@ -145,7 +148,8 @@ def dummy_run(self,
145148
if skip_attn:
146149
attn_metadata = None
147150
elif is_running_torchair:
148-
common_attn_metadata = TorchairCommonAttentionMetadata(
151+
common_attn_metadata: TorchairCommonAttentionMetadata = \
152+
TorchairCommonAttentionMetadata(
149153
num_reqs=num_reqs,
150154
num_actual_tokens=1,
151155
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
@@ -156,24 +160,18 @@ def dummy_run(self,
156160
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
157161
common_attn_metadata)
158162
elif aclgraph_runtime_mode == CUDAGraphMode.FULL:
159-
# assert with_prefill is False, \
160-
# "Full decode graph only supports uniform batch now."
161-
max_seq_lens = self.runner.model_config.max_model_len
162163
if len(self.runner.attn_groups) > 0:
163164
num_computed_tokens_cpu = (
164165
self.runner.input_batch.
165166
num_computed_tokens_cpu_tensor[:num_reqs])
166-
query_start_loc = torch.tensor(
167-
[0] + self.runner.actual_seq_lengths_q[:num_reqs],
168-
device=self.runner.device,
169-
dtype=torch.int32)
170-
common_attn_metadata = AscendCommonAttentionMetadata(
167+
common_attn_metadata: AscendCommonAttentionMetadata = \
168+
AscendCommonAttentionMetadata(
171169
query_start_loc=torch.tensor(
172170
[0] + self.runner.actual_seq_lengths_q[:num_reqs],
173171
device=self.device,
174172
dtype=torch.int32),
175-
query_start_loc_cpu=self.runner.query_start_loc_cpu[
176-
:num_reqs + 1],
173+
query_start_loc_cpu=self.runner.
174+
query_start_loc_cpu[:num_reqs + 1],
177175
seq_lens_cpu=self.runner.seq_lens_cpu,
178176
seq_lens=self.runner.seq_lens_cpu[:num_reqs],
179177
num_reqs=num_reqs,
@@ -183,7 +181,8 @@ def dummy_run(self,
183181
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
184182
block_table_tensor=self.runner.input_batch.block_table[0].
185183
get_device_tensor()[:num_reqs],
186-
slot_mapping=self.runner.input_batch.block_table[0].slot_mapping,
184+
slot_mapping=self.runner.input_batch.block_table[0].
185+
slot_mapping,
187186
positions=self.runner.positions,
188187
attn_mask=self.runner.attn_mask,
189188
spec_attn_mask=self.runner.spec_attn_mask,
@@ -466,7 +465,6 @@ def _propose(
466465

467466
seq_lens = target_positions[last_token_indices] + 1
468467
seq_lens = seq_lens.int()
469-
seq_lens_len = seq_lens.shape[0]
470468

471469
if not self.torchair_graph_enabled:
472470
# torch mode need to update num_tokens_across_dp
@@ -481,10 +479,10 @@ def _propose(
481479

482480
if scheduler_output:
483481
uniform_decode = (max_query_len in list(
484-
range(1, self.num_speculative_tokens + 2))) and (
485-
scheduler_output.total_num_scheduled_tokens ==
486-
self.runner.input_batch.num_reqs *
487-
(self.num_speculative_tokens + 1))
482+
range(1, self.num_speculative_tokens +
483+
2))) and (scheduler_output.total_num_scheduled_tokens
484+
== self.runner.input_batch.num_reqs *
485+
(self.num_speculative_tokens + 1))
488486
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
489487
uniform_decode=uniform_decode)
490488
else:
@@ -500,9 +498,12 @@ def _propose(
500498
# Currently, if not torchair, runner.graph_pad_size will always be -1.
501499
graph_pad_size = self.runner.graph_pad_size
502500

503-
runner_slot_mapping = self.runner.input_batch.block_table[0].slot_mapping
504-
runner_slot_mapping[:target_slot_mapping.shape[0]].copy_(target_slot_mapping)
505-
runner_slot_mapping[target_slot_mapping.shape[0]:num_input_tokens].fill_(0)
501+
runner_slot_mapping = self.runner.input_batch.block_table[
502+
0].slot_mapping
503+
runner_slot_mapping[:target_slot_mapping.shape[0]].copy_(
504+
target_slot_mapping)
505+
runner_slot_mapping[target_slot_mapping.
506+
shape[0]:num_input_tokens].fill_(0)
506507

507508
# NOTE: Currently, just positions, slot_mapping, block_table and
508509
# seq_lens will be sent into MLAMetadata.

0 commit comments

Comments
 (0)