Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"return_metadata": True},
}


Expand Down Expand Up @@ -924,10 +925,17 @@ def __call__(
image_grid_thw = image_inputs["image_grid_thw"]

if videos is not None:
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]

# Get video metadata
if not kwargs.get("return_metadata"):
video_metadata = videos_inputs.pop("video_metadata")
else:
video_metadata = videos_inputs["video_metadata"]

fps = [metadata.sampled_fps or metadata.fps or 24 for metadata in video_metadata]

if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"return_metadata": True},
}


Expand Down Expand Up @@ -135,10 +136,17 @@ def __call__(
image_grid_thw = image_inputs["image_grid_thw"]

if videos is not None:
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]

# Get video metadata
if not kwargs.get("return_metadata"):
video_metadata = videos_inputs.pop("video_metadata")
else:
video_metadata = videos_inputs["video_metadata"]

fps = [metadata.sampled_fps or metadata.fps or 24 for metadata in video_metadata]

if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
Expand Down
24 changes: 18 additions & 6 deletions src/transformers/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
class VideoMetadata(Mapping):
total_num_frames: int
fps: Optional[float] = None
sampled_fps: Optional[float] = None
width: Optional[int] = None
height: Optional[int] = None
duration: Optional[float] = None
Expand Down Expand Up @@ -372,8 +373,9 @@ def sample_indices_fn(metadata, **kwargs):
height=int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)),
width=int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
)
indices = sample_indices_fn(metadata=metadata, **kwargs)

indices = sample_indices_fn(metadata=metadata, **kwargs)
sampled_fps = len(indices) / total_num_frames * video_fps if video_fps else 24
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we infer it in the same way everywhere so imo we can move it as a property of class VideoMetadata, similar to how timestamps work

index = 0
frames = []
while video.isOpened():
Expand All @@ -391,6 +393,7 @@ def sample_indices_fn(metadata, **kwargs):

video.release()
metadata.frames_indices = indices
metadata.sampled_fps = float(sampled_fps)
return np.stack(frames), metadata


Expand Down Expand Up @@ -434,11 +437,13 @@ def sample_indices_fn(metadata, **kwargs):
)

indices = sample_indices_fn(metadata=metadata, **kwargs)
sampled_fps = len(indices) / total_num_frames * video_fps if video_fps else 24
video = vr.get_batch(indices).asnumpy()

metadata.update(
{
"frames_indices": indices,
"sampled_fps": sampled_fps,
"height": video.shape[1],
"width": video.shape[2],
}
Expand Down Expand Up @@ -486,8 +491,9 @@ def sample_indices_fn(metadata, **kwargs):
height=container.streams.video[0].height,
width=container.streams.video[0].width,
)
indices = sample_indices_fn(metadata=metadata, **kwargs)

indices = sample_indices_fn(metadata=metadata, **kwargs)
sampled_fps = len(indices) / total_num_frames * video_fps if video_fps else 24
frames = []
container.seek(0)
end_index = indices[-1]
Expand All @@ -499,6 +505,7 @@ def sample_indices_fn(metadata, **kwargs):

video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
metadata.frames_indices = indices
metadata.sampled_fps = sampled_fps
return video, metadata


Expand Down Expand Up @@ -548,11 +555,12 @@ def sample_indices_fn(metadata, **kwargs):
)

indices = sample_indices_fn(metadata=metadata, **kwargs)

sampled_fps = len(indices) / total_num_frames * video_fps if video_fps else 24
video = video[indices].contiguous()
metadata.update(
{
"frames_indices": indices,
"sampled_fps": sampled_fps,
"height": video.shape[2],
"width": video.shape[3],
}
Expand Down Expand Up @@ -596,18 +604,22 @@ def sample_indices_fn(metadata, **kwargs):
num_ffmpeg_threads=0,
device=kwargs.get("device", "cpu"),
)
total_num_frames = decoder.metadata.num_frames
video_fps = decoder.metadata.average_fps
metadata = VideoMetadata(
total_num_frames=decoder.metadata.num_frames,
fps=decoder.metadata.average_fps,
total_num_frames=total_num_frames,
fps=video_fps,
duration=decoder.metadata.duration_seconds,
video_backend="torchcodec",
height=decoder.metadata.height,
width=decoder.metadata.width,
)
indices = sample_indices_fn(metadata=metadata, **kwargs)

indices = sample_indices_fn(metadata=metadata, **kwargs)
sampled_fps = len(indices) / total_num_frames * video_fps if video_fps else 24
video = decoder.get_frames_at(indices=indices).data.contiguous()
metadata.frames_indices = indices
metadata.sampled_fps = sampled_fps
return video, metadata


Expand Down
Loading