@@ -141,6 +141,34 @@ def apply_rotary_emb_qwen(
141141 return x_out .type_as (x )
142142
143143
144+ def compute_text_seq_len_from_mask (
145+ encoder_hidden_states : torch .Tensor , encoder_hidden_states_mask : Optional [torch .Tensor ]
146+ ) -> Tuple [int , Optional [torch .Tensor ]]:
147+ """
148+ Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask.
149+ """
150+ batch_size , text_seq_len = encoder_hidden_states .shape [:2 ]
151+ if encoder_hidden_states_mask is None :
152+ return text_seq_len , None
153+
154+ if encoder_hidden_states_mask .shape [:2 ] != (batch_size , text_seq_len ):
155+ raise ValueError (
156+ f"`encoder_hidden_states_mask` shape { encoder_hidden_states_mask .shape } must match "
157+ f"(batch_size, text_seq_len)=({ batch_size } , { text_seq_len } )."
158+ )
159+
160+ if encoder_hidden_states_mask .dtype != torch .bool :
161+ encoder_hidden_states_mask = encoder_hidden_states_mask .to (torch .bool )
162+
163+ position_ids = torch .arange (text_seq_len , device = encoder_hidden_states .device , dtype = torch .long )
164+ active_positions = torch .where (encoder_hidden_states_mask , position_ids , position_ids .new_zeros (()))
165+ has_active = encoder_hidden_states_mask .any (dim = 1 )
166+ per_sample_len = torch .where (has_active , active_positions .max (dim = 1 ).values + 1 , torch .as_tensor (text_seq_len ))
167+ rope_text_seq_len = max (text_seq_len , int (per_sample_len .max ().item ()))
168+
169+ return rope_text_seq_len , encoder_hidden_states_mask
170+
171+
144172class QwenTimestepProjEmbeddings (nn .Module ):
145173 def __init__ (self , embedding_dim ):
146174 super ().__init__ ()
@@ -654,9 +682,10 @@ def forward(
654682 encoder_hidden_states = self .txt_norm (encoder_hidden_states )
655683 encoder_hidden_states = self .txt_in (encoder_hidden_states )
656684
657- # Use the encoder_hidden_states sequence length for RoPE computation
658- # The mask is used for attention masking in the attention processor
659- _ , text_seq_len = encoder_hidden_states .shape [:2 ]
685+ # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
686+ text_seq_len , encoder_hidden_states_mask = compute_text_seq_len_from_mask (
687+ encoder_hidden_states , encoder_hidden_states_mask
688+ )
660689
661690 if guidance is not None :
662691 guidance = guidance .to (hidden_states .dtype ) * 1000
0 commit comments