Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ decoder.get_frame_at(len(decoder) - 1)
# pts_seconds: 9.960000038146973
# duration_seconds: 0.03999999910593033

decoder.get_frames_at(start=10, stop=30, step=5)
decoder.get_frames_in_range(start=10, stop=30, step=5)
# FrameBatch:
# data (shape): torch.Size([4, 3, 400, 640])
# pts_seconds: tensor([0.4000, 0.6000, 0.8000, 1.0000])
Expand Down
8 changes: 4 additions & 4 deletions examples/basic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
# their :term:`pts` (Presentation Time Stamp), and their duration.
# This can be achieved using the
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` and
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` methods, which
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range` methods, which
# will return a :class:`~torchcodec.Frame` and
# :class:`~torchcodec.FrameBatch` objects respectively.

Expand All @@ -129,7 +129,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
print(last_frame)

# %%
middle_frames = decoder.get_frames_at(start=10, stop=20, step=2)
middle_frames = decoder.get_frames_in_range(start=10, stop=20, step=2)
print(f"{type(middle_frames) = }")
print(middle_frames)

Expand All @@ -152,7 +152,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
# So far, we have retrieved frames based on their index. We can also retrieve
# frames based on *when* they are displayed with
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at` and
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_at`, which
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_in_range`, which
# also returns :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch`
# respectively.

Expand All @@ -161,7 +161,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
print(frame_at_2_seconds)

# %%
first_two_seconds = decoder.get_frames_displayed_at(
first_two_seconds = decoder.get_frames_displayed_in_range(
start_seconds=0,
stop_seconds=2,
)
Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def get_frame_at(self, index: int) -> Frame:
duration_seconds=duration_seconds.item(),
)

def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch:
def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch:
"""Return multiple frames at the given index range.

Frames are in [start, stop).
Expand Down Expand Up @@ -238,7 +238,7 @@ def get_frame_displayed_at(self, seconds: float) -> Frame:
duration_seconds=duration_seconds.item(),
)

def get_frames_displayed_at(
def get_frames_displayed_in_range(
self, start_seconds: float, stop_seconds: float
) -> FrameBatch:
"""Returns multiple frames in the given range.
Expand Down
42 changes: 21 additions & 21 deletions test/decoders/test_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,14 @@ def test_get_frame_displayed_at_fails(self):
frame = decoder.get_frame_displayed_at(100.0) # noqa

@pytest.mark.parametrize("stream_index", [0, 3, None])
def test_get_frames_at(self, stream_index):
def test_get_frames_in_range(self, stream_index):
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)

# test degenerate case where we only actually get 1 frame
ref_frames9 = NASA_VIDEO.get_frame_data_by_range(
start=9, stop=10, stream_index=stream_index
)
frames9 = decoder.get_frames_at(start=9, stop=10)
frames9 = decoder.get_frames_in_range(start=9, stop=10)

assert_tensor_equal(ref_frames9, frames9.data)
assert frames9.pts_seconds[0].item() == pytest.approx(
Expand All @@ -389,7 +389,7 @@ def test_get_frames_at(self, stream_index):
ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(
start=0, stop=10, stream_index=stream_index
)
frames0_9 = decoder.get_frames_at(start=0, stop=10)
frames0_9 = decoder.get_frames_in_range(start=0, stop=10)
assert frames0_9.data.shape == torch.Size(
[
10,
Expand All @@ -412,7 +412,7 @@ def test_get_frames_at(self, stream_index):
ref_frames0_8_2 = NASA_VIDEO.get_frame_data_by_range(
start=0, stop=10, step=2, stream_index=stream_index
)
frames0_8_2 = decoder.get_frames_at(start=0, stop=10, step=2)
frames0_8_2 = decoder.get_frames_in_range(start=0, stop=10, step=2)
assert frames0_8_2.data.shape == torch.Size(
[
5,
Expand All @@ -434,13 +434,13 @@ def test_get_frames_at(self, stream_index):
)

# test numpy.int64 for indices
frames0_8_2 = decoder.get_frames_at(
frames0_8_2 = decoder.get_frames_in_range(
start=numpy.int64(0), stop=numpy.int64(10), step=numpy.int64(2)
)
assert_tensor_equal(ref_frames0_8_2, frames0_8_2.data)

# an empty range is valid!
empty_frames = decoder.get_frames_at(5, 5)
empty_frames = decoder.get_frames_in_range(5, 5)
assert_tensor_equal(
empty_frames.data,
NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index),
Expand All @@ -456,10 +456,10 @@ def test_get_frames_at(self, stream_index):
(
lambda decoder: decoder[0],
lambda decoder: decoder.get_frame_at(0).data,
lambda decoder: decoder.get_frames_at(0, 4).data,
lambda decoder: decoder.get_frames_in_range(0, 4).data,
lambda decoder: decoder.get_frame_displayed_at(0).data,
# TODO: uncomment once D60001893 lands
# lambda decoder: decoder.get_frames_displayed_at(0, 1).data,
# lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data,
),
)
def test_dimension_order(self, dimension_order, frame_getter):
Expand Down Expand Up @@ -487,7 +487,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)

# Note that we are comparing the results of VideoDecoder's method:
# get_frames_displayed_at()
# get_frames_displayed_in_range()
# With the testing framework's method:
# get_frame_data_by_range()
# That is, we are testing the correctness of a pts-based range against an index-
Expand All @@ -504,7 +504,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
# value for frame 5 that we have access to on the Python side is slightly less than the pts
# value on the C++ side. This test still produces the correct result because a slightly
# less value still falls into the correct window.
frames0_4 = decoder.get_frames_displayed_at(
frames0_4 = decoder.get_frames_displayed_in_range(
decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(5).pts_seconds
)
assert_tensor_equal(
Expand All @@ -513,15 +513,15 @@ def test_get_frames_by_pts_in_range(self, stream_index):
)

# Range where the stop seconds is about halfway between pts values for two frames.
also_frames0_4 = decoder.get_frames_displayed_at(
also_frames0_4 = decoder.get_frames_displayed_in_range(
decoder.get_frame_at(0).pts_seconds,
decoder.get_frame_at(4).pts_seconds + HALF_DURATION,
)
assert_tensor_equal(also_frames0_4.data, frames0_4.data)

# Again, the intention here is to provide the exact values we care about. In practice, our
# pts values are slightly smaller, so we nudge the start upwards.
frames5_9 = decoder.get_frames_displayed_at(
frames5_9 = decoder.get_frames_displayed_in_range(
decoder.get_frame_at(5).pts_seconds,
decoder.get_frame_at(10).pts_seconds,
)
Expand All @@ -533,7 +533,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
# Range where we provide start_seconds and stop_seconds that are different, but
# also should land in the same window of time between two frame's pts values. As
# a result, we should only get back one frame.
frame6 = decoder.get_frames_displayed_at(
frame6 = decoder.get_frames_displayed_in_range(
decoder.get_frame_at(6).pts_seconds,
decoder.get_frame_at(6).pts_seconds + HALF_DURATION,
)
Expand All @@ -543,7 +543,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
)

# Very small range that falls in the same frame.
frame35 = decoder.get_frames_displayed_at(
frame35 = decoder.get_frames_displayed_in_range(
decoder.get_frame_at(35).pts_seconds,
decoder.get_frame_at(35).pts_seconds + 1e-10,
)
Expand All @@ -555,7 +555,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
# Single frame where the start seconds is before frame i's pts, and the stop is
# after frame i's pts, but before frame i+1's pts. In that scenario, we expect
# to see frames i-1 and i.
frames7_8 = decoder.get_frames_displayed_at(
frames7_8 = decoder.get_frames_displayed_in_range(
NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds
- HALF_DURATION,
NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds
Expand All @@ -567,7 +567,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
)

# Start and stop seconds are the same value, which should not return a frame.
empty_frame = decoder.get_frames_displayed_at(
empty_frame = decoder.get_frames_displayed_in_range(
NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds,
NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds,
)
Expand All @@ -583,7 +583,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
)

# Start and stop seconds land within the first frame.
frame0 = decoder.get_frames_displayed_at(
frame0 = decoder.get_frames_displayed_in_range(
NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds,
NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds
+ HALF_DURATION,
Expand All @@ -595,7 +595,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):

# We should be able to get all frames by giving the beginning and ending time
# for the stream.
all_frames = decoder.get_frames_displayed_at(
all_frames = decoder.get_frames_displayed_in_range(
decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds
)
assert_tensor_equal(all_frames.data, decoder[:])
Expand All @@ -604,13 +604,13 @@ def test_get_frames_by_pts_in_range_fails(self):
decoder = VideoDecoder(NASA_VIDEO.path)

with pytest.raises(ValueError, match="Invalid start seconds"):
frame = decoder.get_frames_displayed_at(100.0, 1.0) # noqa
frame = decoder.get_frames_displayed_in_range(100.0, 1.0) # noqa

with pytest.raises(ValueError, match="Invalid start seconds"):
frame = decoder.get_frames_displayed_at(20, 23) # noqa
frame = decoder.get_frames_displayed_in_range(20, 23) # noqa

with pytest.raises(ValueError, match="Invalid stop seconds"):
frame = decoder.get_frames_displayed_at(0, 23) # noqa
frame = decoder.get_frames_displayed_in_range(0, 23) # noqa


if __name__ == "__main__":
Expand Down
Loading