Skip to content

Commit fb42c12

Browse files
committed
Handle different 'sample_indices_fn'
1 parent d1465dd commit fb42c12

File tree

1 file changed

+40
-7
lines changed

1 file changed

+40
-7
lines changed

src/transformers/video_utils.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515

1616
import os
1717
import warnings
18-
from collections.abc import Callable, Iterable, Mapping
18+
from collections.abc import Iterable, Mapping
1919
from contextlib import redirect_stdout
2020
from dataclasses import dataclass, fields
21+
from functools import partial
2122
from io import BytesIO
22-
from typing import NewType, Optional, Union
23+
from types import FunctionType
24+
from typing import Callable, NewType, Optional, Union
2325
from urllib.parse import urlparse
2426

2527
import httpx
@@ -298,6 +300,37 @@ def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] =
298300
return indices
299301

300302

303+
def get_num_frames_and_fps(sample_indices_fn: Union[partial, FunctionType]) -> tuple[Optional[int], Optional[float]]:
304+
"""
305+
Extract num_frames and fps from a function or functools.partial.
306+
307+
Args:
308+
sample_indices_fn: function or functools.partial
309+
310+
Returns:
311+
num_frames (int or None), fps (float or None)
312+
"""
313+
num_frames, fps = None, None
314+
315+
# Case 1: functools.partial
316+
if isinstance(sample_indices_fn, partial):
317+
num_frames = sample_indices_fn.keywords.get("num_frames")
318+
fps = sample_indices_fn.keywords.get("fps")
319+
320+
# Case 2: normal function with closure
321+
elif isinstance(sample_indices_fn, FunctionType):
322+
if sample_indices_fn.__closure__:
323+
closure_vars = {
324+
var: cell.cell_contents
325+
for var, cell in zip(sample_indices_fn.__code__.co_freevars, sample_indices_fn.__closure__)
326+
}
327+
num_frames = closure_vars.get("num_frames")
328+
fps = closure_vars.get("fps")
329+
330+
# Otherwise, not supported
331+
return num_frames, fps
332+
333+
301334
def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
302335
"""
303336
A default sampling function that replicates the logic used in get_uniform_frame_indices,
@@ -364,7 +397,7 @@ def sample_indices_fn(metadata, **kwargs):
364397
video = cv2.VideoCapture(video_path)
365398
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
366399
video_fps = video.get(cv2.CAP_PROP_FPS)
367-
num_frames, fps = sample_indices_fn.keywords.get("num_frames"), sample_indices_fn.keywords.get("fps")
400+
num_frames, fps = get_num_frames_and_fps(sample_indices_fn=sample_indices_fn)
368401
if fps:
369402
sampled_fps = fps
370403
elif num_frames and video_fps:
@@ -434,7 +467,7 @@ def sample_indices_fn(metadata, **kwargs):
434467
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
435468
video_fps = vr.get_avg_fps()
436469
total_num_frames = len(vr)
437-
num_frames, fps = sample_indices_fn.keywords.get("num_frames"), sample_indices_fn.keywords.get("fps")
470+
num_frames, fps = get_num_frames_and_fps(sample_indices_fn=sample_indices_fn)
438471
if fps:
439472
sampled_fps = fps
440473
elif num_frames and video_fps:
@@ -494,7 +527,7 @@ def sample_indices_fn(metadata, **kwargs):
494527
container = av.open(video_path)
495528
total_num_frames = container.streams.video[0].frames
496529
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
497-
num_frames, fps = sample_indices_fn.keywords.get("num_frames"), sample_indices_fn.keywords.get("fps")
530+
num_frames, fps = get_num_frames_and_fps(sample_indices_fn=sample_indices_fn)
498531
if fps:
499532
sampled_fps = fps
500533
elif num_frames and video_fps:
@@ -564,7 +597,7 @@ def sample_indices_fn(metadata, **kwargs):
564597
)
565598
video_fps = info["video_fps"]
566599
total_num_frames = video.size(0)
567-
num_frames, fps = sample_indices_fn.keywords.get("num_frames"), sample_indices_fn.keywords.get("fps")
600+
num_frames, fps = get_num_frames_and_fps(sample_indices_fn=sample_indices_fn)
568601
if fps:
569602
sampled_fps = fps
570603
elif num_frames and video_fps:
@@ -631,7 +664,7 @@ def sample_indices_fn(metadata, **kwargs):
631664
)
632665
total_num_frames = decoder.metadata.num_frames
633666
video_fps = decoder.metadata.average_fps
634-
num_frames, fps = sample_indices_fn.keywords.get("num_frames"), sample_indices_fn.keywords.get("fps")
667+
num_frames, fps = get_num_frames_and_fps(sample_indices_fn=sample_indices_fn)
635668
if fps:
636669
sampled_fps = fps
637670
elif num_frames and video_fps:

0 commit comments

Comments
 (0)