@@ -349,6 +349,66 @@ def _encode_prompt_clip(
349349
350350 return pooled_embed .to (dtype )
351351
352+ # def encode_prompt(
353+ # self,
354+ # prompt: Union[str, List[str]],
355+ # num_videos_per_prompt: int = 1,
356+ # max_sequence_length: int = 512,
357+ # device: Optional[torch.device] = None,
358+ # dtype: Optional[torch.dtype] = None,
359+ # ):
360+ # r"""
361+ # Encodes a single prompt (positive or negative) into text encoder hidden states.
362+
363+ # This method combines embeddings from both Qwen2.5-VL and CLIP text encoders
364+ # to create comprehensive text representations for video generation.
365+
366+ # Args:
367+ # prompt (`str` or `List[str]`):
368+ # Prompt to be encoded.
369+ # num_videos_per_prompt (`int`, *optional*, defaults to 1):
370+ # Number of videos to generate per prompt.
371+ # max_sequence_length (`int`, *optional*, defaults to 512):
372+ # Maximum sequence length for text encoding.
373+ # device (`torch.device`, *optional*):
374+ # Torch device.
375+ # dtype (`torch.dtype`, *optional*):
376+ # Torch dtype.
377+
378+ # Returns:
379+ # Tuple[Dict[str, torch.Tensor], torch.Tensor]:
380+ # - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP)
381+ # - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings
382+ # """
383+ # device = device or self._execution_device
384+ # dtype = dtype or self.text_encoder.dtype
385+
386+ # batch_size = len(prompt)
387+
388+ # prompt = [prompt_clean(p) for p in prompt]
389+
390+ # # Encode with Qwen2.5-VL
391+ # prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
392+ # prompt=prompt,
393+ # device=device,
394+ # max_sequence_length=max_sequence_length,
395+ # dtype=dtype,
396+ # )
397+
398+ # # Encode with CLIP
399+ # prompt_embeds_clip = self._encode_prompt_clip(
400+ # prompt=prompt,
401+ # device=device,
402+ # dtype=dtype,
403+ # )
404+ # prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1)
405+ # prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1)
406+
407+ # prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1)
408+ # prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1)
409+
410+ # return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens
411+
352412 def encode_prompt (
353413 self ,
354414 prompt : Union [str , List [str ]],
@@ -376,9 +436,10 @@ def encode_prompt(
376436 Torch dtype.
377437
378438 Returns:
379- Tuple[Dict[str, torch.Tensor], torch.Tensor]:
380- - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP)
381- - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings
439+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
440+ - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim)
441+ - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim)
442+ - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * num_videos_per_prompt + 1,)
382443 """
383444 device = device or self ._execution_device
384445 dtype = dtype or self .text_encoder .dtype
@@ -394,43 +455,69 @@ def encode_prompt(
394455 max_sequence_length = max_sequence_length ,
395456 dtype = dtype ,
396457 )
458+ # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim]
397459
398460 # Encode with CLIP
399461 prompt_embeds_clip = self ._encode_prompt_clip (
400462 prompt = prompt ,
401463 device = device ,
402464 dtype = dtype ,
403465 )
404- prompt_embeds_qwen = prompt_embeds_qwen .repeat (1 , num_videos_per_prompt , 1 )
405- prompt_embeds_qwen = prompt_embeds_qwen .view (batch_size * num_videos_per_prompt , - 1 )
406-
407- prompt_embeds_clip = prompt_embeds_clip .repeat (1 , num_videos_per_prompt , 1 )
466+ # prompt_embeds_clip shape: [batch_size, clip_embed_dim]
467+
468+ # Repeat embeddings for num_videos_per_prompt
469+ # Qwen embeddings: repeat sequence for each video, then reshape
470+ prompt_embeds_qwen = prompt_embeds_qwen .repeat (1 , num_videos_per_prompt , 1 ) # [batch_size, seq_len * num_videos_per_prompt, embed_dim]
471+ # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim]
472+ prompt_embeds_qwen = prompt_embeds_qwen .view (batch_size * num_videos_per_prompt , - 1 , prompt_embeds_qwen .shape [- 1 ])
473+
474+ # CLIP embeddings: repeat for each video
475+ prompt_embeds_clip = prompt_embeds_clip .repeat (1 , num_videos_per_prompt , 1 ) # [batch_size, num_videos_per_prompt, clip_embed_dim]
476+ # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim]
408477 prompt_embeds_clip = prompt_embeds_clip .view (batch_size * num_videos_per_prompt , - 1 )
409478
410- return prompt_embeds_qwen , prompt_embeds_clip , prompt_cu_seqlens
479+ # Repeat cumulative sequence lengths for num_videos_per_prompt
480+ # Original cu_seqlens: [0, len1, len1+len2, ...]
481+ # Need to repeat the differences and reconstruct for repeated prompts
482+ # Original differences (lengths) for each prompt in the batch
483+ original_lengths = prompt_cu_seqlens .diff () # [len1, len2, ...]
484+ # Repeat the lengths for num_videos_per_prompt
485+ repeated_lengths = original_lengths .repeat_interleave (num_videos_per_prompt ) # [len1, len1, ..., len2, len2, ...]
486+ # Reconstruct the cumulative lengths
487+ repeated_cu_seqlens = torch .cat ([torch .tensor ([0 ], device = device , dtype = torch .int32 ), repeated_lengths .cumsum (0 )])
488+
489+ return prompt_embeds_qwen , prompt_embeds_clip , repeated_cu_seqlens
411490
412491 def check_inputs (
413492 self ,
414493 prompt ,
415494 negative_prompt ,
416495 height ,
417496 width ,
418- prompt_embeds = None ,
419- negative_prompt_embeds = None ,
497+ prompt_embeds_qwen = None ,
498+ prompt_embeds_clip = None ,
499+ negative_prompt_embeds_qwen = None ,
500+ negative_prompt_embeds_clip = None ,
501+ prompt_cu_seqlens = None ,
502+ negative_prompt_cu_seqlens = None ,
420503 callback_on_step_end_tensor_inputs = None ,
421504 ):
422505 """
423506 Validate input parameters for the pipeline.
424-
507+
425508 Args:
426509 prompt: Input prompt
427510 negative_prompt: Negative prompt for guidance
428511 height: Video height
429512 width: Video width
430- prompt_embeds: Pre-computed prompt embeddings
431- negative_prompt_embeds: Pre-computed negative prompt embeddings
513+ prompt_embeds_qwen: Pre-computed Qwen prompt embeddings
514+ prompt_embeds_clip: Pre-computed CLIP prompt embeddings
515+ negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings
516+ negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings
517+ prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt
518+ negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt
432519 callback_on_step_end_tensor_inputs: Callback tensor inputs
433-
520+
434521 Raises:
435522 ValueError: If inputs are invalid
436523 """
@@ -444,23 +531,32 @@ def check_inputs(
444531 f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
445532 )
446533
447- if prompt is not None and prompt_embeds is not None :
448- raise ValueError (
449- f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
450- " only forward one of the two."
451- )
452- elif negative_prompt is not None and negative_prompt_embeds is not None :
453- raise ValueError (
454- f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`: { negative_prompt_embeds } . Please make sure to"
455- " only forward one of the two."
456- )
457- elif prompt is None and prompt_embeds is None :
534+ # Check for consistency within positive prompt embeddings and sequence lengths
535+ if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None :
536+ if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None :
537+ raise ValueError (
538+ f"If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, "
539+ f"all three must be provided."
540+ )
541+
542+ # Check for consistency within negative prompt embeddings and sequence lengths
543+ if negative_prompt_embeds_qwen is not None or negative_prompt_embeds_clip is not None or negative_prompt_cu_seqlens is not None :
544+ if negative_prompt_embeds_qwen is None or negative_prompt_embeds_clip is None or negative_prompt_cu_seqlens is None :
545+ raise ValueError (
546+ f"If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, "
547+ f"all three must be provided."
548+ )
549+
550+ # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive)
551+ if prompt is None and prompt_embeds_qwen is None :
458552 raise ValueError (
459- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt ` and `prompt_embeds` undefined."
553+ "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip ` and `prompt_cu_seqlens`). Cannot leave all undefined."
460554 )
461- elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
555+
556+ # Validate types for prompt and negative_prompt if provided
557+ if prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
462558 raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
463- elif negative_prompt is not None and (
559+ if negative_prompt is not None and (
464560 not isinstance (negative_prompt , str ) and not isinstance (negative_prompt , list )
465561 ):
466562 raise ValueError (f"`negative_prompt` has to be of type `str` or `list` but is { type (negative_prompt )} " )
@@ -632,13 +728,17 @@ def __call__(
632728
633729 # 1. Check inputs. Raise error if not correct
634730 self .check_inputs (
635- prompt ,
636- negative_prompt ,
637- height ,
638- width ,
639- prompt_embeds ,
640- negative_prompt_embeds ,
641- callback_on_step_end_tensor_inputs ,
731+ prompt = prompt ,
732+ negative_prompt = negative_prompt ,
733+ height = height ,
734+ width = width ,
735+ prompt_embeds_qwen = prompt_embeds_qwen ,
736+ prompt_embeds_clip = prompt_embeds_clip ,
737+ negative_prompt_embeds_qwen = negative_prompt_embeds_qwen ,
738+ negative_prompt_embeds_clip = negative_prompt_embeds_clip ,
739+ prompt_cu_seqlens = prompt_cu_seqlens ,
740+ negative_prompt_cu_seqlens = negative_prompt_cu_seqlens ,
741+ callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
642742 )
643743
644744 if num_frames % self .vae_scale_factor_temporal != 1 :
@@ -739,7 +839,7 @@ def __call__(
739839 continue
740840
741841 timestep = t .unsqueeze (0 ).repeat (batch_size * num_videos_per_prompt )
742-
842+
743843 # Predict noise residual
744844 pred_velocity = self .transformer (
745845 hidden_states = latents .to (dtype ),
@@ -753,7 +853,7 @@ def __call__(
753853 return_dict = True
754854 ).sample
755855
756- if self .do_classifier_free_guidance and negative_prompt_embeds_dict is not None :
856+ if self .do_classifier_free_guidance and negative_prompt_embeds_qwen is not None :
757857 uncond_pred_velocity = self .transformer (
758858 hidden_states = latents .to (dtype ),
759859 encoder_hidden_states = negative_prompt_embeds_qwen .to (dtype ),
@@ -769,7 +869,6 @@ def __call__(
769869 pred_velocity = uncond_pred_velocity + guidance_scale * (
770870 pred_velocity - uncond_pred_velocity
771871 )
772-
773872 # Compute previous sample using the scheduler
774873 latents [:, :, :, :, :num_channels_latents ] = self .scheduler .step (
775874 pred_velocity , t , latents [:, :, :, :, :num_channels_latents ], return_dict = False
0 commit comments