Skip to content

Commit e1a635e

Browse files
committed
fixed
1 parent 06afd9b commit e1a635e

File tree

1 file changed

+137
-38
lines changed

1 file changed

+137
-38
lines changed

src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py

Lines changed: 137 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,66 @@ def _encode_prompt_clip(
349349

350350
return pooled_embed.to(dtype)
351351

352+
# def encode_prompt(
353+
# self,
354+
# prompt: Union[str, List[str]],
355+
# num_videos_per_prompt: int = 1,
356+
# max_sequence_length: int = 512,
357+
# device: Optional[torch.device] = None,
358+
# dtype: Optional[torch.dtype] = None,
359+
# ):
360+
# r"""
361+
# Encodes a single prompt (positive or negative) into text encoder hidden states.
362+
363+
# This method combines embeddings from both Qwen2.5-VL and CLIP text encoders
364+
# to create comprehensive text representations for video generation.
365+
366+
# Args:
367+
# prompt (`str` or `List[str]`):
368+
# Prompt to be encoded.
369+
# num_videos_per_prompt (`int`, *optional*, defaults to 1):
370+
# Number of videos to generate per prompt.
371+
# max_sequence_length (`int`, *optional*, defaults to 512):
372+
# Maximum sequence length for text encoding.
373+
# device (`torch.device`, *optional*):
374+
# Torch device.
375+
# dtype (`torch.dtype`, *optional*):
376+
# Torch dtype.
377+
378+
# Returns:
379+
# Tuple[Dict[str, torch.Tensor], torch.Tensor]:
380+
# - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP)
381+
# - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings
382+
# """
383+
# device = device or self._execution_device
384+
# dtype = dtype or self.text_encoder.dtype
385+
386+
# batch_size = len(prompt)
387+
388+
# prompt = [prompt_clean(p) for p in prompt]
389+
390+
# # Encode with Qwen2.5-VL
391+
# prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
392+
# prompt=prompt,
393+
# device=device,
394+
# max_sequence_length=max_sequence_length,
395+
# dtype=dtype,
396+
# )
397+
398+
# # Encode with CLIP
399+
# prompt_embeds_clip = self._encode_prompt_clip(
400+
# prompt=prompt,
401+
# device=device,
402+
# dtype=dtype,
403+
# )
404+
# prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1)
405+
# prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1)
406+
407+
# prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1)
408+
# prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1)
409+
410+
# return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens
411+
352412
def encode_prompt(
353413
self,
354414
prompt: Union[str, List[str]],
@@ -376,9 +436,10 @@ def encode_prompt(
376436
Torch dtype.
377437
378438
Returns:
379-
Tuple[Dict[str, torch.Tensor], torch.Tensor]:
380-
- A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP)
381-
- Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings
439+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
440+
- Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim)
441+
- CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim)
442+
- Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * num_videos_per_prompt + 1,)
382443
"""
383444
device = device or self._execution_device
384445
dtype = dtype or self.text_encoder.dtype
@@ -394,43 +455,69 @@ def encode_prompt(
394455
max_sequence_length=max_sequence_length,
395456
dtype=dtype,
396457
)
458+
# prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim]
397459

398460
# Encode with CLIP
399461
prompt_embeds_clip = self._encode_prompt_clip(
400462
prompt=prompt,
401463
device=device,
402464
dtype=dtype,
403465
)
404-
prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1)
405-
prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1)
406-
407-
prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1)
466+
# prompt_embeds_clip shape: [batch_size, clip_embed_dim]
467+
468+
# Repeat embeddings for num_videos_per_prompt
469+
# Qwen embeddings: repeat sequence for each video, then reshape
470+
prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) # [batch_size, seq_len * num_videos_per_prompt, embed_dim]
471+
# Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim]
472+
prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1])
473+
474+
# CLIP embeddings: repeat for each video
475+
prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) # [batch_size, num_videos_per_prompt, clip_embed_dim]
476+
# Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim]
408477
prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1)
409478

410-
return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens
479+
# Repeat cumulative sequence lengths for num_videos_per_prompt
480+
# Original cu_seqlens: [0, len1, len1+len2, ...]
481+
# Need to repeat the differences and reconstruct for repeated prompts
482+
# Original differences (lengths) for each prompt in the batch
483+
original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...]
484+
# Repeat the lengths for num_videos_per_prompt
485+
repeated_lengths = original_lengths.repeat_interleave(num_videos_per_prompt) # [len1, len1, ..., len2, len2, ...]
486+
# Reconstruct the cumulative lengths
487+
repeated_cu_seqlens = torch.cat([torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)])
488+
489+
return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens
411490

412491
def check_inputs(
413492
self,
414493
prompt,
415494
negative_prompt,
416495
height,
417496
width,
418-
prompt_embeds=None,
419-
negative_prompt_embeds=None,
497+
prompt_embeds_qwen=None,
498+
prompt_embeds_clip=None,
499+
negative_prompt_embeds_qwen=None,
500+
negative_prompt_embeds_clip=None,
501+
prompt_cu_seqlens=None,
502+
negative_prompt_cu_seqlens=None,
420503
callback_on_step_end_tensor_inputs=None,
421504
):
422505
"""
423506
Validate input parameters for the pipeline.
424-
507+
425508
Args:
426509
prompt: Input prompt
427510
negative_prompt: Negative prompt for guidance
428511
height: Video height
429512
width: Video width
430-
prompt_embeds: Pre-computed prompt embeddings
431-
negative_prompt_embeds: Pre-computed negative prompt embeddings
513+
prompt_embeds_qwen: Pre-computed Qwen prompt embeddings
514+
prompt_embeds_clip: Pre-computed CLIP prompt embeddings
515+
negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings
516+
negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings
517+
prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt
518+
negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt
432519
callback_on_step_end_tensor_inputs: Callback tensor inputs
433-
520+
434521
Raises:
435522
ValueError: If inputs are invalid
436523
"""
@@ -444,23 +531,32 @@ def check_inputs(
444531
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
445532
)
446533

447-
if prompt is not None and prompt_embeds is not None:
448-
raise ValueError(
449-
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
450-
" only forward one of the two."
451-
)
452-
elif negative_prompt is not None and negative_prompt_embeds is not None:
453-
raise ValueError(
454-
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
455-
" only forward one of the two."
456-
)
457-
elif prompt is None and prompt_embeds is None:
534+
# Check for consistency within positive prompt embeddings and sequence lengths
535+
if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None:
536+
if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None:
537+
raise ValueError(
538+
f"If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, "
539+
f"all three must be provided."
540+
)
541+
542+
# Check for consistency within negative prompt embeddings and sequence lengths
543+
if negative_prompt_embeds_qwen is not None or negative_prompt_embeds_clip is not None or negative_prompt_cu_seqlens is not None:
544+
if negative_prompt_embeds_qwen is None or negative_prompt_embeds_clip is None or negative_prompt_cu_seqlens is None:
545+
raise ValueError(
546+
f"If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, "
547+
f"all three must be provided."
548+
)
549+
550+
# Check if prompt or embeddings are provided (either prompt or all required embedding components for positive)
551+
if prompt is None and prompt_embeds_qwen is None:
458552
raise ValueError(
459-
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
553+
"Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined."
460554
)
461-
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
555+
556+
# Validate types for prompt and negative_prompt if provided
557+
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
462558
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
463-
elif negative_prompt is not None and (
559+
if negative_prompt is not None and (
464560
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
465561
):
466562
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
@@ -632,13 +728,17 @@ def __call__(
632728

633729
# 1. Check inputs. Raise error if not correct
634730
self.check_inputs(
635-
prompt,
636-
negative_prompt,
637-
height,
638-
width,
639-
prompt_embeds,
640-
negative_prompt_embeds,
641-
callback_on_step_end_tensor_inputs,
731+
prompt=prompt,
732+
negative_prompt=negative_prompt,
733+
height=height,
734+
width=width,
735+
prompt_embeds_qwen=prompt_embeds_qwen,
736+
prompt_embeds_clip=prompt_embeds_clip,
737+
negative_prompt_embeds_qwen=negative_prompt_embeds_qwen,
738+
negative_prompt_embeds_clip=negative_prompt_embeds_clip,
739+
prompt_cu_seqlens=prompt_cu_seqlens,
740+
negative_prompt_cu_seqlens=negative_prompt_cu_seqlens,
741+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
642742
)
643743

644744
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -739,7 +839,7 @@ def __call__(
739839
continue
740840

741841
timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt)
742-
842+
743843
# Predict noise residual
744844
pred_velocity = self.transformer(
745845
hidden_states=latents.to(dtype),
@@ -753,7 +853,7 @@ def __call__(
753853
return_dict=True
754854
).sample
755855

756-
if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None:
856+
if self.do_classifier_free_guidance and negative_prompt_embeds_qwen is not None:
757857
uncond_pred_velocity = self.transformer(
758858
hidden_states=latents.to(dtype),
759859
encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype),
@@ -769,7 +869,6 @@ def __call__(
769869
pred_velocity = uncond_pred_velocity + guidance_scale * (
770870
pred_velocity - uncond_pred_velocity
771871
)
772-
773872
# Compute previous sample using the scheduler
774873
latents[:, :, :, :, :num_channels_latents] = self.scheduler.step(
775874
pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False

0 commit comments

Comments
 (0)