@@ -97,6 +97,10 @@ def _init_dynamic_sampling_tensors(self):
9797 context = self .inference_wrapped_model .inference_context
9898 self ._materialize_only_last = context .materialize_only_last_token_logits
9999 max_requests = context .max_total_requests
100+ if self ._materialize_only_last :
101+ max_logits = max_requests
102+ else :
103+ max_logits = context .max_tokens
100104
101105 model_config = get_model_config (self .inference_wrapped_model .model )
102106 self ._sampling_backend = model_config .sampling_backend
@@ -111,10 +115,15 @@ def _init_dynamic_sampling_tensors(self):
111115 vocab_size = self .inference_wrapped_model .inference_wrapper_config .padded_vocab_size
112116
113117 # Keep track of request metadata.
114- self ._active_request_count = None
118+ self ._active_request_count : int = None
119+ self ._active_token_count : int = None
120+ self ._logits_seq_length : int = None
115121 self ._request_metadata : Dict [str , Tensor ] = {}
116122
117123 # Initialize bookkeeping tensors.
124+ self ._all_logits_cuda = torch .empty (
125+ max_logits , vocab_size , dtype = logits_dtype , device = device
126+ )
118127 self ._sampling_logits_cuda = torch .empty (
119128 max_requests , vocab_size , dtype = logits_dtype , device = device
120129 )
@@ -593,12 +602,19 @@ def _dynamic_step_context_init(
593602 # Get flat tokens, position ids.
594603 if construct_graph_dimensions is not None :
595604 self ._recording_graph = True
596- return context .current_input_and_position_ids (
605+ self ._active_token_count = construct_graph_dimensions .token_count
606+ ret = context .current_input_and_position_ids (
597607 num_warmup_tokens = construct_graph_dimensions .token_count
598608 )
599609 else :
600610 self ._recording_graph = False
601- return context .current_input_and_position_ids ()
611+ self ._active_token_count = context .padded_active_token_count
612+ ret = context .current_input_and_position_ids ()
613+
614+ self ._logits_seq_len = (
615+ self ._active_request_count if self ._materialize_only_last else self ._active_token_count
616+ )
617+ return ret
602618
603619 def _dynamic_step_forward_logits (self , input_ids : Tensor , position_ids : Tensor ) -> Tensor :
604620 """Forward step the model to get logits for dynamic batching.
@@ -610,6 +626,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
610626 position_ids (Tensor): The position IDs.
611627 """
612628 inference_wrapper_config = self .inference_wrapped_model .inference_wrapper_config
629+ vocab_size = inference_wrapper_config .padded_vocab_size
613630
614631 context = self .inference_wrapped_model .inference_context
615632
@@ -619,11 +636,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
619636 )
620637
621638 if self .model_is_pipeline_parallel :
622- logits_seq_len = (
623- self ._active_request_count if self ._materialize_only_last else input_ids .shape [1 ]
624- )
625- vocab_size = inference_wrapper_config .padded_vocab_size
626- logits_shape = [1 , logits_seq_len , vocab_size ]
639+ logits_shape = [1 , self ._logits_seq_len , vocab_size ]
627640
628641 if is_pipeline_last_stage (self .pp_group ):
629642 assert logits is not None and torch .Size (logits_shape ) == logits .shape
@@ -636,8 +649,9 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
636649 )
637650
638651 last_token_logits = context .last_token_logits (logits )
639- # Copy last_token_logits to contiguous buffer.
640- self ._sampling_logits_cuda [:self ._active_request_count ].copy_ (last_token_logits , non_blocking = True )
652+ # Copy logits to contiguous buffer.
653+ self ._all_logits_cuda [:self ._logits_seq_len , :].copy_ (logits , non_blocking = True )
654+ self ._sampling_logits_cuda [:self ._active_request_count , :].copy_ (last_token_logits , non_blocking = True )
641655
642656 return logits
643657
@@ -733,23 +747,22 @@ def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]:
733747
734748 return return_log_probs .any (), top_n_log_probs .any ()
735749
736- def _dynamic_step_calculate_log_probs (self , logits : Tensor ) -> Optional [Tensor ]:
750+ def _dynamic_step_calculate_log_probs (self ) -> Optional [Tensor ]:
737751 """Calculate log probs from logits."""
738752 context = self .inference_wrapped_model .inference_context
739753
740754 return context .calculate_log_probs (
741- logits ,
755+ self . _all_logits_cuda [: self . _logits_seq_len , :] ,
742756 self ._sampled_tokens_cuda [:self ._active_request_count ],
743757 only_last_token_logits = self ._materialize_only_last ,
744758 )
745759
746760 def _dynamic_step_calculate_top_n_logprobs (
747- self , logits : Tensor , log_probs_tensor : Optional [Tensor ] = None
761+ self , log_probs_tensor : Optional [Tensor ] = None
748762 ) -> Optional [Dict [int , List [Tuple [Tensor , Tensor ]]]]:
749763 """Calculate top-n log probs from logits for dynamic batching.
750764
751765 Args:
752- logits (Tensor): The logits to compute top-n log probs from.
753766 log_probs_tensor (Optional[Tensor]): Pre-computed log probabilities tensor.
754767 If provided, avoids recomputing log_softmax. Should be the tensor
755768 returned by calculate_log_probs.
@@ -892,7 +905,7 @@ async def async_generate_output_tokens_dynamic_batch(
892905 context .padded_active_request_count if context .is_decode_only () else None
893906 )
894907
895- logits = self ._dynamic_step_forward_logits (input_ids , position_ids )
908+ self ._dynamic_step_forward_logits (input_ids , position_ids )
896909
897910 # This is the best place to yield control back to event loop.
898911 # At this point we have enqueued FW pass GPU kernels asynchronously.
@@ -911,11 +924,9 @@ async def async_generate_output_tokens_dynamic_batch(
911924 log_probs = None
912925 top_n_logprobs = None
913926 if return_log_probs or return_top_n_logprobs :
914- log_probs , log_probs_tensor = self ._dynamic_step_calculate_log_probs (logits )
927+ log_probs , log_probs_tensor = self ._dynamic_step_calculate_log_probs ()
915928 if return_top_n_logprobs :
916- top_n_logprobs = self ._dynamic_step_calculate_top_n_logprobs (
917- logits , log_probs_tensor
918- )
929+ top_n_logprobs = self ._dynamic_step_calculate_top_n_logprobs (log_probs_tensor )
919930
920931 if skip_bookkeeping :
921932 request_bookkeeping = {}
0 commit comments