Skip to content

Commit e6505b5

Browse files
committed
Cleanup
1 parent 24f0d69 commit e6505b5

File tree

1 file changed

+30
-19
lines changed

1 file changed

+30
-19
lines changed

megatron/core/inference/text_generation_controllers/text_generation_controller.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def _init_dynamic_sampling_tensors(self):
9797
context = self.inference_wrapped_model.inference_context
9898
self._materialize_only_last = context.materialize_only_last_token_logits
9999
max_requests = context.max_total_requests
100+
if self._materialize_only_last:
101+
max_logits = max_requests
102+
else:
103+
max_logits = context.max_tokens
100104

101105
model_config = get_model_config(self.inference_wrapped_model.model)
102106
self._sampling_backend = model_config.sampling_backend
@@ -111,10 +115,15 @@ def _init_dynamic_sampling_tensors(self):
111115
vocab_size = self.inference_wrapped_model.inference_wrapper_config.padded_vocab_size
112116

113117
# Keep track of request metadata.
114-
self._active_request_count = None
118+
self._active_request_count: int = None
119+
self._active_token_count: int = None
120+
self._logits_seq_length: int = None
115121
self._request_metadata: Dict[str, Tensor] = {}
116122

117123
# Initialize bookkeeping tensors.
124+
self._all_logits_cuda = torch.empty(
125+
max_logits, vocab_size, dtype=logits_dtype, device=device
126+
)
118127
self._sampling_logits_cuda = torch.empty(
119128
max_requests, vocab_size, dtype=logits_dtype, device=device
120129
)
@@ -593,12 +602,19 @@ def _dynamic_step_context_init(
593602
# Get flat tokens, position ids.
594603
if construct_graph_dimensions is not None:
595604
self._recording_graph = True
596-
return context.current_input_and_position_ids(
605+
self._active_token_count = construct_graph_dimensions.token_count
606+
ret = context.current_input_and_position_ids(
597607
num_warmup_tokens=construct_graph_dimensions.token_count
598608
)
599609
else:
600610
self._recording_graph = False
601-
return context.current_input_and_position_ids()
611+
self._active_token_count = context.padded_active_token_count
612+
ret = context.current_input_and_position_ids()
613+
614+
self._logits_seq_len = (
615+
self._active_request_count if self._materialize_only_last else self._active_token_count
616+
)
617+
return ret
602618

603619
def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) -> Tensor:
604620
"""Forward step the model to get logits for dynamic batching.
@@ -610,6 +626,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
610626
position_ids (Tensor): The position IDs.
611627
"""
612628
inference_wrapper_config = self.inference_wrapped_model.inference_wrapper_config
629+
vocab_size = inference_wrapper_config.padded_vocab_size
613630

614631
context = self.inference_wrapped_model.inference_context
615632

@@ -619,11 +636,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
619636
)
620637

621638
if self.model_is_pipeline_parallel:
622-
logits_seq_len = (
623-
self._active_request_count if self._materialize_only_last else input_ids.shape[1]
624-
)
625-
vocab_size = inference_wrapper_config.padded_vocab_size
626-
logits_shape = [1, logits_seq_len, vocab_size]
639+
logits_shape = [1, self._logits_seq_len, vocab_size]
627640

628641
if is_pipeline_last_stage(self.pp_group):
629642
assert logits is not None and torch.Size(logits_shape) == logits.shape
@@ -636,8 +649,9 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
636649
)
637650

638651
last_token_logits = context.last_token_logits(logits)
639-
# Copy last_token_logits to contiguous buffer.
640-
self._sampling_logits_cuda[:self._active_request_count].copy_(last_token_logits, non_blocking=True)
652+
# Copy logits to contiguous buffer.
653+
self._all_logits_cuda[:self._logits_seq_len, :].copy_(logits, non_blocking=True)
654+
self._sampling_logits_cuda[:self._active_request_count, :].copy_(last_token_logits, non_blocking=True)
641655

642656
return logits
643657

@@ -733,23 +747,22 @@ def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]:
733747

734748
return return_log_probs.any(), top_n_log_probs.any()
735749

736-
def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]:
750+
def _dynamic_step_calculate_log_probs(self) -> Optional[Tensor]:
737751
"""Calculate log probs from logits."""
738752
context = self.inference_wrapped_model.inference_context
739753

740754
return context.calculate_log_probs(
741-
logits,
755+
self._all_logits_cuda[: self._logits_seq_len, :],
742756
self._sampled_tokens_cuda[:self._active_request_count],
743757
only_last_token_logits=self._materialize_only_last,
744758
)
745759

746760
def _dynamic_step_calculate_top_n_logprobs(
747-
self, logits: Tensor, log_probs_tensor: Optional[Tensor] = None
761+
self, log_probs_tensor: Optional[Tensor] = None
748762
) -> Optional[Dict[int, List[Tuple[Tensor, Tensor]]]]:
749763
"""Calculate top-n log probs from logits for dynamic batching.
750764
751765
Args:
752-
logits (Tensor): The logits to compute top-n log probs from.
753766
log_probs_tensor (Optional[Tensor]): Pre-computed log probabilities tensor.
754767
If provided, avoids recomputing log_softmax. Should be the tensor
755768
returned by calculate_log_probs.
@@ -892,7 +905,7 @@ async def async_generate_output_tokens_dynamic_batch(
892905
context.padded_active_request_count if context.is_decode_only() else None
893906
)
894907

895-
logits = self._dynamic_step_forward_logits(input_ids, position_ids)
908+
self._dynamic_step_forward_logits(input_ids, position_ids)
896909

897910
# This is the best place to yield control back to event loop.
898911
# At this point we have enqueued FW pass GPU kernels asynchronously.
@@ -911,11 +924,9 @@ async def async_generate_output_tokens_dynamic_batch(
911924
log_probs = None
912925
top_n_logprobs = None
913926
if return_log_probs or return_top_n_logprobs:
914-
log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits)
927+
log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs()
915928
if return_top_n_logprobs:
916-
top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs(
917-
logits, log_probs_tensor
918-
)
929+
top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs(log_probs_tensor)
919930

920931
if skip_bookkeeping:
921932
request_bookkeeping = {}

0 commit comments

Comments
 (0)