Skip to content

Commit 09e51b2

Browse files
committed
add crf to api, move and update tests
1 parent afd5aba commit 09e51b2

File tree

3 files changed

+264
-282
lines changed

3 files changed

+264
-282
lines changed

src/torchcodec/encoders/_video_encoder.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,40 @@ def __init__(self, frames: Tensor, *, frame_rate: int):
3535
def to_file(
3636
self,
3737
dest: Union[str, Path],
38+
*,
39+
crf: int = None,
3840
) -> None:
3941
"""Encode frames into a file.
4042
4143
Args:
4244
dest (str or ``pathlib.Path``): The path to the output file, e.g.
4345
``video.mp4``. The extension of the file determines the video
4446
container format.
47+
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
48+
mean better quality. Valid range depends on the encoder (commonly 0-51).
49+
Defaults to None (which will use encoder's default).
4550
"""
4651
_core.encode_video_to_file(
4752
frames=self._frames,
4853
frame_rate=self._frame_rate,
4954
filename=str(dest),
55+
crf=crf,
5056
)
5157

5258
def to_tensor(
5359
self,
5460
format: str,
61+
*,
62+
crf: int = None,
5563
) -> Tensor:
5664
"""Encode frames into raw bytes, as a 1D uint8 Tensor.
5765
5866
Args:
5967
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
60-
"mkv", "avi", "webm", "flv", or "gif"
68+
"mkv", "avi", "webm", "flv", etc.
69+
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
70+
mean better quality. Valid range depends on the encoder (commonly 0-51).
71+
Defaults to None (which will use encoder's default).
6172
6273
Returns:
6374
Tensor: The raw encoded bytes as 4D uint8 Tensor.
@@ -66,12 +77,15 @@ def to_tensor(
6677
frames=self._frames,
6778
frame_rate=self._frame_rate,
6879
format=format,
80+
crf=crf,
6981
)
7082

7183
def to_file_like(
7284
self,
7385
file_like,
7486
format: str,
87+
*,
88+
crf: int = None,
7589
) -> None:
7690
"""Encode frames into a file-like object.
7791
@@ -82,11 +96,15 @@ def to_file_like(
8296
``write(data: bytes) -> int`` and ``seek(offset: int, whence:
8397
int = 0) -> int``.
8498
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
85-
"mkv", "avi", "webm", "flv", or "gif".
99+
"mkv", "avi", "webm", "flv", etc.
100+
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
101+
mean better quality. Valid range depends on the encoder (commonly 0-51).
102+
Defaults to None (which will use encoder's default).
86103
"""
87104
_core.encode_video_to_file_like(
88105
frames=self._frames,
89106
frame_rate=self._frame_rate,
90107
format=format,
91108
file_like=file_like,
109+
crf=crf,
92110
)

test/test_encoders.py

Lines changed: 244 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytest
1111
import torch
12-
from torchcodec.decoders import AudioDecoder
12+
from torchcodec.decoders import AudioDecoder, VideoDecoder
1313

1414
from torchcodec.encoders import AudioEncoder, VideoEncoder
1515

@@ -20,7 +20,9 @@
2020
in_fbcode,
2121
IS_WINDOWS,
2222
NASA_AUDIO_MP3,
23+
psnr,
2324
SINE_MONO_S32,
25+
TEST_SRC_2_720P,
2426
TestContainerFile,
2527
)
2628

@@ -567,6 +569,9 @@ def write(self, data):
567569

568570

569571
class TestVideoEncoder:
572+
def decode(self, source=None) -> torch.Tensor:
573+
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)
574+
570575
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
571576
def test_bad_input_parameterized(self, tmp_path, method):
572577
if method == "to_file":
@@ -676,3 +681,241 @@ def encode_to_tensor(frames):
676681
torch.testing.assert_close(
677682
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
678683
)
684+
685+
@pytest.mark.parametrize(
686+
"format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow))
687+
)
688+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
689+
def test_round_trip(self, tmp_path, format, method):
690+
# Test that decode(encode(decode(frames))) == decode(frames)
691+
ffmpeg_version = get_ffmpeg_major_version()
692+
# In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm.
693+
# As a result, we skip the round trip test.
694+
if ffmpeg_version == 6 and format != "webm":
695+
pytest.skip(
696+
f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test."
697+
)
698+
if format == "webm" and (
699+
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
700+
):
701+
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
702+
source_frames = self.decode(TEST_SRC_2_720P.path).data
703+
704+
# Frame rate is fixed with num frames decoded
705+
encoder = VideoEncoder(frames=source_frames, frame_rate=30)
706+
707+
if method == "to_file":
708+
encoded_path = str(tmp_path / f"encoder_output.{format}")
709+
encoder.to_file(dest=encoded_path, crf=0)
710+
round_trip_frames = self.decode(encoded_path).data
711+
elif method == "to_tensor":
712+
encoded_tensor = encoder.to_tensor(format=format, crf=0)
713+
round_trip_frames = self.decode(encoded_tensor).data
714+
elif method == "to_file_like":
715+
file_like = io.BytesIO()
716+
encoder.to_file_like(file_like=file_like, format=format, crf=0)
717+
round_trip_frames = self.decode(file_like.getvalue()).data
718+
else:
719+
raise ValueError(f"Unknown method: {method}")
720+
721+
assert source_frames.shape == round_trip_frames.shape
722+
assert source_frames.dtype == round_trip_frames.dtype
723+
724+
# If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels
725+
# are within a higher tolerance.
726+
if ffmpeg_version == 6:
727+
assert_close = partial(assert_tensor_close_on_at_least, percentage=99)
728+
atol = 15
729+
else:
730+
assert_close = torch.testing.assert_close
731+
atol = 2
732+
for s_frame, rt_frame in zip(source_frames, round_trip_frames):
733+
assert psnr(s_frame, rt_frame) > 30
734+
assert_close(s_frame, rt_frame, atol=atol, rtol=0)
735+
736+
@pytest.mark.parametrize(
737+
"format",
738+
(
739+
"mov",
740+
"mp4",
741+
"avi",
742+
"mkv",
743+
"flv",
744+
"gif",
745+
pytest.param("webm", marks=pytest.mark.slow),
746+
),
747+
)
748+
@pytest.mark.parametrize("method", ("to_tensor", "to_file_like"))
749+
def test_against_to_file(self, tmp_path, format, method):
750+
# Test that to_file, to_tensor, and to_file_like produce the same results
751+
ffmpeg_version = get_ffmpeg_major_version()
752+
if format == "webm" and (
753+
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
754+
):
755+
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
756+
757+
source_frames = self.decode(TEST_SRC_2_720P.path).data
758+
encoder = VideoEncoder(frames=source_frames, frame_rate=30)
759+
760+
encoded_file = tmp_path / f"output.{format}"
761+
encoder.to_file(dest=encoded_file, crf=0)
762+
763+
if method == "to_tensor":
764+
encoded_output = encoder.to_tensor(format=format, crf=0)
765+
else: # to_file_like
766+
file_like = io.BytesIO()
767+
encoder.to_file_like(file_like=file_like, format=format, crf=0)
768+
encoded_output = file_like.getvalue()
769+
770+
torch.testing.assert_close(
771+
self.decode(encoded_file).data,
772+
self.decode(encoded_output).data,
773+
atol=0,
774+
rtol=0,
775+
)
776+
777+
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
778+
@pytest.mark.parametrize(
779+
"format",
780+
(
781+
"mov",
782+
"mp4",
783+
"avi",
784+
"mkv",
785+
"flv",
786+
"gif",
787+
pytest.param("webm", marks=pytest.mark.slow),
788+
),
789+
)
790+
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
791+
# Encode samples with our encoder and with the FFmpeg CLI, and check
792+
# that both decoded outputs are similar
793+
ffmpeg_version = get_ffmpeg_major_version()
794+
if format == "webm" and (
795+
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
796+
):
797+
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
798+
799+
source_frames = self.decode(TEST_SRC_2_720P.path).data
800+
801+
# Encode with FFmpeg CLI
802+
temp_raw_path = str(tmp_path / "temp_input.raw")
803+
with open(temp_raw_path, "wb") as f:
804+
f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes())
805+
806+
ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}")
807+
frame_rate = 30
808+
crf = 0
809+
# Some codecs (ex. MPEG4) do not support CRF.
810+
# Flags not supported by the selected codec will be ignored.
811+
ffmpeg_cmd = [
812+
"ffmpeg",
813+
"-y",
814+
"-f",
815+
"rawvideo",
816+
"-pix_fmt",
817+
"rgb24",
818+
"-s",
819+
f"{source_frames.shape[3]}x{source_frames.shape[2]}",
820+
"-r",
821+
str(frame_rate),
822+
"-i",
823+
temp_raw_path,
824+
"-crf",
825+
str(crf),
826+
ffmpeg_encoded_path,
827+
]
828+
subprocess.run(ffmpeg_cmd, check=True)
829+
830+
# Encode with our video encoder
831+
encoder_output_path = str(tmp_path / f"encoder_output.{format}")
832+
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
833+
encoder.to_file(dest=encoder_output_path, crf=crf)
834+
835+
ffmpeg_frames = self.decode(ffmpeg_encoded_path).data
836+
encoder_frames = self.decode(encoder_output_path).data
837+
838+
assert ffmpeg_frames.shape[0] == encoder_frames.shape[0]
839+
840+
# If FFmpeg selects a codec or pixel format that uses qscale (not crf),
841+
# the VideoEncoder outputs *slightly* different frames.
842+
# There may be additional subtle differences in the encoder.
843+
percentage = 94 if ffmpeg_version == 6 or format == "avi" else 99
844+
845+
# Check that PSNR between both encoded versions is high
846+
for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames):
847+
res = psnr(ff_frame, enc_frame)
848+
assert res > 30
849+
assert_tensor_close_on_at_least(
850+
ff_frame, enc_frame, percentage=percentage, atol=2
851+
)
852+
853+
def test_to_file_like_custom_file_object(self):
854+
"""Test to_file_like with a custom file-like object that implements write and seek."""
855+
856+
class CustomFileObject:
857+
def __init__(self):
858+
self._file = io.BytesIO()
859+
860+
def write(self, data):
861+
return self._file.write(data)
862+
863+
def seek(self, offset, whence=0):
864+
return self._file.seek(offset, whence)
865+
866+
def get_encoded_data(self):
867+
return self._file.getvalue()
868+
869+
source_frames = self.decode(TEST_SRC_2_720P.path).data
870+
encoder = VideoEncoder(frames=source_frames, frame_rate=30)
871+
872+
file_like = CustomFileObject()
873+
encoder.to_file_like(file_like, format="mp4", crf=0)
874+
decoded_frames = self.decode(file_like.get_encoded_data())
875+
876+
torch.testing.assert_close(
877+
decoded_frames.data,
878+
source_frames,
879+
atol=2,
880+
rtol=0,
881+
)
882+
883+
def test_to_file_like_real_file(self, tmp_path):
884+
"""Test to_file_like with a real file opened in binary write mode."""
885+
source_frames = self.decode(TEST_SRC_2_720P.path).data
886+
encoder = VideoEncoder(frames=source_frames, frame_rate=30)
887+
888+
file_path = tmp_path / "test_file_like.mp4"
889+
890+
with open(file_path, "wb") as file_like:
891+
encoder.to_file_like(file_like, format="mp4", crf=0)
892+
decoded_frames = self.decode(str(file_path))
893+
894+
torch.testing.assert_close(
895+
decoded_frames.data,
896+
source_frames,
897+
atol=2,
898+
rtol=0,
899+
)
900+
901+
def test_to_file_like_bad_methods(self):
902+
source_frames = self.decode(TEST_SRC_2_720P.path).data
903+
encoder = VideoEncoder(frames=source_frames, frame_rate=30)
904+
905+
class NoWriteMethod:
906+
def seek(self, offset, whence=0):
907+
return 0
908+
909+
with pytest.raises(
910+
RuntimeError, match="File like object must implement a write method"
911+
):
912+
encoder.to_file_like(NoWriteMethod(), format="mp4")
913+
914+
class NoSeekMethod:
915+
def write(self, data):
916+
return len(data)
917+
918+
with pytest.raises(
919+
RuntimeError, match="File like object must implement a seek method"
920+
):
921+
encoder.to_file_like(NoSeekMethod(), format="mp4")

0 commit comments

Comments
 (0)