@@ -808,31 +808,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
808808 return per_layer_attn_metadata , logits_indices , padded_num_reqs ,\
809809 num_reqs , end_index
810810
811- def _scatter_placeholders (
812- self ,
813- embeds : torch .Tensor ,
814- is_embed : Optional [torch .Tensor ],
815- ) -> torch .Tensor :
816- if is_embed is None :
817- return embeds
818-
819- placeholders = embeds .new_full (
820- (is_embed .shape [0 ], embeds .shape [- 1 ]),
821- fill_value = torch .nan ,
822- )
823- placeholders [is_embed ] = embeds
824- return placeholders
825-
826- def _gather_placeholders (
827- self ,
828- placeholders : torch .Tensor ,
829- is_embed : Optional [torch .Tensor ],
830- ) -> torch .Tensor :
831- if is_embed is None :
832- return placeholders
833-
834- return placeholders [is_embed ]
835-
836811 def _execute_mm_encoder (self , scheduler_output : "SchedulerOutput" ):
837812 scheduled_encoder_inputs = scheduler_output .scheduled_encoder_inputs
838813 if not scheduled_encoder_inputs :
@@ -892,12 +867,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
892867 # NOTE (NickLucche) here we diverge from logic in other runners, as we
893868 # assume to only have whole mm items to process. Hence we avoid the
894869 # intrinsic dynamism that `scatter_mm_placeholders` introduces.
895- for (mm_hash , pos_info ), output in zip (
896- mm_hashes_pos ,
897- encoder_outputs ,
898- ):
899- if req_id not in self .encoder_cache :
900- self .encoder_cache [req_id ] = {}
870+ for (mm_hash , pos_info ), output in zip (mm_hashes_pos , encoder_outputs ):
901871 assert pos_info .is_embed is None , "Expected all positions to be" \
902872 " contiguous and embeddings."
903873 self .encoder_cache [mm_hash ] = output
0 commit comments