Skip to content

Commit b7e52fb

Browse files
committed
crf is somtimes a double actually
1 parent 14e797b commit b7e52fb

File tree

6 files changed

+40
-23
lines changed

6 files changed

+40
-23
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -570,17 +570,16 @@ AVPixelFormat validatePixelFormat(
570570
TORCH_CHECK(false, errorMsg.str());
571571
}
572572

573-
void validateNumericOption(
573+
void validateDoubleOption(
574574
const AVCodec& avCodec,
575575
const char* optionName,
576-
int value) {
577-
// First determine if codec's private class is defined
576+
double value) {
578577
if (!avCodec.priv_class) {
579578
return;
580579
}
581580
const AVOption* option = av_opt_find2(
582-
// The obj arg must be converted from const AVClass* const* to non-const
583-
// void* First cast to remove const, then cast to void*
581+
// Convert obj arg from const AVClass* const* to non-const void*
582+
// First cast to remove const, then cast to void*
584583
const_cast<void*>(static_cast<const void*>(&avCodec.priv_class)),
585584
optionName,
586585
nullptr,
@@ -739,7 +738,7 @@ void VideoEncoder::initializeEncoder(
739738
// Apply videoStreamOptions
740739
AVDictionary* options = nullptr;
741740
if (videoStreamOptions.crf.has_value()) {
742-
validateNumericOption(*avCodec, "crf", videoStreamOptions.crf.value());
741+
validateDoubleOption(*avCodec, "crf", videoStreamOptions.crf.value());
743742
av_dict_set(
744743
&options,
745744
"crf",

src/torchcodec/_core/StreamOptions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct VideoStreamOptions {
4747
// Encoding options
4848
// TODO-VideoEncoder: Consider adding other optional fields here
4949
// (bit rate, gop size, max b frames, preset)
50-
std::optional<int> crf;
50+
std::optional<double> crf;
5151

5252
// Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p")
5353
// If not specified, uses codec's default format.

src/torchcodec/_core/custom_ops.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? pixel_format=None, int? crf=None) -> ()");
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? pixel_format=None, float? crf=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? pixel_format=None, int? crf=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? pixel_format=None, float? crf=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? pixel_format=None, int? crf=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? pixel_format=None, float? crf=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -604,7 +604,7 @@ void encode_video_to_file(
604604
int64_t frame_rate,
605605
std::string_view file_name,
606606
std::optional<std::string> pixel_format = std::nullopt,
607-
std::optional<int64_t> crf = std::nullopt) {
607+
std::optional<double> crf = std::nullopt) {
608608
VideoStreamOptions videoStreamOptions;
609609
videoStreamOptions.pixelFormat = pixel_format;
610610
videoStreamOptions.crf = crf;
@@ -621,7 +621,7 @@ at::Tensor encode_video_to_tensor(
621621
int64_t frame_rate,
622622
std::string_view format,
623623
std::optional<std::string> pixel_format = std::nullopt,
624-
std::optional<int64_t> crf = std::nullopt) {
624+
std::optional<double> crf = std::nullopt) {
625625
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
626626
VideoStreamOptions videoStreamOptions;
627627
videoStreamOptions.pixelFormat = pixel_format;
@@ -641,7 +641,7 @@ void _encode_video_to_file_like(
641641
std::string_view format,
642642
int64_t file_like_context,
643643
std::optional<std::string> pixel_format = std::nullopt,
644-
std::optional<int64_t> crf = std::nullopt) {
644+
std::optional<double> crf = std::nullopt) {
645645
auto fileLikeContext =
646646
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
647647
TORCH_CHECK(

src/torchcodec/_core/ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def encode_video_to_file_like(
213213
frame_rate: int,
214214
format: str,
215215
file_like: Union[io.RawIOBase, io.BufferedIOBase],
216-
crf: Optional[int] = None,
216+
crf: Optional[Union[int, float]] = None,
217217
pixel_format: Optional[str] = None,
218218
) -> None:
219219
"""Encode video frames to a file-like object.
@@ -322,7 +322,7 @@ def encode_video_to_file_abstract(
322322
frames: torch.Tensor,
323323
frame_rate: int,
324324
filename: str,
325-
crf: Optional[int] = None,
325+
crf: Optional[Union[int, float]] = None,
326326
pixel_format: Optional[str] = None,
327327
) -> None:
328328
return
@@ -333,7 +333,7 @@ def encode_video_to_tensor_abstract(
333333
frames: torch.Tensor,
334334
frame_rate: int,
335335
format: str,
336-
crf: Optional[int] = None,
336+
crf: Optional[Union[int, float]] = None,
337337
pixel_format: Optional[str] = None,
338338
) -> torch.Tensor:
339339
return torch.empty([], dtype=torch.long)
@@ -345,7 +345,7 @@ def _encode_video_to_file_like_abstract(
345345
frame_rate: int,
346346
format: str,
347347
file_like_context: int,
348-
crf: Optional[int] = None,
348+
crf: Optional[Union[int, float]] = None,
349349
pixel_format: Optional[str] = None,
350350
) -> None:
351351
return

src/torchcodec/encoders/_video_encoder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def to_file(
3737
dest: Union[str, Path],
3838
*,
3939
pixel_format: Optional[str] = None,
40-
crf: Optional[int] = None,
40+
crf: Optional[Union[int, float]] = None,
4141
) -> None:
4242
"""Encode frames into a file.
4343
@@ -47,7 +47,7 @@ def to_file(
4747
container format.
4848
pixel_format (str, optional): The pixel format for encoding (e.g.,
4949
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
50-
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
50+
crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values
5151
mean better quality. Valid range depends on the encoder (commonly 0-51).
5252
Defaults to None (which will use encoder's default).
5353
"""
@@ -64,7 +64,7 @@ def to_tensor(
6464
format: str,
6565
*,
6666
pixel_format: Optional[str] = None,
67-
crf: Optional[int] = None,
67+
crf: Optional[Union[int, float]] = None,
6868
) -> Tensor:
6969
"""Encode frames into raw bytes, as a 1D uint8 Tensor.
7070
@@ -73,7 +73,7 @@ def to_tensor(
7373
"mkv", "avi", "webm", "flv", etc.
7474
pixel_format (str, optional): The pixel format to encode frames into (e.g.,
7575
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
76-
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
76+
crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values
7777
mean better quality. Valid range depends on the encoder (commonly 0-51).
7878
Defaults to None (which will use encoder's default).
7979
@@ -94,7 +94,7 @@ def to_file_like(
9494
format: str,
9595
*,
9696
pixel_format: Optional[str] = None,
97-
crf: Optional[int] = None,
97+
crf: Optional[Union[int, float]] = None,
9898
) -> None:
9999
"""Encode frames into a file-like object.
100100
@@ -108,7 +108,7 @@ def to_file_like(
108108
"mkv", "avi", "webm", "flv", etc.
109109
pixel_format (str, optional): The pixel format for encoding (e.g.,
110110
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
111-
crf (int, optional): Constant Rate Factor for encoding quality. Lower values
111+
crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values
112112
mean better quality. Valid range depends on the encoder (commonly 0-51).
113113
Defaults to None (which will use encoder's default).
114114
"""

test/test_encoders.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,24 @@ def test_bad_input_parameterized(self, tmp_path, method):
617617
)
618618
getattr(encoder, method)(**valid_params, crf=-10)
619619

620+
@pytest.mark.parametrize("method", ["to_file", "to_tensor", "to_file_like"])
621+
@pytest.mark.parametrize("crf", [23, 23.5, -0.9])
622+
def test_crf_valid_values(self, method, crf, tmp_path):
623+
if method == "to_file":
624+
valid_params = {"dest": str(tmp_path / "test.mp4")}
625+
elif method == "to_tensor":
626+
valid_params = {"format": "mp4"}
627+
elif method == "to_file_like":
628+
valid_params = dict(file_like=io.BytesIO(), format="mp4")
629+
else:
630+
raise ValueError(f"Unknown method: {method}")
631+
632+
encoder = VideoEncoder(
633+
frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8),
634+
frame_rate=30,
635+
)
636+
getattr(encoder, method)(**valid_params, crf=crf)
637+
620638
def test_bad_input(self, tmp_path):
621639
encoder = VideoEncoder(
622640
frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8),

0 commit comments

Comments
 (0)