Skip to content

Commit d9a1b9c

Browse files
committed
fix yapf error
Signed-off-by: Ronald1995 <[email protected]>
1 parent c0317c9 commit d9a1b9c

File tree

4 files changed

+97
-93
lines changed

4 files changed

+97
-93
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,8 @@ def build(
348348
device=query_start_loc_cpu.device)
349349
])
350350

351-
query_start_loc = query_start_loc_cpu.pin_memory().to(self.device,
352-
non_blocking=True)
351+
query_start_loc = query_start_loc_cpu.pin_memory().to(
352+
self.device, non_blocking=True)
353353

354354
if get_ascend_device_type() == AscendDeviceType._310P:
355355
if attn_state == AscendAttentionState.PrefillNoCache:

vllm_ascend/attention/mla_v1.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -568,41 +568,41 @@ def build(
568568
)
569569
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
570570
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
571-
device, non_blocking=True
572-
),
571+
device, non_blocking=True),
573572
starts=local_chunk_starts.pin_memory().to(
574-
device, non_blocking=True
575-
),
576-
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
573+
device, non_blocking=True),
574+
seq_tot=padded_local_chunk_seq_lens.sum(
575+
dim=1).tolist(),
577576
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
578577
chunk_seq_lens=chunk_seq_lens,
579578
chunk_seq_lens_npu=chunk_seq_lens.npu(),
580579
workspace=self.chunked_prefill_workspace,
581-
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
582-
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
583-
local_context_lens_allranks=local_context_lens_allranks.tolist(),
584-
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
585-
device, non_blocking=True
586-
),
580+
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.
581+
npu(),
582+
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens
583+
.tolist(),
584+
local_context_lens_allranks=local_context_lens_allranks
585+
.tolist(),
586+
padded_local_cu_seq_lens=
587+
padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
588+
device, non_blocking=True),
587589
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
588590
chunk_size=padded_local_max_context_chunk_across_ranks,
589591
)
590592
else:
591593
chunked_context_metadata = (
592594
AscendMLAPrefillMetadata.ChunkedContextMetadata(
593595
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
594-
device, non_blocking=True
595-
),
596+
device, non_blocking=True),
596597
starts=chunk_starts.pin_memory().to(
597-
device, non_blocking=True
598-
),
598+
device, non_blocking=True),
599599
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
600-
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
600+
max_seq_lens=chunk_seq_lens.max(
601+
dim=1).values.tolist(),
601602
chunk_seq_lens=chunk_seq_lens,
602603
chunk_seq_lens_npu=chunk_seq_lens.npu(),
603604
workspace=self.chunked_prefill_workspace,
604-
)
605-
)
605+
))
606606
prefill_input_positions = input_positions[tokens_start:]
607607
cos = self.cos_cache[
608608
prefill_input_positions].unsqueeze( # type: ignore
@@ -634,7 +634,8 @@ def build(
634634
cos = common_attn_metadata.cos
635635
sin = common_attn_metadata.sin
636636
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
637-
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + 1].tolist()
637+
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
638+
1].tolist()
638639
max_seq_lens = seq_lens[:num_decodes].max().item()
639640
seq_lens = seq_lens[:num_decodes]
640641
input_positions = input_positions[:num_decode_tokens]

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ def __init__(
144144
self.arange = torch.arange(max_num_slots_for_arange,
145145
device=device,
146146
dtype=torch.int32)
147-
self.arange_cpu = torch.arange(
148-
max_num_slots_for_arange, device="cpu", dtype=torch.int32
149-
)
147+
self.arange_cpu = torch.arange(max_num_slots_for_arange,
148+
device="cpu",
149+
dtype=torch.int32)
150150

151151
self.inputs_embeds = torch.zeros(
152152
(self.max_num_tokens, self.hidden_size),
@@ -346,7 +346,8 @@ def generate_token_ids(self,
346346
self.runner.discard_request_indices.gpu,
347347
self.runner.num_discarded_requests
348348
)
349-
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
349+
self._copy_valid_sampled_token_count(next_token_ids,
350+
valid_sampled_tokens_count)
350351

351352
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
352353
if self.pcp_size > 1:
@@ -426,24 +427,28 @@ def generate_token_ids(self,
426427
)
427428

428429
return draft_token_ids
429-
430+
430431
def _copy_valid_sampled_token_count(
431-
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
432-
) -> None:
432+
self, next_token_ids: torch.Tensor,
433+
valid_sampled_tokens_count: torch.Tensor) -> None:
433434
if self.runner.valid_sampled_token_count_event is not None:
434435
default_stream = torch.npu.current_stream()
435436
# initialize a new stream to overlap the copy operation with
436437
# prepare_input of draft model.
437-
with torch.npu.stream(self.runner.valid_sampled_token_count_copy_stream):
438+
with torch.npu.stream(
439+
self.runner.valid_sampled_token_count_copy_stream):
438440
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
439-
default_stream
440-
) # type: ignore
441-
self.runner.valid_sampled_token_count_cpu[
442-
: valid_sampled_tokens_count.shape[0]
443-
].copy_(valid_sampled_tokens_count, non_blocking=True)
441+
default_stream) # type: ignore
442+
self.runner.valid_sampled_token_count_cpu[:
443+
valid_sampled_tokens_count
444+
.shape[0]].copy_(
445+
valid_sampled_tokens_count,
446+
non_blocking=True
447+
)
444448
self.runner.valid_sampled_token_count_event.record()
445449

446-
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
450+
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
451+
1)
447452

448453
def _init_mtp_model(self):
449454
architecture = self.vllm_config.model_config.architecture

0 commit comments

Comments
 (0)