@@ -390,6 +390,15 @@ def forward(
390390 return self .processor (self , hidden_states , attention_mask , image_rotary_emb , ** kwargs )
391391
392392
393+ def split_mod (mod : torch .Tensor , mod_param_sets : int ):
394+ if mod .ndim == 2 :
395+ mod = mod .unsqueeze (1 )
396+ mod_params = torch .chunk (mod , 3 * mod_param_sets , dim = - 1 )
397+ # Return tuple of 3-tuples of modulation params shift/scale/gate
398+ return tuple (mod_params [3 * i : 3 * (i + 1 )] for i in range (mod_param_sets ))
399+
400+
401+
393402class Flux2SingleTransformerBlock (nn .Module ):
394403 def __init__ (
395404 self ,
@@ -430,6 +439,8 @@ def forward(
430439 split_hidden_states : bool = False ,
431440 text_seq_len : Optional [int ] = None ,
432441 ) -> Tuple [torch .Tensor , torch .Tensor ]:
442+ temb_mod_params = split_mod (temb_mod_params , 1 )[0 ]
443+
433444 # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
434445 # concatenated
435446 if encoder_hidden_states is not None :
@@ -504,6 +515,9 @@ def forward(
504515 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
505516 ) -> Tuple [torch .Tensor , torch .Tensor ]:
506517 joint_attention_kwargs = joint_attention_kwargs or {}
518+ temb_mod_params_img = split_mod (temb_mod_params_img , 2 )
519+ temb_mod_params_txt = split_mod (temb_mod_params_txt , 2 )
520+
507521
508522 # Modulation parameters shape: [1, 1, self.dim]
509523 (shift_msa , scale_msa , gate_msa ), (shift_mlp , scale_mlp , gate_mlp ) = temb_mod_params_img
@@ -621,11 +635,12 @@ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor,
621635 mod = self .act_fn (temb )
622636 mod = self .linear (mod )
623637
624- if mod .ndim == 2 :
625- mod = mod .unsqueeze (1 )
626- mod_params = torch .chunk (mod , 3 * self .mod_param_sets , dim = - 1 )
627- # Return tuple of 3-tuples of modulation params shift/scale/gate
628- return tuple (mod_params [3 * i : 3 * (i + 1 )] for i in range (self .mod_param_sets ))
638+ return mod
639+ # if mod.ndim == 2:
640+ # mod = mod.unsqueeze(1)
641+ # mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
642+ # # Return tuple of 3-tuples of modulation params shift/scale/gate
643+ # return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
629644
630645
631646class Flux2Transformer2DModel (
@@ -821,7 +836,7 @@ def forward(
821836
822837 double_stream_mod_img = self .double_stream_modulation_img (temb )
823838 double_stream_mod_txt = self .double_stream_modulation_txt (temb )
824- single_stream_mod = self .single_stream_modulation (temb )[ 0 ]
839+ single_stream_mod = self .single_stream_modulation (temb )
825840
826841 # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
827842 hidden_states = self .x_embedder (hidden_states )
0 commit comments