Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,52 @@
import os
import random

import pytest
import torch


def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems)
config.addinivalue_line(
"markers", "needs_cuda: mark for tests that rely on a CUDA device"
)


def pytest_collection_modifyitems(items):
# This hook is called by pytest after it has collected the tests (google its
# name to check out its doc!). We can ignore some tests as we see fit here,
# or add marks, such as a skip mark.

out_items = []
for item in items:
# The needs_cuda mark will exist if the test was explicitly decorated
# with the @needs_cuda decorator. It will also exist if it was
# parametrized with a parameter that has the mark: for example if a test
# is parametrized with
# @pytest.mark.parametrize('device', cpu_and_cuda())
# the "instances" of the tests where device == 'cuda' will have the
# 'needs_cuda' mark, and the ones with device == 'cpu' won't have the
# mark.
needs_cuda = item.get_closest_marker("needs_cuda") is not None

if (
needs_cuda
and not torch.cuda.is_available()
and os.environ.get("FAIL_WITHOUT_CUDA") is None
):
# We skip CUDA tests on non-CUDA machines, but only if the
# FAIL_WITHOUT_CUDA env var wasn't set. If it's set, the test will
# typically fail with a "Unsupported device: cuda" error. This is
# normal and desirable: this env var is set on CI jobs that are
# supposed to run the CUDA tests, so if CUDA isn't available on
# those for whatever reason, we need to know.
item.add_marker(pytest.mark.skip(reason="CUDA not available."))

out_items.append(item)

items[:] = out_items


@pytest.fixture(autouse=True)
def prevent_leaking_rng():
# Prevent each test from leaking the rng to all other test when they call
Expand Down
101 changes: 29 additions & 72 deletions test/decoders/test_video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from ..utils import (
assert_tensor_close_on_at_least,
assert_tensor_equal,
cpu_and_cuda,
get_frame_compare_function,
NASA_AUDIO,
NASA_VIDEO,
needs_cuda,
Expand Down Expand Up @@ -129,41 +131,26 @@ def test_get_frame_with_info_at_index(self):
assert pts.item() == pytest.approx(6.006, rel=1e-3)
assert duration.item() == pytest.approx(0.03337, rel=1e-3)

def test_get_frames_at_indices(self):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frames_at_indices(self, device):
tensor_compare_function = get_frame_compare_function(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: here and everywhere below, call this frame_compare_function

decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
add_video_stream(decoder)
add_video_stream(decoder, device=device)
frames0and180, *_ = get_frames_at_indices(
decoder, stream_index=3, frame_indices=[0, 180]
)
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
reference_frame180 = NASA_VIDEO.get_frame_data_by_index(
INDEX_OF_FRAME_AT_6_SECONDS
)
assert_tensor_equal(frames0and180[0], reference_frame0)
assert_tensor_equal(frames0and180[1], reference_frame180)
tensor_compare_function(frames0and180[0], reference_frame0)
tensor_compare_function(frames0and180[1], reference_frame180)

@needs_cuda
def test_get_frames_at_indices_with_cuda(self):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frames_at_indices_unsorted_indices(self, device):
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
add_video_stream(decoder, device="cuda")
frames0and180, *_ = get_frames_at_indices(
decoder, stream_index=3, frame_indices=[0, 180]
)
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
reference_frame180 = NASA_VIDEO.get_frame_data_by_index(
INDEX_OF_FRAME_AT_6_SECONDS
)
assert frames0and180.device.type == "cuda"
assert_tensor_close_on_at_least(frames0and180[0].to("cpu"), reference_frame0)
assert_tensor_close_on_at_least(
frames0and180[1].to("cpu"), reference_frame180, 0.3, 30
)

def test_get_frames_at_indices_unsorted_indices(self):
decoder = create_from_file(str(NASA_VIDEO.path))
_add_video_stream(decoder)
_add_video_stream(decoder, device=device)
scan_all_streams_to_update_metadata(decoder)
stream_index = 3

Expand Down Expand Up @@ -192,9 +179,10 @@ def test_get_frames_at_indices_unsorted_indices(self):
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

def test_get_frames_by_pts(self):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frames_by_pts(self, device):
decoder = create_from_file(str(NASA_VIDEO.path))
_add_video_stream(decoder)
_add_video_stream(decoder, device=device)
scan_all_streams_to_update_metadata(decoder)
stream_index = 3

Expand Down Expand Up @@ -222,48 +210,15 @@ def test_get_frames_by_pts(self):
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

# TODO: Figure out how to parameterize this test to run on both CPU and CUDA.abs
# The question is how to have the @needs_cuda decorator with the pytest.mark.parametrize
# decorator on the same test.
@needs_cuda
def test_get_frames_by_pts_with_cuda(self):
decoder = create_from_file(str(NASA_VIDEO.path))
_add_video_stream(decoder, device="cuda")
scan_all_streams_to_update_metadata(decoder)
stream_index = 3

# Note: 13.01 should give the last video frame for the NASA video
timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3]

expected_frames = [
get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps
]

frames, *_ = get_frames_by_pts(
decoder,
stream_index=stream_index,
timestamps=timestamps,
)
for frame, expected_frame in zip(frames, expected_frames):
assert_tensor_equal(frame, expected_frame)

# first and last frame should be equal, at pts=2 [+ eps]. We then modify
# the first frame and assert that it's now different from the last
# frame. This ensures a copy was properly made during the de-duplication
# logic.
assert_tensor_equal(frames[0], frames[-1])
frames[0] += 20
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

def test_pts_apis_against_index_ref(self):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_pts_apis_against_index_ref(self, device):
# Non-regression test for https://github.com/pytorch/torchcodec/pull/287
# Get all frames in the video, then query all frames with all time-based
# APIs exactly where those frames are supposed to start. We assert that
# we get the expected frame.
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
add_video_stream(decoder)
add_video_stream(decoder, device=device)

metadata = get_json_metadata(decoder)
metadata_dict = json.loads(metadata)
Expand Down Expand Up @@ -316,55 +271,57 @@ def test_pts_apis_against_index_ref(self):
)
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)

def test_get_frames_in_range(self):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frames_in_range(self, device):
tensor_compare_function = get_frame_compare_function(device)
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
add_video_stream(decoder)
add_video_stream(decoder, device=device)

# ensure that the degenerate case of a range of size 1 works
ref_frame0 = NASA_VIDEO.get_frame_data_by_range(0, 1)
bulk_frame0, *_ = get_frames_in_range(decoder, stream_index=3, start=0, stop=1)
assert_tensor_equal(ref_frame0, bulk_frame0)
tensor_compare_function(ref_frame0, bulk_frame0)

ref_frame1 = NASA_VIDEO.get_frame_data_by_range(1, 2)
bulk_frame1, *_ = get_frames_in_range(decoder, stream_index=3, start=1, stop=2)
assert_tensor_equal(ref_frame1, bulk_frame1)
tensor_compare_function(ref_frame1, bulk_frame1)

ref_frame389 = NASA_VIDEO.get_frame_data_by_range(389, 390)
bulk_frame389, *_ = get_frames_in_range(
decoder, stream_index=3, start=389, stop=390
)
assert_tensor_equal(ref_frame389, bulk_frame389)
tensor_compare_function(ref_frame389, bulk_frame389)

# contiguous ranges
ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9)
bulk_frames0_9, *_ = get_frames_in_range(
decoder, stream_index=3, start=0, stop=9
)
assert_tensor_equal(ref_frames0_9, bulk_frames0_9)
tensor_compare_function(ref_frames0_9, bulk_frames0_9)

ref_frames4_8 = NASA_VIDEO.get_frame_data_by_range(4, 8)
bulk_frames4_8, *_ = get_frames_in_range(
decoder, stream_index=3, start=4, stop=8
)
assert_tensor_equal(ref_frames4_8, bulk_frames4_8)
tensor_compare_function(ref_frames4_8, bulk_frames4_8)

# ranges with a stride
ref_frames15_35 = NASA_VIDEO.get_frame_data_by_range(15, 36, 5)
bulk_frames15_35, *_ = get_frames_in_range(
decoder, stream_index=3, start=15, stop=36, step=5
)
assert_tensor_equal(ref_frames15_35, bulk_frames15_35)
tensor_compare_function(ref_frames15_35, bulk_frames15_35)

ref_frames0_9_2 = NASA_VIDEO.get_frame_data_by_range(0, 9, 2)
bulk_frames0_9_2, *_ = get_frames_in_range(
decoder, stream_index=3, start=0, stop=9, step=2
)
assert_tensor_equal(ref_frames0_9_2, bulk_frames0_9_2)
tensor_compare_function(ref_frames0_9_2, bulk_frames0_9_2)

# an empty range is valid!
empty_frame, *_ = get_frames_in_range(decoder, stream_index=3, start=5, stop=5)
assert_tensor_equal(empty_frame, NASA_VIDEO.empty_chw_tensor)
tensor_compare_function(empty_frame, NASA_VIDEO.empty_chw_tensor)

def test_throws_exception_at_eof(self):
decoder = create_from_file(str(NASA_VIDEO.path))
Expand Down
36 changes: 26 additions & 10 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,22 @@
import torch


# Decorator for skipping CUDA tests when CUDA isn't available
# Decorator for skipping CUDA tests when CUDA isn't available. The tests are
# effectively marked to be skipped in pytest_collection_modifyitems() of
# conftest.py
def needs_cuda(test_item):
if not torch.cuda.is_available():
if os.environ.get("FAIL_WITHOUT_CUDA") == "1":
raise RuntimeError("CUDA is required for this test")
return pytest.mark.skip(reason="CUDA not available")(test_item)
return test_item
return pytest.mark.needs_cuda(test_item)


def cpu_and_cuda():
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))


def get_frame_compare_function(device):
if device == "cpu":
return assert_tensor_equal
else:
return assert_tensor_close_on_at_least


# For use with decoded data frames. On Linux, we expect exact, bit-for-bit equality. On
Expand All @@ -34,10 +43,17 @@ def assert_tensor_equal(*args, **kwargs):


# Asserts that at least `percentage`% of the values are within the absolute tolerance.
def assert_tensor_close_on_at_least(frame1, frame2, percentage=99.7, abs_tolerance=20):
diff = (frame2.float() - frame1.float()).abs()
diff_percentage = 100.0 - percentage
assert (diff > abs_tolerance).float().mean() <= diff_percentage / 100.0
def assert_tensor_close_on_at_least(tensor1, tensor2, percentage=90, abs_tolerance=20):
tensor1 = tensor1.to("cpu")
tensor2 = tensor2.to("cpu")
diff = (tensor2.float() - tensor1.float()).abs()
max_diff_percentage = 100.0 - percentage
if diff.sum() == 0:
return
diff_percentage = (diff > abs_tolerance).float().mean() * 100.0
assert (
diff_percentage <= max_diff_percentage
), f"Diff too high: {diff_percentage} > {max_diff_percentage}"


# For use with floating point metadata, or in other instances where we are not confident
Expand Down
Loading