Skip to content

Commit f1bddbd

Browse files
[Core] Cleanup TPU model runner for MM (#23894)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 9748c51 commit f1bddbd

File tree

1 file changed

+1
-31
lines changed

1 file changed

+1
-31
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -808,31 +808,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
808808
return per_layer_attn_metadata, logits_indices, padded_num_reqs,\
809809
num_reqs, end_index
810810

811-
def _scatter_placeholders(
812-
self,
813-
embeds: torch.Tensor,
814-
is_embed: Optional[torch.Tensor],
815-
) -> torch.Tensor:
816-
if is_embed is None:
817-
return embeds
818-
819-
placeholders = embeds.new_full(
820-
(is_embed.shape[0], embeds.shape[-1]),
821-
fill_value=torch.nan,
822-
)
823-
placeholders[is_embed] = embeds
824-
return placeholders
825-
826-
def _gather_placeholders(
827-
self,
828-
placeholders: torch.Tensor,
829-
is_embed: Optional[torch.Tensor],
830-
) -> torch.Tensor:
831-
if is_embed is None:
832-
return placeholders
833-
834-
return placeholders[is_embed]
835-
836811
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
837812
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
838813
if not scheduled_encoder_inputs:
@@ -892,12 +867,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
892867
# NOTE (NickLucche) here we diverge from logic in other runners, as we
893868
# assume to only have whole mm items to process. Hence we avoid the
894869
# intrinsic dynamism that `scatter_mm_placeholders` introduces.
895-
for (mm_hash, pos_info), output in zip(
896-
mm_hashes_pos,
897-
encoder_outputs,
898-
):
899-
if req_id not in self.encoder_cache:
900-
self.encoder_cache[req_id] = {}
870+
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
901871
assert pos_info.is_embed is None, "Expected all positions to be"\
902872
" contiguous and embeddings."
903873
self.encoder_cache[mm_hash] = output

0 commit comments

Comments
 (0)