@@ -197,7 +197,7 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
197197 return output_tensor
198198
199199
200- @support_torch_compile
200+ @support_torch_compile ( dynamic_arg_dims = { "encoded_patches" : 0 })
201201class Llama4VisionPixelShuffleMLP (nn .Module ):
202202 def __init__ (
203203 self ,
@@ -228,7 +228,6 @@ def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
228228 return self .mlp (encoded_patches )
229229
230230
231- @support_torch_compile
232231class Llama4VisionAttention (nn .Module ):
233232 def __init__ (
234233 self ,
@@ -323,7 +322,6 @@ def forward(
323322 return attn_output
324323
325324
326- @support_torch_compile
327325class Llama4VisionEncoderLayer (nn .Module ):
328326 def __init__ (
329327 self ,
@@ -377,6 +375,7 @@ def forward(
377375 return outputs
378376
379377
378+ @support_torch_compile (dynamic_arg_dims = {"hidden_states" : 0 })
380379class Llama4VisionEncoder (nn .Module ):
381380 def __init__ (
382381 self ,
@@ -387,20 +386,17 @@ def __init__(
387386 ):
388387 super ().__init__ ()
389388 self .config = config
390- from vllm .compilation .backends import set_model_tag
391-
392- with set_model_tag ("Llama4VisionEncoderLayer" ):
393- self .layers = nn .ModuleList (
394- [
395- Llama4VisionEncoderLayer (
396- config = config ,
397- quant_config = quant_config ,
398- prefix = f"{ prefix } .layers.{ layer_idx } " ,
399- use_data_parallel = use_data_parallel ,
400- )
401- for layer_idx in range (config .num_hidden_layers )
402- ]
403- )
389+ self .layers = nn .ModuleList (
390+ [
391+ Llama4VisionEncoderLayer (
392+ config = config ,
393+ quant_config = quant_config ,
394+ prefix = f"{ prefix } .layers.{ layer_idx } " ,
395+ use_data_parallel = use_data_parallel ,
396+ )
397+ for layer_idx in range (config .num_hidden_layers )
398+ ]
399+ )
404400
405401 def forward (
406402 self ,
@@ -488,14 +484,16 @@ def __init__(
488484 self .layernorm_post = nn .LayerNorm (self .hidden_size , eps = 1e-5 )
489485
490486 # encoders
491- self .model = Llama4VisionEncoder (
492- config = config ,
493- quant_config = quant_config ,
494- prefix = f"{ prefix } .model" ,
495- use_data_parallel = use_data_parallel ,
496- )
497487 from vllm .compilation .backends import set_model_tag
498488
489+ with set_model_tag ("Llama4VisionEncoderLayer" ):
490+ self .model = Llama4VisionEncoder (
491+ config = config ,
492+ quant_config = quant_config ,
493+ prefix = f"{ prefix } .model" ,
494+ use_data_parallel = use_data_parallel ,
495+ )
496+
499497 with set_model_tag ("Llama4VisionPixelShuffleMLP" ):
500498 self .vision_adapter = Llama4VisionPixelShuffleMLP (
501499 config = config ,
0 commit comments