@@ -566,10 +566,13 @@ def build(
566566 out = padded_local_cu_chunk_seq_lens_cpu [:, 1 :],
567567 dtype = torch .int32 ,
568568 )
569- chunked_context_metadata = \
570- AscendMLAPrefillMetadata .ChunkedContextMetadata (
571- cu_seq_lens = cu_seq_lens_cpu .to (device , non_blocking = True ),
572- starts = local_chunk_starts .to (device , non_blocking = True ),
569+ chunked_context_metadata = AscendMLAPrefillMetadata .ChunkedContextMetadata (
570+ cu_seq_lens = cu_seq_lens_cpu .pin_memory ().to (
571+ device , non_blocking = True
572+ ),
573+ starts = local_chunk_starts .pin_memory ().to (
574+ device , non_blocking = True
575+ ),
573576 seq_tot = padded_local_chunk_seq_lens .sum (dim = 1 ).tolist (),
574577 max_seq_lens = chunk_seq_lens .max (dim = 1 ).values .tolist (),
575578 chunk_seq_lens = chunk_seq_lens ,
@@ -578,22 +581,27 @@ def build(
578581 padded_chunk_seq_lens_npu = padded_local_chunk_seq_lens .npu (),
579582 padded_local_chunk_seq_lens = padded_local_chunk_seq_lens .tolist (),
580583 local_context_lens_allranks = local_context_lens_allranks .tolist (),
581- padded_local_cu_seq_lens = padded_local_cu_chunk_seq_lens_cpu .to (
584+ padded_local_cu_seq_lens = padded_local_cu_chunk_seq_lens_cpu .pin_memory (). to (
582585 device , non_blocking = True
583586 ),
584587 cu_seq_lens_lst = cu_seq_lens_cpu .tolist (),
585588 chunk_size = padded_local_max_context_chunk_across_ranks ,
586589 )
587590 else :
588- chunked_context_metadata = \
591+ chunked_context_metadata = (
589592 AscendMLAPrefillMetadata .ChunkedContextMetadata (
590- cu_seq_lens = cu_seq_lens_cpu .to (device , non_blocking = True ),
591- starts = chunk_starts .to (device , non_blocking = True ),
592- seq_tot = chunk_seq_lens .sum (dim = 1 ).tolist (),
593- max_seq_lens = chunk_seq_lens .max (dim = 1 ).values .tolist (),
594- chunk_seq_lens = chunk_seq_lens ,
595- chunk_seq_lens_npu = chunk_seq_lens .npu (),
596- workspace = self .chunked_prefill_workspace ,
593+ cu_seq_lens = cu_seq_lens_cpu .pin_memory ().to (
594+ device , non_blocking = True
595+ ),
596+ starts = chunk_starts .pin_memory ().to (
597+ device , non_blocking = True
598+ ),
599+ seq_tot = chunk_seq_lens .sum (dim = 1 ).tolist (),
600+ max_seq_lens = chunk_seq_lens .max (dim = 1 ).values .tolist (),
601+ chunk_seq_lens = chunk_seq_lens ,
602+ chunk_seq_lens_npu = chunk_seq_lens .npu (),
603+ workspace = self .chunked_prefill_workspace ,
604+ )
597605 )
598606 prefill_input_positions = input_positions [tokens_start :]
599607 cos = self .cos_cache [
@@ -626,7 +634,7 @@ def build(
626634 cos = common_attn_metadata .cos
627635 sin = common_attn_metadata .sin
628636 # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
629- actual_seq_lengths_q = query_start_loc [1 :num_decodes + 1 ].tolist ()
637+ actual_seq_lengths_q = query_start_loc_cpu [1 :num_decodes + 1 ].tolist ()
630638 max_seq_lens = seq_lens [:num_decodes ].max ().item ()
631639 seq_lens = seq_lens [:num_decodes ]
632640 input_positions = input_positions [:num_decode_tokens ]
0 commit comments