Skip to content

Commit 0881fe8

Browse files
committed
split tensors inside the transformer blocks to avoid checkpointing issues
1 parent edf36f5 commit 0881fe8

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,15 @@ def forward(
390390
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
391391

392392

393+
def split_mod(mod: torch.Tensor, mod_param_sets: int):
394+
if mod.ndim == 2:
395+
mod = mod.unsqueeze(1)
396+
mod_params = torch.chunk(mod, 3 * mod_param_sets, dim=-1)
397+
# Return tuple of 3-tuples of modulation params shift/scale/gate
398+
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets))
399+
400+
401+
393402
class Flux2SingleTransformerBlock(nn.Module):
394403
def __init__(
395404
self,
@@ -430,6 +439,8 @@ def forward(
430439
split_hidden_states: bool = False,
431440
text_seq_len: Optional[int] = None,
432441
) -> Tuple[torch.Tensor, torch.Tensor]:
442+
temb_mod_params = split_mod(temb_mod_params, 1)[0]
443+
433444
# If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
434445
# concatenated
435446
if encoder_hidden_states is not None:
@@ -504,6 +515,9 @@ def forward(
504515
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
505516
) -> Tuple[torch.Tensor, torch.Tensor]:
506517
joint_attention_kwargs = joint_attention_kwargs or {}
518+
temb_mod_params_img = split_mod(temb_mod_params_img, 2)
519+
temb_mod_params_txt = split_mod(temb_mod_params_txt, 2)
520+
507521

508522
# Modulation parameters shape: [1, 1, self.dim]
509523
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
@@ -621,11 +635,12 @@ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor,
621635
mod = self.act_fn(temb)
622636
mod = self.linear(mod)
623637

624-
if mod.ndim == 2:
625-
mod = mod.unsqueeze(1)
626-
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
627-
# Return tuple of 3-tuples of modulation params shift/scale/gate
628-
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
638+
return mod
639+
# if mod.ndim == 2:
640+
# mod = mod.unsqueeze(1)
641+
# mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
642+
# # Return tuple of 3-tuples of modulation params shift/scale/gate
643+
# return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
629644

630645

631646
class Flux2Transformer2DModel(
@@ -821,7 +836,7 @@ def forward(
821836

822837
double_stream_mod_img = self.double_stream_modulation_img(temb)
823838
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
824-
single_stream_mod = self.single_stream_modulation(temb)[0]
839+
single_stream_mod = self.single_stream_modulation(temb)
825840

826841
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
827842
hidden_states = self.x_embedder(hidden_states)

0 commit comments

Comments
 (0)