Skip to content

Commit 976bd2c

Browse files
committed
add codec selection + logic, add valid codecs test
1 parent b35005d commit 976bd2c

File tree

6 files changed

+84
-11
lines changed

6 files changed

+84
-11
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -649,9 +649,30 @@ VideoEncoder::VideoEncoder(
649649

650650
void VideoEncoder::initializeEncoder(
651651
const VideoStreamOptions& videoStreamOptions) {
652-
const AVCodec* avCodec =
653-
avcodec_find_encoder(avFormatContext_->oformat->video_codec);
654-
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
652+
const AVCodec* avCodec = nullptr;
653+
// If codec arg is provided, find codec using logic similar to FFmpeg:
654+
// https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835
655+
if (videoStreamOptions.codec.has_value()) {
656+
const std::string& codec = videoStreamOptions.codec.value();
657+
// Try to find codec by name ("libx264", "libsvtav1")
658+
avCodec = avcodec_find_encoder_by_name(codec.c_str());
659+
// Try to find by codec descriptor ("h264", "av1")
660+
if (!avCodec) {
661+
const AVCodecDescriptor* desc =
662+
avcodec_descriptor_get_by_name(codec.c_str());
663+
if (desc) {
664+
avCodec = avcodec_find_encoder(desc->id);
665+
}
666+
}
667+
TORCH_CHECK(
668+
avCodec != nullptr,
669+
"Video codec ",
670+
codec,
671+
" not found. Provide a codec name ('libx264', 'libx265') or a codec descriptor ('h264', 'hevc'), or do not specify a codec to use the default codec.");
672+
} else {
673+
avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec);
674+
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
675+
}
655676

656677
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
657678
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");

src/torchcodec/_core/StreamOptions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct VideoStreamOptions {
4545
std::string_view deviceVariant = "ffmpeg";
4646

4747
// Encoding options
48+
std::optional<std::string> codec;
4849
// TODO-VideoEncoder: Consider adding other optional fields here
4950
// (bit rate, gop size, max b frames, preset)
5051
std::optional<int> crf;

src/torchcodec/_core/custom_ops.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? pixel_format=None, int? crf=None) -> ()");
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, int? crf=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? pixel_format=None, int? crf=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, int? crf=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? pixel_format=None, int? crf=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, int? crf=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -603,9 +603,11 @@ void encode_video_to_file(
603603
const at::Tensor& frames,
604604
int64_t frame_rate,
605605
std::string_view file_name,
606+
std::optional<std::string> codec = std::nullopt,
606607
std::optional<std::string> pixel_format = std::nullopt,
607608
std::optional<int64_t> crf = std::nullopt) {
608609
VideoStreamOptions videoStreamOptions;
610+
videoStreamOptions.codec = codec;
609611
videoStreamOptions.pixelFormat = pixel_format;
610612
videoStreamOptions.crf = crf;
611613
VideoEncoder(
@@ -620,10 +622,12 @@ at::Tensor encode_video_to_tensor(
620622
const at::Tensor& frames,
621623
int64_t frame_rate,
622624
std::string_view format,
625+
std::optional<std::string> codec = std::nullopt,
623626
std::optional<std::string> pixel_format = std::nullopt,
624627
std::optional<int64_t> crf = std::nullopt) {
625628
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
626629
VideoStreamOptions videoStreamOptions;
630+
videoStreamOptions.codec = codec;
627631
videoStreamOptions.pixelFormat = pixel_format;
628632
videoStreamOptions.crf = crf;
629633
return VideoEncoder(
@@ -640,6 +644,7 @@ void _encode_video_to_file_like(
640644
int64_t frame_rate,
641645
std::string_view format,
642646
int64_t file_like_context,
647+
std::optional<std::string> codec = std::nullopt,
643648
std::optional<std::string> pixel_format = std::nullopt,
644649
std::optional<int64_t> crf = std::nullopt) {
645650
auto fileLikeContext =
@@ -649,6 +654,7 @@ void _encode_video_to_file_like(
649654
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
650655

651656
VideoStreamOptions videoStreamOptions;
657+
videoStreamOptions.codec = codec;
652658
videoStreamOptions.pixelFormat = pixel_format;
653659
videoStreamOptions.crf = crf;
654660

src/torchcodec/_core/ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def encode_video_to_file_like(
213213
frame_rate: int,
214214
format: str,
215215
file_like: Union[io.RawIOBase, io.BufferedIOBase],
216+
codec: Optional[str] = None,
216217
crf: Optional[int] = None,
217218
pixel_format: Optional[str] = None,
218219
) -> None:
@@ -223,6 +224,7 @@ def encode_video_to_file_like(
223224
frame_rate: Frame rate in frames per second
224225
format: Video format (e.g., "mp4", "mov", "mkv")
225226
file_like: File-like object that supports write() and seek() methods
227+
codec: Optional codec name (e.g., "libx264", "h264")
226228
crf: Optional constant rate factor for encoding quality
227229
pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p")
228230
"""
@@ -233,6 +235,7 @@ def encode_video_to_file_like(
233235
frame_rate,
234236
format,
235237
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
238+
codec,
236239
pixel_format,
237240
crf,
238241
)
@@ -322,8 +325,9 @@ def encode_video_to_file_abstract(
322325
frames: torch.Tensor,
323326
frame_rate: int,
324327
filename: str,
325-
crf: Optional[int] = None,
328+
codec: Optional[str],
326329
pixel_format: Optional[str] = None,
330+
crf: Optional[int] = None,
327331
) -> None:
328332
return
329333

@@ -333,6 +337,7 @@ def encode_video_to_tensor_abstract(
333337
frames: torch.Tensor,
334338
frame_rate: int,
335339
format: str,
340+
codec: Optional[str],
336341
crf: Optional[int] = None,
337342
pixel_format: Optional[str] = None,
338343
) -> torch.Tensor:
@@ -345,8 +350,9 @@ def _encode_video_to_file_like_abstract(
345350
frame_rate: int,
346351
format: str,
347352
file_like_context: int,
348-
crf: Optional[int] = None,
353+
codec: Optional[str] = None,
349354
pixel_format: Optional[str] = None,
355+
crf: Optional[int] = None,
350356
) -> None:
351357
return
352358

src/torchcodec/encoders/_video_encoder.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def to_file(
3636
self,
3737
dest: Union[str, Path],
3838
*,
39+
codec: Optional[str] = None,
3940
pixel_format: Optional[str] = None,
4041
) -> None:
4142
"""Encode frames into a file.
@@ -44,37 +45,46 @@ def to_file(
4445
dest (str or ``pathlib.Path``): The path to the output file, e.g.
4546
``video.mp4``. The extension of the file determines the video
4647
container format.
48+
codec (str, optional): The codec to use for encoding (e.g., "libx264",
49+
"h264"). If not specified, the default codec
50+
for the container format will be used.
4751
pixel_format (str, optional): The pixel format for encoding (e.g.,
4852
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
4953
"""
5054
_core.encode_video_to_file(
5155
frames=self._frames,
5256
frame_rate=self._frame_rate,
5357
filename=str(dest),
58+
codec=codec,
5459
pixel_format=pixel_format,
5560
)
5661

5762
def to_tensor(
5863
self,
5964
format: str,
6065
*,
66+
codec: Optional[str] = None,
6167
pixel_format: Optional[str] = None,
6268
) -> Tensor:
6369
"""Encode frames into raw bytes, as a 1D uint8 Tensor.
6470
6571
Args:
6672
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
67-
"mkv", "avi", "webm", "flv", or "gif"
73+
"mkv", "avi", "webm", "flv", etc.
74+
codec (str, optional): The codec to use for encoding (e.g., "libx264",
75+
"h264"). If not specified, the default codec
76+
for the container format will be used.
6877
pixel_format (str, optional): The pixel format to encode frames into (e.g.,
6978
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
7079
7180
Returns:
72-
Tensor: The raw encoded bytes as 4D uint8 Tensor.
81+
Tensor: The raw encoded bytes as 1D uint8 Tensor.
7382
"""
7483
return _core.encode_video_to_tensor(
7584
frames=self._frames,
7685
frame_rate=self._frame_rate,
7786
format=format,
87+
codec=codec,
7888
pixel_format=pixel_format,
7989
)
8090

@@ -83,6 +93,7 @@ def to_file_like(
8393
file_like,
8494
format: str,
8595
*,
96+
codec: Optional[str] = None,
8697
pixel_format: Optional[str] = None,
8798
) -> None:
8899
"""Encode frames into a file-like object.
@@ -94,7 +105,10 @@ def to_file_like(
94105
``write(data: bytes) -> int`` and ``seek(offset: int, whence:
95106
int = 0) -> int``.
96107
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
97-
"mkv", "avi", "webm", "flv", or "gif".
108+
"mkv", "avi", "webm", "flv", etc.
109+
codec (str, optional): The codec to use for encoding (e.g., "libx264",
110+
"h264"). If not specified, the default codec
111+
for the container format will be used.
98112
pixel_format (str, optional): The pixel format for encoding (e.g.,
99113
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
100114
"""
@@ -103,5 +117,6 @@ def to_file_like(
103117
frame_rate=self._frame_rate,
104118
format=format,
105119
file_like=file_like,
120+
codec=codec,
106121
pixel_format=pixel_format,
107122
)

test/test_encoders.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,12 @@ def test_bad_input_parameterized(self, tmp_path, method):
605605
)
606606
getattr(encoder, method)(**valid_params)
607607

608+
with pytest.raises(
609+
RuntimeError,
610+
match=r"Video codec invalid_codec_name not found.",
611+
):
612+
encoder.to_file(str(tmp_path / "output.mp4"), codec="invalid_codec_name")
613+
608614
def test_bad_input(self, tmp_path):
609615
encoder = VideoEncoder(
610616
frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8),
@@ -629,6 +635,24 @@ def test_bad_input(self, tmp_path):
629635
):
630636
encoder.to_tensor(format="bad_format")
631637

638+
@pytest.mark.parametrize("method", ["to_file", "to_tensor", "to_file_like"])
639+
@pytest.mark.parametrize("codec", ["h264", "hevc", "av1", "libx264", None])
640+
def test_codec_valid_values(self, method, codec, tmp_path):
641+
if method == "to_file":
642+
valid_params = {"dest": str(tmp_path / "test.mp4")}
643+
elif method == "to_tensor":
644+
valid_params = {"format": "mp4"}
645+
elif method == "to_file_like":
646+
valid_params = dict(file_like=io.BytesIO(), format="mp4")
647+
else:
648+
raise ValueError(f"Unknown method: {method}")
649+
650+
encoder = VideoEncoder(
651+
frames=torch.zeros((5, 3, 128, 128), dtype=torch.uint8),
652+
frame_rate=30,
653+
)
654+
getattr(encoder, method)(**valid_params, codec=codec)
655+
632656
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
633657
def test_pixel_format_errors(self, method, tmp_path):
634658
frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8)

0 commit comments

Comments
 (0)