Skip to content

Commit 29dec46

Browse files
committed
Fix to VisionEncoder instead of any one layer
Signed-off-by: Lucas Kabela <[email protected]>
1 parent 8a8a772 commit 29dec46

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

vllm/model_executor/models/mllama4.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
197197
return output_tensor
198198

199199

200-
@support_torch_compile
200+
@support_torch_compile(dynamic_arg_dims={"encoded_patches": 0})
201201
class Llama4VisionPixelShuffleMLP(nn.Module):
202202
def __init__(
203203
self,
@@ -228,7 +228,6 @@ def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
228228
return self.mlp(encoded_patches)
229229

230230

231-
@support_torch_compile
232231
class Llama4VisionAttention(nn.Module):
233232
def __init__(
234233
self,
@@ -323,7 +322,6 @@ def forward(
323322
return attn_output
324323

325324

326-
@support_torch_compile
327325
class Llama4VisionEncoderLayer(nn.Module):
328326
def __init__(
329327
self,
@@ -377,6 +375,7 @@ def forward(
377375
return outputs
378376

379377

378+
@support_torch_compile(dynamic_arg_dims={"hidden_states": 0})
380379
class Llama4VisionEncoder(nn.Module):
381380
def __init__(
382381
self,
@@ -387,20 +386,17 @@ def __init__(
387386
):
388387
super().__init__()
389388
self.config = config
390-
from vllm.compilation.backends import set_model_tag
391-
392-
with set_model_tag("Llama4VisionEncoderLayer"):
393-
self.layers = nn.ModuleList(
394-
[
395-
Llama4VisionEncoderLayer(
396-
config=config,
397-
quant_config=quant_config,
398-
prefix=f"{prefix}.layers.{layer_idx}",
399-
use_data_parallel=use_data_parallel,
400-
)
401-
for layer_idx in range(config.num_hidden_layers)
402-
]
403-
)
389+
self.layers = nn.ModuleList(
390+
[
391+
Llama4VisionEncoderLayer(
392+
config=config,
393+
quant_config=quant_config,
394+
prefix=f"{prefix}.layers.{layer_idx}",
395+
use_data_parallel=use_data_parallel,
396+
)
397+
for layer_idx in range(config.num_hidden_layers)
398+
]
399+
)
404400

405401
def forward(
406402
self,
@@ -488,14 +484,16 @@ def __init__(
488484
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
489485

490486
# encoders
491-
self.model = Llama4VisionEncoder(
492-
config=config,
493-
quant_config=quant_config,
494-
prefix=f"{prefix}.model",
495-
use_data_parallel=use_data_parallel,
496-
)
497487
from vllm.compilation.backends import set_model_tag
498488

489+
with set_model_tag("Llama4VisionEncoderLayer"):
490+
self.model = Llama4VisionEncoder(
491+
config=config,
492+
quant_config=quant_config,
493+
prefix=f"{prefix}.model",
494+
use_data_parallel=use_data_parallel,
495+
)
496+
499497
with set_model_tag("Llama4VisionPixelShuffleMLP"):
500498
self.vision_adapter = Llama4VisionPixelShuffleMLP(
501499
config=config,

0 commit comments

Comments
 (0)