Skip to content

Commit f8b64ea

Browse files
committed
feedback
1 parent 2342837 commit f8b64ea

File tree

6 files changed

+45
-85
lines changed

6 files changed

+45
-85
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,7 @@ void AudioEncoder::flushBuffers() {
523523

524524
namespace {
525525

526-
torch::Tensor validateFrames(
527-
const torch::Tensor& frames,
528-
const torch::Device& device) {
526+
torch::Tensor validateFrames(const torch::Tensor& frames) {
529527
TORCH_CHECK(
530528
frames.dtype() == torch::kUInt8,
531529
"frames must have uint8 dtype, got ",
@@ -538,15 +536,6 @@ torch::Tensor validateFrames(
538536
frames.sizes()[1] == 3,
539537
"frame must have 3 channels (R, G, B), got ",
540538
frames.sizes()[1]);
541-
if (device.type() != torch::kCPU) {
542-
TORCH_CHECK(
543-
frames.is_cuda(),
544-
"When using CUDA encoding (device=",
545-
device.str(),
546-
"), frames must be on a CUDA device. Got frames on ",
547-
frames.device().str(),
548-
". Please move frames to a CUDA device: frames.to('cuda')");
549-
}
550539
return frames.contiguous();
551540
}
552541

@@ -676,8 +665,7 @@ VideoEncoder::VideoEncoder(
676665
double frameRate,
677666
std::string_view fileName,
678667
const VideoStreamOptions& videoStreamOptions)
679-
: frames_(validateFrames(frames, videoStreamOptions.device)),
680-
inFrameRate_(frameRate) {
668+
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
681669
setFFmpegLogLevel();
682670

683671
// Allocate output format context
@@ -710,7 +698,7 @@ VideoEncoder::VideoEncoder(
710698
std::string_view formatName,
711699
std::unique_ptr<AVIOContextHolder> avioContextHolder,
712700
const VideoStreamOptions& videoStreamOptions)
713-
: frames_(validateFrames(frames, videoStreamOptions.device)),
701+
: frames_(validateFrames(frames)),
714702
inFrameRate_(frameRate),
715703
avioContextHolder_(std::move(avioContextHolder)) {
716704
setFFmpegLogLevel();
@@ -736,8 +724,8 @@ VideoEncoder::VideoEncoder(
736724

737725
void VideoEncoder::initializeEncoder(
738726
const VideoStreamOptions& videoStreamOptions) {
739-
if (videoStreamOptions.device.is_cuda()) {
740-
gpuEncoder_ = std::make_unique<GpuEncoder>(videoStreamOptions.device);
727+
if (frames_.device().is_cuda()) {
728+
gpuEncoder_ = std::make_unique<GpuEncoder>(frames_.device());
741729
}
742730

743731
const AVCodec* avCodec = nullptr;
@@ -764,12 +752,7 @@ void VideoEncoder::initializeEncoder(
764752
TORCH_CHECK(
765753
avFormatContext_->oformat != nullptr,
766754
"Output format is null, unable to find default codec.");
767-
// Try to find a hardware-accelerated encoder if not using CPU
768755
avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec);
769-
if (gpuEncoder_) {
770-
avCodec = gpuEncoder_->findEncoder(avFormatContext_->oformat->video_codec)
771-
.value_or(avCodec);
772-
}
773756
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
774757
}
775758

@@ -842,11 +825,9 @@ void VideoEncoder::initializeEncoder(
842825
0);
843826
}
844827

845-
// Register the hardware device context with the codec
846-
// context before calling avcodec_open2().
847828
if (gpuEncoder_) {
848829
gpuEncoder_->registerHardwareDeviceWithCodec(avCodecContext_.get());
849-
gpuEncoder_->setupEncodingContext(avCodecContext_.get());
830+
gpuEncoder_->setupHardwareFrameContext(avCodecContext_.get());
850831
}
851832

852833
int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions);

src/torchcodec/_core/GpuEncoder.cpp

Lines changed: 25 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -100,46 +100,32 @@ void GpuEncoder::initializeHardwareContext() {
100100
nppCtx_ = getNppStreamContext(device_);
101101
}
102102

103-
std::optional<const AVCodec*> GpuEncoder::findEncoder(
104-
const AVCodecID& codecId) {
105-
void* i = nullptr;
106-
const AVCodec* codec = nullptr;
107-
while ((codec = av_codec_iterate(&i)) != nullptr) {
108-
if (codec->id != codecId || !av_codec_is_encoder(codec)) {
109-
continue;
110-
}
111-
112-
const AVCodecHWConfig* config = nullptr;
113-
for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr;
114-
++j) {
115-
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
116-
return codec;
117-
}
118-
}
119-
}
120-
return std::nullopt;
121-
}
122-
123103
void GpuEncoder::registerHardwareDeviceWithCodec(AVCodecContext* codecContext) {
124104
TORCH_CHECK(
125105
hardwareDeviceCtx_, "Hardware device context has not been initialized");
126106
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
127107
codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get());
128108
}
129109

130-
void GpuEncoder::setupEncodingContext(AVCodecContext* codecContext) {
110+
// Allocates and initializes AVHWFramesContext, and sets pixel format fields
111+
// to enable encoding with CUDA device. The hw_frames_ctx field is needed by
112+
// FFmpeg to allocate frames on GPU's memory.
113+
void GpuEncoder::setupHardwareFrameContext(AVCodecContext* codecContext) {
131114
TORCH_CHECK(
132115
hardwareDeviceCtx_, "Hardware device context has not been initialized");
133116
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
134117

135-
codecContext->sw_pix_fmt = AV_PIX_FMT_NV12;
136-
codecContext->pix_fmt = AV_PIX_FMT_CUDA;
137-
138118
AVBufferRef* hwFramesCtxRef = av_hwframe_ctx_alloc(hardwareDeviceCtx_.get());
139119
TORCH_CHECK(
140120
hwFramesCtxRef != nullptr,
141121
"Failed to allocate hardware frames context for codec");
142122

123+
// Always set pixel formats to options that support CUDA encoding.
124+
// TODO-VideoEncoder: Enable user set pixel formats to be set and properly
125+
// converted with npp functions below
126+
codecContext->sw_pix_fmt = AV_PIX_FMT_NV12;
127+
codecContext->pix_fmt = AV_PIX_FMT_CUDA;
128+
143129
AVHWFramesContext* hwFramesCtx =
144130
reinterpret_cast<AVHWFramesContext*>(hwFramesCtxRef->data);
145131
hwFramesCtx->format = codecContext->pix_fmt;
@@ -164,41 +150,44 @@ UniqueAVFrame GpuEncoder::convertTensorToAVFrame(
164150
[[maybe_unused]] AVPixelFormat targetFormat,
165151
int frameIndex,
166152
AVCodecContext* codecContext) {
167-
TORCH_CHECK(tensor.is_cuda(), "GpuEncoder requires CUDA tensors");
153+
TORCH_CHECK(
154+
tensor.is_cuda(),
155+
"Frame tensor is not stored on GPU, but the GPU method convertTensorToAVFrame was called.");
168156
TORCH_CHECK(
169157
tensor.dim() == 3 && tensor.size(0) == 3,
170158
"Expected 3D RGB tensor (CHW format), got shape: ",
171159
tensor.sizes());
160+
161+
// TODO-VideoEncoder: Unify AVFrame creation with CPU version of this method
172162
UniqueAVFrame avFrame(av_frame_alloc());
173163
TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame");
164+
int height = static_cast<int>(tensor.size(1));
165+
int width = static_cast<int>(tensor.size(2));
174166

175167
avFrame->format = AV_PIX_FMT_CUDA;
176-
avFrame->width = static_cast<int>(tensor.size(2));
177-
avFrame->height = static_cast<int>(tensor.size(1));
168+
avFrame->height = height;
169+
avFrame->width = width;
178170
avFrame->pts = frameIndex;
179171

180-
int ret = av_hwframe_get_buffer(
181-
codecContext ? codecContext->hw_frames_ctx : nullptr, avFrame.get(), 0);
172+
// FFmpeg's av_hwframe_get_buffer is used to allocate memory on CUDA device.
173+
// TODO-VideoEncoder: Consider using pytorch to allocate CUDA memory for
174+
// efficiency
175+
int ret =
176+
av_hwframe_get_buffer(codecContext->hw_frames_ctx, avFrame.get(), 0);
182177
TORCH_CHECK(
183178
ret >= 0,
184179
"Failed to allocate hardware frame: ",
185180
getFFMPEGErrorStringFromErrorCode(ret));
186181

187-
// Validate that avFrame was properly allocated with CUDA memory
188182
TORCH_CHECK(
189183
avFrame != nullptr && avFrame->data[0] != nullptr,
190184
"avFrame must be pre-allocated with CUDA memory");
191185

192-
// Convert CHW to HWC for NPP processing
193-
int height = static_cast<int>(tensor.size(1));
194-
int width = static_cast<int>(tensor.size(2));
195186
torch::Tensor hwcFrame = tensor.permute({1, 2, 0}).contiguous();
196187

197-
// Get current CUDA stream for NPP operations
198188
at::cuda::CUDAStream currentStream =
199189
at::cuda::getCurrentCUDAStream(device_.index());
200190

201-
// Setup NPP context with current stream
202191
nppCtx_->hStream = currentStream.stream();
203192
cudaError_t cudaErr =
204193
cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags);
@@ -207,9 +196,7 @@ UniqueAVFrame GpuEncoder::convertTensorToAVFrame(
207196
"cudaStreamGetFlags failed: ",
208197
cudaGetErrorString(cudaErr));
209198

210-
// Always use FFmpeg's default behavior: BT.601 limited range
211199
NppiSize oSizeROI = {width, height};
212-
213200
NppStatus status = nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx(
214201
static_cast<const Npp8u*>(hwcFrame.data_ptr()),
215202
hwcFrame.stride(0) * hwcFrame.element_size(),
@@ -224,15 +211,8 @@ UniqueAVFrame GpuEncoder::convertTensorToAVFrame(
224211
"Failed to convert RGB to NV12: NPP error code ",
225212
status);
226213

227-
// Validate CUDA operations completed successfully
228-
cudaError_t memCheck = cudaGetLastError();
229-
TORCH_CHECK(
230-
memCheck == cudaSuccess,
231-
"CUDA error detected: ",
232-
cudaGetErrorString(memCheck));
233-
234214
// TODO-VideoEncoder: Enable configuration of color properties, similar to
235-
// FFmpeg Set color properties to FFmpeg defaults
215+
// FFmpeg. Below are the default color properties used by FFmpeg.
236216
avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601
237217
avFrame->color_range = AVCOL_RANGE_MPEG; // Limited range
238218

src/torchcodec/_core/GpuEncoder.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ class GpuEncoder {
2727
explicit GpuEncoder(const torch::Device& device);
2828
~GpuEncoder();
2929

30-
std::optional<const AVCodec*> findEncoder(const AVCodecID& codecId);
3130
void registerHardwareDeviceWithCodec(AVCodecContext* codecContext);
32-
void setupEncodingContext(AVCodecContext* codecContext);
31+
void setupHardwareFrameContext(AVCodecContext* codecContext);
3332

3433
UniqueAVFrame convertTensorToAVFrame(
3534
const torch::Tensor& tensor,

src/torchcodec/_core/StreamOptions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ struct VideoStreamOptions {
4141
ColorConversionLibrary::FILTERGRAPH;
4242

4343
// By default we use CPU for decoding for both C++ and python users.
44+
// Note: For video encoding, device is determined by the location of the input
45+
// frame tensor.
4446
torch::Device device = torch::kCPU;
4547
// Device variant (e.g., "ffmpeg", "beta", etc.)
4648
std::string_view deviceVariant = "ffmpeg";

src/torchcodec/_core/custom_ops.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,6 @@ void encode_video_to_file(
647647
std::optional<std::string_view> preset = std::nullopt,
648648
std::optional<std::vector<std::string>> extra_options = std::nullopt) {
649649
VideoStreamOptions videoStreamOptions;
650-
videoStreamOptions.device = frames.device();
651650
videoStreamOptions.codec = std::move(codec);
652651
videoStreamOptions.pixelFormat = std::move(pixel_format);
653652
videoStreamOptions.crf = crf;
@@ -672,7 +671,6 @@ at::Tensor encode_video_to_tensor(
672671
std::optional<std::vector<std::string>> extra_options = std::nullopt) {
673672
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
674673
VideoStreamOptions videoStreamOptions;
675-
videoStreamOptions.device = frames.device();
676674
videoStreamOptions.codec = std::move(codec);
677675
videoStreamOptions.pixelFormat = std::move(pixel_format);
678676
videoStreamOptions.crf = crf;
@@ -709,7 +707,6 @@ void _encode_video_to_file_like(
709707
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
710708

711709
VideoStreamOptions videoStreamOptions;
712-
videoStreamOptions.device = frames.device();
713710
videoStreamOptions.codec = std::move(codec);
714711
videoStreamOptions.pixelFormat = std::move(pixel_format);
715712
videoStreamOptions.crf = crf;

test/test_encoders.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -827,12 +827,16 @@ def test_contiguity(self, method, tmp_path, device):
827827
)
828828

829829
def encode_to_tensor(frames):
830-
common_params = dict(crf=0, pixel_format="yuv444p")
830+
common_params = dict(
831+
crf=0,
832+
pixel_format="yuv444p",
833+
codec="h264_nvenc" if device != "cpu" else None,
834+
)
831835
if method == "to_file":
832836
dest = str(tmp_path / "output.mp4")
833837
VideoEncoder(frames, frame_rate=30).to_file(dest=dest, **common_params)
834838
with open(dest, "rb") as f:
835-
return torch.frombuffer(f.read(), dtype=torch.uint8).clone()
839+
return torch.frombuffer(f.read(), dtype=torch.uint8)
836840
elif method == "to_tensor":
837841
return VideoEncoder(frames, frame_rate=30).to_tensor(
838842
format="mp4", **common_params
@@ -1269,7 +1273,6 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range
12691273

12701274
@pytest.mark.needs_cuda
12711275
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
1272-
@pytest.mark.parametrize("pixel_format", ("nv12", "yuv420p"))
12731276
@pytest.mark.parametrize(
12741277
"format_codec",
12751278
[
@@ -1280,12 +1283,12 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range
12801283
],
12811284
)
12821285
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
1283-
def test_nvenc_against_ffmpeg_cli(
1284-
self, tmp_path, pixel_format, format_codec, method
1285-
):
1286+
# TODO-VideoEncoder: Enable additional pixel formats ("yuv420p", "yuv444p")
1287+
def test_nvenc_against_ffmpeg_cli(self, tmp_path, format_codec, method):
12861288
# Encode with FFmpeg CLI using nvenc codecs
12871289
format, codec = format_codec
12881290
device = "cuda"
1291+
pixel_format = "nv12"
12891292
qp = 1 # Lossless (qp=0) is not supported on av1_nvenc, so we use 1
12901293
source_frames = self.decode(TEST_SRC_2_720P.path).data.to(device)
12911294

@@ -1315,13 +1318,11 @@ def test_nvenc_against_ffmpeg_cli(
13151318

13161319
ffmpeg_cmd.extend(["-pix_fmt", pixel_format]) # Output format
13171320
if codec == "av1_nvenc":
1318-
ffmpeg_cmd.extend(
1319-
["-rc", "constqp"]
1320-
) # Set rate control mode for AV1 else:
1321+
ffmpeg_cmd.extend(["-rc", "constqp"]) # Set rate control mode for AV1
13211322
ffmpeg_cmd.extend(["-qp", str(qp)]) # Use lossless qp for other codecs
13221323
ffmpeg_cmd.extend([ffmpeg_encoded_path])
13231324

1324-
# Will this prevent CI from treating test as failed if NVENC is not available?
1325+
# TODO-VideoEncoder: Ensure CI does not skip this test, as we know NVENC is available.
13251326
try:
13261327
subprocess.run(ffmpeg_cmd, check=True, capture_output=True)
13271328
except subprocess.CalledProcessError as e:

0 commit comments

Comments
 (0)