@@ -568,41 +568,41 @@ def build(
568568 )
569569 chunked_context_metadata = AscendMLAPrefillMetadata .ChunkedContextMetadata (
570570 cu_seq_lens = cu_seq_lens_cpu .pin_memory ().to (
571- device , non_blocking = True
572- ),
571+ device , non_blocking = True ),
573572 starts = local_chunk_starts .pin_memory ().to (
574- device , non_blocking = True
575- ),
576- seq_tot = padded_local_chunk_seq_lens . sum ( dim = 1 ).tolist (),
573+ device , non_blocking = True ),
574+ seq_tot = padded_local_chunk_seq_lens . sum (
575+ dim = 1 ).tolist (),
577576 max_seq_lens = chunk_seq_lens .max (dim = 1 ).values .tolist (),
578577 chunk_seq_lens = chunk_seq_lens ,
579578 chunk_seq_lens_npu = chunk_seq_lens .npu (),
580579 workspace = self .chunked_prefill_workspace ,
581- padded_chunk_seq_lens_npu = padded_local_chunk_seq_lens .npu (),
582- padded_local_chunk_seq_lens = padded_local_chunk_seq_lens .tolist (),
583- local_context_lens_allranks = local_context_lens_allranks .tolist (),
584- padded_local_cu_seq_lens = padded_local_cu_chunk_seq_lens_cpu .pin_memory ().to (
585- device , non_blocking = True
586- ),
580+ padded_chunk_seq_lens_npu = padded_local_chunk_seq_lens .
581+ npu (),
582+ padded_local_chunk_seq_lens = padded_local_chunk_seq_lens
583+ .tolist (),
584+ local_context_lens_allranks = local_context_lens_allranks
585+ .tolist (),
586+ padded_local_cu_seq_lens =
587+ padded_local_cu_chunk_seq_lens_cpu .pin_memory ().to (
588+ device , non_blocking = True ),
587589 cu_seq_lens_lst = cu_seq_lens_cpu .tolist (),
588590 chunk_size = padded_local_max_context_chunk_across_ranks ,
589591 )
590592 else :
591593 chunked_context_metadata = (
592594 AscendMLAPrefillMetadata .ChunkedContextMetadata (
593595 cu_seq_lens = cu_seq_lens_cpu .pin_memory ().to (
594- device , non_blocking = True
595- ),
596+ device , non_blocking = True ),
596597 starts = chunk_starts .pin_memory ().to (
597- device , non_blocking = True
598- ),
598+ device , non_blocking = True ),
599599 seq_tot = chunk_seq_lens .sum (dim = 1 ).tolist (),
600- max_seq_lens = chunk_seq_lens .max (dim = 1 ).values .tolist (),
600+ max_seq_lens = chunk_seq_lens .max (
601+ dim = 1 ).values .tolist (),
601602 chunk_seq_lens = chunk_seq_lens ,
602603 chunk_seq_lens_npu = chunk_seq_lens .npu (),
603604 workspace = self .chunked_prefill_workspace ,
604- )
605- )
605+ ))
606606 prefill_input_positions = input_positions [tokens_start :]
607607 cos = self .cos_cache [
608608 prefill_input_positions ].unsqueeze ( # type: ignore
@@ -634,7 +634,8 @@ def build(
634634 cos = common_attn_metadata .cos
635635 sin = common_attn_metadata .sin
636636 # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
637- actual_seq_lengths_q = query_start_loc_cpu [1 :num_decodes + 1 ].tolist ()
637+ actual_seq_lengths_q = query_start_loc_cpu [1 :num_decodes +
638+ 1 ].tolist ()
638639 max_seq_lens = seq_lens [:num_decodes ].max ().item ()
639640 seq_lens = seq_lens [:num_decodes ]
640641 input_positions = input_positions [:num_decode_tokens ]
0 commit comments