|
15 | 15 |
|
16 | 16 | import os |
17 | 17 | import warnings |
18 | | -from collections.abc import Callable, Iterable, Mapping |
| 18 | +from collections.abc import Iterable, Mapping |
19 | 19 | from contextlib import redirect_stdout |
20 | 20 | from dataclasses import dataclass, fields |
| 21 | +from functools import partial |
21 | 22 | from io import BytesIO |
22 | | -from typing import NewType, Optional, Union |
| 23 | +from types import FunctionType |
| 24 | +from typing import Callable, NewType, Optional, Union |
23 | 25 | from urllib.parse import urlparse |
24 | 26 |
|
25 | 27 | import httpx |
@@ -298,6 +300,37 @@ def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = |
298 | 300 | return indices |
299 | 301 |
|
300 | 302 |
|
| 303 | +def get_num_frames_and_fps(sample_indices_fn: Union[partial, FunctionType]) -> tuple[Optional[int], Optional[float]]: |
| 304 | + """ |
| 305 | + Extract num_frames and fps from a function or functools.partial. |
| 306 | + |
| 307 | + Args: |
| 308 | + sample_indices_fn: function or functools.partial |
| 309 | + |
| 310 | + Returns: |
| 311 | + num_frames (int or None), fps (float or None) |
| 312 | + """ |
| 313 | + num_frames, fps = None, None |
| 314 | + |
| 315 | + # Case 1: functools.partial |
| 316 | + if isinstance(sample_indices_fn, partial): |
| 317 | + num_frames = sample_indices_fn.keywords.get("num_frames") |
| 318 | + fps = sample_indices_fn.keywords.get("fps") |
| 319 | + |
| 320 | + # Case 2: normal function with closure |
| 321 | + elif isinstance(sample_indices_fn, FunctionType): |
| 322 | + if sample_indices_fn.__closure__: |
| 323 | + closure_vars = { |
| 324 | + var: cell.cell_contents |
| 325 | + for var, cell in zip(sample_indices_fn.__code__.co_freevars, sample_indices_fn.__closure__) |
| 326 | + } |
| 327 | + num_frames = closure_vars.get("num_frames") |
| 328 | + fps = closure_vars.get("fps") |
| 329 | + |
| 330 | + # Otherwise, not supported |
| 331 | + return num_frames, fps |
| 332 | + |
| 333 | + |
301 | 334 | def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs): |
302 | 335 | """ |
303 | 336 | A default sampling function that replicates the logic used in get_uniform_frame_indices, |
@@ -364,7 +397,7 @@ def sample_indices_fn(metadata, **kwargs): |
364 | 397 | video = cv2.VideoCapture(video_path) |
365 | 398 | total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
366 | 399 | video_fps = video.get(cv2.CAP_PROP_FPS) |
367 | | - num_frames, fps = sample_indices_fn.keywords.get("num_frames"), sample_indices_fn.keywords.get("fps") |
| 400 | + num_frames, fps = get_num_frames_and_fps(sample_indices_fn=sample_indices_fn) |
368 | 401 | if fps: |
369 | 402 | sampled_fps = fps |
370 | 403 | elif num_frames and video_fps: |
@@ -434,7 +467,7 @@ def sample_indices_fn(metadata, **kwargs): |
434 | 467 | vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu |
435 | 468 | video_fps = vr.get_avg_fps() |
436 | 469 | total_num_frames = len(vr) |
437 | | - num_frames, fps = sample_indices_fn.keywords.get("num_frames"), sample_indices_fn.keywords.get("fps") |
| 470 | + num_frames, fps = get_num_frames_and_fps(sample_indices_fn=sample_indices_fn) |
438 | 471 | if fps: |
439 | 472 | sampled_fps = fps |
440 | 473 | elif num_frames and video_fps: |
@@ -494,7 +527,7 @@ def sample_indices_fn(metadata, **kwargs): |
494 | 527 | container = av.open(video_path) |
495 | 528 | total_num_frames = container.streams.video[0].frames |
496 | 529 | video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`? |
497 | | - num_frames, fps = sample_indices_fn.keywords.get("num_frames"), sample_indices_fn.keywords.get("fps") |
| 530 | + num_frames, fps = get_num_frames_and_fps(sample_indices_fn=sample_indices_fn) |
498 | 531 | if fps: |
499 | 532 | sampled_fps = fps |
500 | 533 | elif num_frames and video_fps: |
@@ -564,7 +597,7 @@ def sample_indices_fn(metadata, **kwargs): |
564 | 597 | ) |
565 | 598 | video_fps = info["video_fps"] |
566 | 599 | total_num_frames = video.size(0) |
567 | | - num_frames, fps = sample_indices_fn.keywords.get("num_frames"), sample_indices_fn.keywords.get("fps") |
| 600 | + num_frames, fps = get_num_frames_and_fps(sample_indices_fn=sample_indices_fn) |
568 | 601 | if fps: |
569 | 602 | sampled_fps = fps |
570 | 603 | elif num_frames and video_fps: |
@@ -631,7 +664,7 @@ def sample_indices_fn(metadata, **kwargs): |
631 | 664 | ) |
632 | 665 | total_num_frames = decoder.metadata.num_frames |
633 | 666 | video_fps = decoder.metadata.average_fps |
634 | | - num_frames, fps = sample_indices_fn.keywords.get("num_frames"), sample_indices_fn.keywords.get("fps") |
| 667 | + num_frames, fps = get_num_frames_and_fps(sample_indices_fn=sample_indices_fn) |
635 | 668 | if fps: |
636 | 669 | sampled_fps = fps |
637 | 670 | elif num_frames and video_fps: |
|
0 commit comments