Skip to content

Commit 18efdde

Browse files
committed
use compute_text_seq_len_from_mask
1 parent 0477526 commit 18efdde

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

src/diffusers/models/controlnets/controlnet_qwenimage.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
QwenImageTransformerBlock,
3232
QwenTimestepProjEmbeddings,
3333
RMSNorm,
34+
compute_text_seq_len_from_mask,
3435
)
3536

3637

@@ -244,9 +245,10 @@ def forward(
244245

245246
temb = self.time_text_embed(timestep, hidden_states)
246247

247-
# Use the encoder_hidden_states sequence length for RoPE computation
248-
# The mask is used for attention masking in the attention processor
249-
_, text_seq_len = encoder_hidden_states.shape[:2]
248+
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
249+
text_seq_len, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
250+
encoder_hidden_states, encoder_hidden_states_mask
251+
)
250252

251253
image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device)
252254

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,34 @@ def apply_rotary_emb_qwen(
141141
return x_out.type_as(x)
142142

143143

144+
def compute_text_seq_len_from_mask(
145+
encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor]
146+
) -> Tuple[int, Optional[torch.Tensor]]:
147+
"""
148+
Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask.
149+
"""
150+
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
151+
if encoder_hidden_states_mask is None:
152+
return text_seq_len, None
153+
154+
if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
155+
raise ValueError(
156+
f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match "
157+
f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})."
158+
)
159+
160+
if encoder_hidden_states_mask.dtype != torch.bool:
161+
encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool)
162+
163+
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
164+
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
165+
has_active = encoder_hidden_states_mask.any(dim=1)
166+
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
167+
rope_text_seq_len = max(text_seq_len, int(per_sample_len.max().item()))
168+
169+
return rope_text_seq_len, encoder_hidden_states_mask
170+
171+
144172
class QwenTimestepProjEmbeddings(nn.Module):
145173
def __init__(self, embedding_dim):
146174
super().__init__()
@@ -654,9 +682,10 @@ def forward(
654682
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
655683
encoder_hidden_states = self.txt_in(encoder_hidden_states)
656684

657-
# Use the encoder_hidden_states sequence length for RoPE computation
658-
# The mask is used for attention masking in the attention processor
659-
_, text_seq_len = encoder_hidden_states.shape[:2]
685+
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
686+
text_seq_len, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
687+
encoder_hidden_states, encoder_hidden_states_mask
688+
)
660689

661690
if guidance is not None:
662691
guidance = guidance.to(hidden_states.dtype) * 1000

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020

2121
from diffusers import QwenImageTransformer2DModel
22+
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
2223

2324
from ...testing_utils import enable_full_determinism, torch_device
2425
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
@@ -133,13 +134,17 @@ def test_non_contiguous_attention_mask(self):
133134
encoder_hidden_states_mask[:, 3] = 0
134135
encoder_hidden_states_mask[:, 5:] = 0
135136

136-
inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask
137+
inferred_rope_len, normalized_mask = compute_text_seq_len_from_mask(
138+
inputs["encoder_hidden_states"], encoder_hidden_states_mask
139+
)
140+
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
141+
self.assertTrue(normalized_mask.dtype == torch.bool)
142+
143+
inputs["encoder_hidden_states_mask"] = normalized_mask
137144

138145
with torch.no_grad():
139146
output = model(**inputs)
140147

141-
# The model should handle non-contiguous masks correctly
142-
# RoPE uses the full sequence length, attention masking handles the pattern
143148
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
144149

145150

0 commit comments

Comments
 (0)