@@ -2122,7 +2122,7 @@ def _calc_spec_decode_metadata(
21222122 cu_num_scheduled_tokens - num_sampled_tokens ,
21232123 num_sampled_tokens )
21242124 logits_indices_pcp += arange
2125- logits_indices_pcp = torch .tensor (logits_indices_pcp , pin_memory = True ).to (
2125+ logits_indices_pcp = torch .from_numpy (logits_indices_pcp ). pin_memory ( ).to (
21262126 self .device , non_blocking = True )
21272127
21282128 # Compute the bonus logits indices.
@@ -2145,25 +2145,27 @@ def _calc_spec_decode_metadata(
21452145
21462146 # TODO: Optimize the CPU -> NPU copy.
21472147 cu_num_draft_tokens = (
2148- torch .tensor (cu_num_draft_tokens , pin_memory = True )
2148+ torch .from_numpy (cu_num_draft_tokens )
2149+ .pin_memory ()
21492150 .to (self .device , non_blocking = True )
21502151 )
21512152 cu_num_sampled_tokens = (
2152- torch .tensor (cu_num_sampled_tokens , pin_memory = True )
2153+ torch .from_numpy (cu_num_sampled_tokens )
2154+ .pin_memory ()
21532155 .to (self .device , non_blocking = True )
21542156 )
21552157 logits_indices = (
2156- torch .tensor (logits_indices , pin_memory = True )
2158+ torch .from_numpy (logits_indices )
2159+ .pin_memory ()
21572160 .to (self .device , non_blocking = True )
21582161 )
21592162 target_logits_indices = (
2160- torch .tensor (target_logits_indices , pin_memory = True )
2161- .to (self .device , non_blocking = True )
2162- )
2163- bonus_logits_indices = (
2164- torch .tensor (bonus_logits_indices , pin_memory = True )
2163+ torch .from_numpy (target_logits_indices )
2164+ .pin_memory ()
21652165 .to (self .device , non_blocking = True )
21662166 )
2167+ bonus_logits_indices = torch .from_numpy (bonus_logits_indices ).pin_memory ().to (
2168+ self .device , non_blocking = True )
21672169
21682170 # Compute the draft token ids.
21692171 # draft_token_indices: [ 1, 2, 3, 105, 106, 208]
0 commit comments