@@ -2122,7 +2122,7 @@ def _calc_spec_decode_metadata(
21222122 cu_num_scheduled_tokens - num_sampled_tokens ,
21232123 num_sampled_tokens )
21242124 logits_indices_pcp += arange
2125- logits_indices_pcp = torch .from_numpy (logits_indices_pcp ). pin_memory ( ).to (
2125+ logits_indices_pcp = torch .tensor (logits_indices_pcp , pin_memory = True ).to (
21262126 self .device , non_blocking = True )
21272127
21282128 # Compute the bonus logits indices.
@@ -2145,28 +2145,23 @@ def _calc_spec_decode_metadata(
21452145
21462146 # TODO: Optimize the CPU -> NPU copy.
21472147 cu_num_draft_tokens = (
2148- torch .from_numpy (cu_num_draft_tokens )
2149- .pin_memory ()
2148+ torch .tensor (cu_num_draft_tokens , pin_memory = True )
21502149 .to (self .device , non_blocking = True )
21512150 )
21522151 cu_num_sampled_tokens = (
2153- torch .from_numpy (cu_num_sampled_tokens )
2154- .pin_memory ()
2152+ torch .tensor (cu_num_sampled_tokens , pin_memory = True )
21552153 .to (self .device , non_blocking = True )
21562154 )
21572155 logits_indices = (
2158- torch .from_numpy (logits_indices )
2159- .pin_memory ()
2156+ torch .tensor (logits_indices , pin_memory = True )
21602157 .to (self .device , non_blocking = True )
21612158 )
21622159 target_logits_indices = (
2163- torch .from_numpy (target_logits_indices )
2164- .pin_memory ()
2160+ torch .tensor (target_logits_indices , pin_memory = True )
21652161 .to (self .device , non_blocking = True )
21662162 )
21672163 bonus_logits_indices = (
2168- torch .from_numpy (bonus_logits_indices )
2169- .pin_memory ()
2164+ torch .tensor (bonus_logits_indices , pin_memory = True )
21702165 .to (self .device , non_blocking = True )
21712166 )
21722167
0 commit comments