Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/torchcodec/encoders/_video_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def to_file(
dest: Union[str, Path],
*,
pixel_format: Optional[str] = None,
crf: Optional[int] = None,
) -> None:
"""Encode frames into a file.

Expand All @@ -46,27 +47,35 @@ def to_file(
container format.
pixel_format (str, optional): The pixel format for encoding (e.g.,
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
mean better quality. Valid range depends on the encoder (commonly 0-51).
Defaults to None (which will use encoder's default).
"""
_core.encode_video_to_file(
frames=self._frames,
frame_rate=self._frame_rate,
filename=str(dest),
pixel_format=pixel_format,
crf=crf,
)

def to_tensor(
self,
format: str,
*,
pixel_format: Optional[str] = None,
crf: Optional[int] = None,
) -> Tensor:
"""Encode frames into raw bytes, as a 1D uint8 Tensor.

Args:
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
"mkv", "avi", "webm", "flv", or "gif"
"mkv", "avi", "webm", "flv", etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q - why remove "gif"? Do we not support it anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not test explicitly for it anymore, but it still works. I mostly wanted to amend the docstring to make it seem less like a finalized, exhaustive list of supported formats.

pixel_format (str, optional): The pixel format to encode frames into (e.g.,
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
mean better quality. Valid range depends on the encoder (commonly 0-51).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ever valid to be less than 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe -1 is valid and is equivalent to leaving crf unset. Otherwise, no negative values are valid.

Defaults to None (which will use encoder's default).

Returns:
Tensor: The raw encoded bytes as 4D uint8 Tensor.
Expand All @@ -76,6 +85,7 @@ def to_tensor(
frame_rate=self._frame_rate,
format=format,
pixel_format=pixel_format,
crf=crf,
)

def to_file_like(
Expand All @@ -84,6 +94,7 @@ def to_file_like(
format: str,
*,
pixel_format: Optional[str] = None,
crf: Optional[int] = None,
) -> None:
"""Encode frames into a file-like object.

Expand All @@ -94,14 +105,18 @@ def to_file_like(
``write(data: bytes) -> int`` and ``seek(offset: int, whence:
int = 0) -> int``.
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
"mkv", "avi", "webm", "flv", or "gif".
"mkv", "avi", "webm", "flv", etc.
pixel_format (str, optional): The pixel format for encoding (e.g.,
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
mean better quality. Valid range depends on the encoder (commonly 0-51).
Defaults to None (which will use encoder's default).
"""
_core.encode_video_to_file_like(
frames=self._frames,
frame_rate=self._frame_rate,
format=format,
file_like=file_like,
pixel_format=pixel_format,
crf=crf,
)
238 changes: 237 additions & 1 deletion test/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest
import torch
from torchcodec.decoders import AudioDecoder
from torchcodec.decoders import AudioDecoder, VideoDecoder

from torchcodec.encoders import AudioEncoder, VideoEncoder

Expand All @@ -20,7 +20,9 @@
in_fbcode,
IS_WINDOWS,
NASA_AUDIO_MP3,
psnr,
SINE_MONO_S32,
TEST_SRC_2_720P,
TestContainerFile,
)

Expand Down Expand Up @@ -567,6 +569,9 @@ def write(self, data):


class TestVideoEncoder:
def decode(self, source=None) -> torch.Tensor:
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)

@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
def test_bad_input_parameterized(self, tmp_path, method):
if method == "to_file":
Expand Down Expand Up @@ -700,3 +705,234 @@ def encode_to_tensor(frames):
torch.testing.assert_close(
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
)

@pytest.mark.parametrize(
"format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow))
)
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
def test_round_trip(self, tmp_path, format, method):
# Test that decode(encode(decode(frames))) == decode(frames)
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm" and (
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
):
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
source_frames = self.decode(TEST_SRC_2_720P.path).data

# Frame rate is fixed with num frames decoded
encoder = VideoEncoder(frames=source_frames, frame_rate=30)

if method == "to_file":
encoded_path = str(tmp_path / f"encoder_output.{format}")
encoder.to_file(dest=encoded_path, pixel_format="yuv444p", crf=0)
round_trip_frames = self.decode(encoded_path).data
elif method == "to_tensor":
encoded_tensor = encoder.to_tensor(
format=format, pixel_format="yuv444p", crf=0
)
round_trip_frames = self.decode(encoded_tensor).data
elif method == "to_file_like":
file_like = io.BytesIO()
encoder.to_file_like(
file_like=file_like, format=format, pixel_format="yuv444p", crf=0
)
round_trip_frames = self.decode(file_like.getvalue()).data
else:
raise ValueError(f"Unknown method: {method}")

assert source_frames.shape == round_trip_frames.shape
assert source_frames.dtype == round_trip_frames.dtype

for s_frame, rt_frame in zip(source_frames, round_trip_frames):
assert psnr(s_frame, rt_frame) > 30
torch.testing.assert_close(s_frame, rt_frame, atol=2, rtol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be failing for webm, you might need to use the previous logic

        # If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels
        # are within a higher tolerance.
        if ffmpeg_version == 6:
            assert_close = partial(assert_tensor_close_on_at_least, percentage=99)
            atol = 15
        else:
            assert_close = torch.testing.assert_close
            atol = 3 if format == "webm" else 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder - it seems I applied the webm tolerance to the wrong test. We can simply use atol = 3 if format == "webm" else 2 on the round_trip_test, though I'm not sure why webm needs this special handling.


@pytest.mark.parametrize(
"format",
(
"mov",
"mp4",
"avi",
"mkv",
"flv",
"gif",
pytest.param("webm", marks=pytest.mark.slow),
),
)
@pytest.mark.parametrize("method", ("to_tensor", "to_file_like"))
def test_against_to_file(self, tmp_path, format, method):
# Test that to_file, to_tensor, and to_file_like produce the same results
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm" and (
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
):
pytest.skip("Codec for webm is not available in this FFmpeg installation.")

source_frames = self.decode(TEST_SRC_2_720P.path).data
encoder = VideoEncoder(frames=source_frames, frame_rate=30)

encoded_file = tmp_path / f"output.{format}"
encoder.to_file(dest=encoded_file, crf=0)

if method == "to_tensor":
encoded_output = encoder.to_tensor(format=format, crf=0)
else: # to_file_like
file_like = io.BytesIO()
encoder.to_file_like(file_like=file_like, format=format, crf=0)
encoded_output = file_like.getvalue()

torch.testing.assert_close(
self.decode(encoded_file).data,
self.decode(encoded_output).data,
atol=0,
rtol=0,
)

@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
@pytest.mark.parametrize(
"format",
(
"mov",
"mp4",
"avi",
"mkv",
"flv",
pytest.param("webm", marks=pytest.mark.slow),
),
)
@pytest.mark.parametrize("pixel_format", ("yuv444p", "yuv420p"))
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format, pixel_format):
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm" and (
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
):
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
if format in ("avi", "flv") and pixel_format == "yuv444p":
pytest.skip(f"Default codec for {format} does not support {pixel_format}")

source_frames = self.decode(TEST_SRC_2_720P.path).data

# Encode with FFmpeg CLI
temp_raw_path = str(tmp_path / "temp_input.raw")
with open(temp_raw_path, "wb") as f:
f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes())

ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}")
frame_rate = 30
crf = 0
# Some codecs (ex. MPEG4) do not support CRF.
# Flags not supported by the selected codec will be ignored.
ffmpeg_cmd = [
"ffmpeg",
"-y",
"-f",
"rawvideo",
"-pix_fmt",
"rgb24", # Input format
"-s",
f"{source_frames.shape[3]}x{source_frames.shape[2]}",
"-r",
str(frame_rate),
"-i",
temp_raw_path,
"-pix_fmt",
pixel_format, # Output format
"-crf",
str(crf),
ffmpeg_encoded_path,
]
subprocess.run(ffmpeg_cmd, check=True)

# Encode with our video encoder
encoder_output_path = str(tmp_path / f"encoder_output.{format}")
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
encoder.to_file(dest=encoder_output_path, pixel_format=pixel_format, crf=crf)

ffmpeg_frames = self.decode(ffmpeg_encoded_path).data
encoder_frames = self.decode(encoder_output_path).data

assert ffmpeg_frames.shape[0] == encoder_frames.shape[0]

# If FFmpeg selects a codec or pixel format that uses qscale (not crf),
# the VideoEncoder outputs *slightly* different frames.
# There may be additional subtle differences in the encoder.
percentage = 94 if ffmpeg_version == 6 or format == "avi" else 99

atol = 3 if format == "webm" else 2
# Check that PSNR between both encoded versions is high
for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames):
res = psnr(ff_frame, enc_frame)
assert res > 30
assert_tensor_close_on_at_least(
ff_frame, enc_frame, percentage=percentage, atol=atol
)

def test_to_file_like_custom_file_object(self):
"""Test to_file_like with a custom file-like object that implements write and seek."""

class CustomFileObject:
def __init__(self):
self._file = io.BytesIO()

def write(self, data):
return self._file.write(data)

def seek(self, offset, whence=0):
return self._file.seek(offset, whence)

def get_encoded_data(self):
return self._file.getvalue()

source_frames = self.decode(TEST_SRC_2_720P.path).data
encoder = VideoEncoder(frames=source_frames, frame_rate=30)

file_like = CustomFileObject()
encoder.to_file_like(file_like, format="mp4", pixel_format="yuv444p", crf=0)
decoded_frames = self.decode(file_like.get_encoded_data())

torch.testing.assert_close(
decoded_frames.data,
source_frames,
atol=2,
rtol=0,
)

def test_to_file_like_real_file(self, tmp_path):
"""Test to_file_like with a real file opened in binary write mode."""
source_frames = self.decode(TEST_SRC_2_720P.path).data
encoder = VideoEncoder(frames=source_frames, frame_rate=30)

file_path = tmp_path / "test_file_like.mp4"

with open(file_path, "wb") as file_like:
encoder.to_file_like(file_like, format="mp4", pixel_format="yuv444p", crf=0)
decoded_frames = self.decode(str(file_path))

torch.testing.assert_close(
decoded_frames.data,
source_frames,
atol=2,
rtol=0,
)

def test_to_file_like_bad_methods(self):
source_frames = self.decode(TEST_SRC_2_720P.path).data
encoder = VideoEncoder(frames=source_frames, frame_rate=30)

class NoWriteMethod:
def seek(self, offset, whence=0):
return 0

with pytest.raises(
RuntimeError, match="File like object must implement a write method"
):
encoder.to_file_like(NoWriteMethod(), format="mp4")

class NoSeekMethod:
def write(self, data):
return len(data)

with pytest.raises(
RuntimeError, match="File like object must implement a seek method"
):
encoder.to_file_like(NoSeekMethod(), format="mp4")
Loading
Loading