Skip to content

Commit b6f45f6

Browse files
committed
compare frame ts in test, update framerate set
1 parent c1b3adc commit b6f45f6

File tree

2 files changed

+51
-32
lines changed

2 files changed

+51
-32
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -788,9 +788,8 @@ void VideoEncoder::initializeEncoder(
788788
avCodecContext_->height = outHeight_;
789789
avCodecContext_->pix_fmt = outPixelFormat_;
790790
// TODO-VideoEncoder: Add and utilize output frame_rate option
791-
AVRational frameRate = av_d2q(inFrameRate_, INT_MAX);
792-
avCodecContext_->time_base = av_inv_q(frameRate);
793-
avCodecContext_->framerate = frameRate;
791+
avCodecContext_->framerate = av_d2q(inFrameRate_, INT_MAX);
792+
avCodecContext_->time_base = av_inv_q(avCodecContext_->framerate);
794793

795794
// Set flag for containers that require extradata to be in the codec context
796795
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {

test/test_encoders.py

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ class TestVideoEncoder:
572572
def decode(self, source=None) -> torch.Tensor:
573573
return VideoDecoder(source).get_frames_in_range(start=0, stop=60).data
574574

575+
# TODO: add average_fps field to TestVideo asset
575576
def decode_and_get_frame_rate(self, source=None):
576577
decoder = VideoDecoder(source)
577578
frames = decoder.get_frames_in_range(start=0, stop=60).data
@@ -604,6 +605,27 @@ def _get_video_metadata(self, file_path, fields):
604605
metadata[key] = value
605606
return metadata
606607

608+
def _get_frames_info(self, file_path):
609+
"""Helper function to get frame info (pts, dts, etc.) using ffprobe."""
610+
result = subprocess.run(
611+
[
612+
"ffprobe",
613+
"-v",
614+
"error",
615+
"-select_streams",
616+
"v:0",
617+
"-show_entries",
618+
"frame=pts,pkt_pts,dts,pkt_duration,duration",
619+
"-of",
620+
"json",
621+
str(file_path),
622+
],
623+
capture_output=True,
624+
check=True,
625+
text=True,
626+
)
627+
return json.loads(result.stdout)["frames"]
628+
607629
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
608630
def test_bad_input_parameterized(self, tmp_path, method):
609631
if method == "to_file":
@@ -925,8 +947,9 @@ def test_against_to_file(self, tmp_path, format, method):
925947
],
926948
)
927949
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
950+
@pytest.mark.parametrize("frame_rate", [30, 29.97])
928951
def test_video_encoder_against_ffmpeg_cli(
929-
self, tmp_path, format, encode_params, method
952+
self, tmp_path, format, encode_params, method, frame_rate
930953
):
931954
ffmpeg_version = get_ffmpeg_major_version()
932955
if format == "webm" and (
@@ -941,7 +964,7 @@ def test_video_encoder_against_ffmpeg_cli(
941964
if format in ("avi", "flv") and pixel_format == "yuv444p":
942965
pytest.skip(f"Default codec for {format} does not support {pixel_format}")
943966

944-
source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path)
967+
source_frames = self.decode(TEST_SRC_2_720P.path)
945968

946969
# Encode with FFmpeg CLI
947970
temp_raw_path = str(tmp_path / "temp_input.raw")
@@ -1024,16 +1047,30 @@ def test_video_encoder_against_ffmpeg_cli(
10241047

10251048
# Check that video metadata is the same
10261049
if method == "to_file":
1027-
fields = ["duration", "duration_ts", "r_frame_rate", "nb_frames"]
1028-
ffmpeg_metadata = self._get_video_metadata(
1029-
ffmpeg_encoded_path,
1030-
fields=fields,
1031-
)
1032-
encoder_metadata = self._get_video_metadata(
1033-
encoder_output_path,
1034-
fields=fields,
1035-
)
1036-
assert ffmpeg_metadata == encoder_metadata
1050+
# mkv container format handles frame_rate differently, so we skip it here
1051+
if format != "mkv":
1052+
fields = ["duration", "duration_ts", "r_frame_rate", "nb_frames"]
1053+
ffmpeg_metadata = self._get_video_metadata(
1054+
ffmpeg_encoded_path,
1055+
fields=fields,
1056+
)
1057+
encoder_metadata = self._get_video_metadata(
1058+
encoder_output_path,
1059+
fields=fields,
1060+
)
1061+
assert ffmpeg_metadata == encoder_metadata
1062+
1063+
# Check that frame timestamps and duration are the same
1064+
ffmpeg_frames_info = self._get_frames_info(ffmpeg_encoded_path)
1065+
encoder_frames_info = self._get_frames_info(encoder_output_path)
1066+
1067+
assert len(ffmpeg_frames_info) == len(encoder_frames_info)
1068+
for ffmpeg_frame, encoder_frame in zip(
1069+
ffmpeg_frames_info, encoder_frames_info
1070+
):
1071+
for key in ffmpeg_frame.keys():
1072+
assert key in encoder_frame
1073+
assert ffmpeg_frame[key] == encoder_frame[key]
10371074

10381075
def test_to_file_like_custom_file_object(self):
10391076
"""Test to_file_like with a custom file-like object that implements write and seek."""
@@ -1214,20 +1251,3 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range
12141251
assert metadata["profile"].lower() == expected_profile
12151252
assert metadata["color_space"] == colorspace
12161253
assert metadata["color_range"] == color_range
1217-
1218-
@pytest.mark.parametrize("frame_rate", [29.97, 59.94, 5.001])
1219-
def test_fractional_frame_rate(self, tmp_path, frame_rate):
1220-
source_frames = torch.zeros((10, 3, 64, 64), dtype=torch.uint8)
1221-
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
1222-
output_path = str(tmp_path / "output.mp4")
1223-
encoder.to_file(dest=output_path)
1224-
# Assert the encoded frame rate via file metadata
1225-
metadata = self._get_video_metadata(output_path, fields=["r_frame_rate"])
1226-
num, den = metadata["r_frame_rate"].split("/")
1227-
encoded_frame_rate = int(num) / int(den)
1228-
assert encoded_frame_rate == pytest.approx(frame_rate, abs=1e-3)
1229-
# Assert the decoded frame rate matches the input frame rate
1230-
_decoded_frames, decoded_frame_rate = self.decode_and_get_frame_rate(
1231-
output_path
1232-
)
1233-
assert decoded_frame_rate == pytest.approx(frame_rate, abs=1e-3)

0 commit comments

Comments
 (0)