Skip to content

Commit bf78468

Browse files
committed
actually add GpuEncoder.cpp
1 parent 9c7bae7 commit bf78468

File tree

4 files changed

+275
-25
lines changed

4 files changed

+275
-25
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ void VideoEncoder::encode() {
895895
currFrame, outPixelFormat_, i, avCodecContext_.get());
896896
} else {
897897
// Use direct CPU conversion for CPU devices
898-
avFrame = convertCpuTensorToAVFrame(currFrame, outPixelFormat_, i);
898+
avFrame = convertCpuTensorToAVFrame(currFrame, i);
899899
}
900900
encodeFrame(autoAVPacket, avFrame);
901901
}
@@ -911,33 +911,32 @@ void VideoEncoder::encode() {
911911

912912
UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame(
913913
const torch::Tensor& tensor,
914-
AVPixelFormat targetFormat,
915914
int frameIndex) {
916915
TORCH_CHECK(tensor.is_cpu(), "CPU encoder requires CPU tensors");
917916
TORCH_CHECK(
918917
tensor.dim() == 3 && tensor.size(0) == 3,
919918
"Expected 3D RGB tensor (CHW format), got shape: ",
920919
tensor.sizes());
921920

922-
int inHeight = static_cast<int>(tensor.sizes()[1]);
923-
int inWidth = static_cast<int>(tensor.sizes()[2]);
921+
inHeight_ = static_cast<int>(tensor.sizes()[1]);
922+
inWidth_ = static_cast<int>(tensor.sizes()[2]);
924923

925924
// For now, reuse input dimensions as output dimensions
926-
int outWidth = inWidth;
927-
int outHeight = inHeight;
925+
outWidth_ = inWidth_;
926+
outHeight_ = inHeight_;
928927

929928
// Input format is RGB planar (AV_PIX_FMT_GBRP after channel reordering)
930-
AVPixelFormat inPixelFormat = AV_PIX_FMT_GBRP;
929+
inPixelFormat_ = AV_PIX_FMT_GBRP;
931930

932931
// Initialize and cache scaling context if it does not exist
933932
if (!swsContext_) {
934933
swsContext_.reset(sws_getContext(
935-
inWidth,
936-
inHeight,
937-
inPixelFormat,
938-
outWidth,
939-
outHeight,
940-
targetFormat,
934+
inWidth_,
935+
inHeight_,
936+
inPixelFormat_,
937+
outWidth_,
938+
outHeight_,
939+
outPixelFormat_,
941940
SWS_BICUBIC, // Used by FFmpeg CLI
942941
nullptr,
943942
nullptr,
@@ -949,9 +948,9 @@ UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame(
949948
TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame");
950949

951950
// Set output frame properties
952-
avFrame->format = targetFormat;
953-
avFrame->width = outWidth;
954-
avFrame->height = outHeight;
951+
avFrame->format = outPixelFormat_;
952+
avFrame->width = outWidth_;
953+
avFrame->height = outHeight_;
955954
avFrame->pts = frameIndex;
956955

957956
int status = av_frame_get_buffer(avFrame.get(), 0);
@@ -962,23 +961,23 @@ UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame(
962961
UniqueAVFrame inputFrame(av_frame_alloc());
963962
TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame");
964963

965-
inputFrame->format = inPixelFormat;
966-
inputFrame->width = inWidth;
967-
inputFrame->height = inHeight;
964+
inputFrame->format = inPixelFormat_;
965+
inputFrame->width = inWidth_;
966+
inputFrame->height = inHeight_;
968967

969968
uint8_t* tensorData = static_cast<uint8_t*>(tensor.data_ptr());
970969

971970
// TODO-VideoEncoder: Reorder tensor if in NHWC format
972-
int channelSize = inHeight * inWidth;
971+
int channelSize = inHeight_ * inWidth_;
973972
// Reorder RGB -> GBR for AV_PIX_FMT_GBRP format
974973
// TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format
975974
inputFrame->data[0] = tensorData + channelSize; // G channel
976975
inputFrame->data[1] = tensorData + (2 * channelSize); // B channel
977976
inputFrame->data[2] = tensorData; // R channel
978977

979-
inputFrame->linesize[0] = inWidth;
980-
inputFrame->linesize[1] = inWidth;
981-
inputFrame->linesize[2] = inWidth;
978+
inputFrame->linesize[0] = inWidth_;
979+
inputFrame->linesize[1] = inWidth_;
980+
inputFrame->linesize[2] = inWidth_;
982981

983982
status = sws_scale(
984983
swsContext_.get(),
@@ -988,7 +987,7 @@ UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame(
988987
inputFrame->height,
989988
avFrame->data,
990989
avFrame->linesize);
991-
TORCH_CHECK(status == outHeight, "sws_scale failed");
990+
TORCH_CHECK(status == outHeight_, "sws_scale failed");
992991

993992
return avFrame;
994993
}

src/torchcodec/_core/Encoder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ class VideoEncoder {
168168
// CPU tensor-to-frame conversion for CPU encoding
169169
UniqueAVFrame convertCpuTensorToAVFrame(
170170
const torch::Tensor& tensor,
171-
AVPixelFormat targetFormat,
172171
int frameIndex);
173172

174173
UniqueEncodingAVFormatContext avFormatContext_;
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "GpuEncoder.h"
8+
9+
#include <ATen/cuda/CUDAEvent.h>
10+
#include <c10/cuda/CUDAStream.h>
11+
#include <cuda_runtime.h>
12+
#include <torch/types.h>
13+
14+
#include "CUDACommon.h"
15+
#include "FFMPEGCommon.h"
16+
17+
extern "C" {
18+
#include <libavutil/hwcontext_cuda.h>
19+
#include <libavutil/pixdesc.h>
20+
}
21+
22+
namespace facebook::torchcodec {
23+
namespace {
24+
25+
// Redefinition from CudaDeviceInterface.cpp anonymous namespace
26+
int getFlagsAVHardwareDeviceContextCreate() {
27+
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
28+
return AV_CUDA_USE_CURRENT_CONTEXT;
29+
#else
30+
return 0;
31+
#endif
32+
}
33+
34+
// Redefinition from CudaDeviceInterface.cpp anonymous namespace
35+
// TODO-VideoEncoder: unify device context creation, add caching to encoder
36+
UniqueAVBufferRef createHardwareDeviceContext(const torch::Device& device) {
37+
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
38+
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
39+
40+
int deviceIndex = getDeviceIndex(device);
41+
42+
c10::cuda::CUDAGuard deviceGuard(device);
43+
// We set the device because we may be called from a different thread than
44+
// the one that initialized the cuda context.
45+
TORCH_CHECK(
46+
cudaSetDevice(deviceIndex) == cudaSuccess, "Failed to set CUDA device");
47+
48+
AVBufferRef* hardwareDeviceCtxRaw = nullptr;
49+
std::string deviceOrdinal = std::to_string(deviceIndex);
50+
51+
int err = av_hwdevice_ctx_create(
52+
&hardwareDeviceCtxRaw,
53+
type,
54+
deviceOrdinal.c_str(),
55+
nullptr,
56+
getFlagsAVHardwareDeviceContextCreate());
57+
58+
if (err < 0) {
59+
/* clang-format off */
60+
TORCH_CHECK(
61+
false,
62+
"Failed to create specified HW device. This typically happens when ",
63+
"your installed FFmpeg doesn't support CUDA (see ",
64+
"https://github.com/pytorch/torchcodec#installing-cuda-enabled-torchcodec",
65+
"). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err));
66+
/* clang-format on */
67+
}
68+
69+
return UniqueAVBufferRef(hardwareDeviceCtxRaw);
70+
}
71+
72+
} // anonymous namespace
73+
74+
GpuEncoder::GpuEncoder(const torch::Device& device) : device_(device) {
75+
TORCH_CHECK(
76+
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
77+
78+
initializeCudaContextWithPytorch(device_);
79+
initializeHardwareContext();
80+
}
81+
82+
GpuEncoder::~GpuEncoder() {}
83+
84+
void GpuEncoder::initializeHardwareContext() {
85+
hardwareDeviceCtx_ = createHardwareDeviceContext(device_);
86+
nppCtx_ = getNppStreamContext(device_);
87+
}
88+
89+
std::optional<const AVCodec*> GpuEncoder::findEncoder(
90+
const AVCodecID& codecId) {
91+
void* i = nullptr;
92+
const AVCodec* codec = nullptr;
93+
while ((codec = av_codec_iterate(&i)) != nullptr) {
94+
if (codec->id != codecId || !av_codec_is_encoder(codec)) {
95+
continue;
96+
}
97+
98+
const AVCodecHWConfig* config = nullptr;
99+
for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr;
100+
++j) {
101+
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
102+
return codec;
103+
}
104+
}
105+
}
106+
return std::nullopt;
107+
}
108+
109+
void GpuEncoder::registerHardwareDeviceWithCodec(AVCodecContext* codecContext) {
110+
TORCH_CHECK(
111+
hardwareDeviceCtx_, "Hardware device context has not been initialized");
112+
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
113+
codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get());
114+
}
115+
116+
void GpuEncoder::setupEncodingContext(AVCodecContext* codecContext) {
117+
TORCH_CHECK(
118+
hardwareDeviceCtx_, "Hardware device context has not been initialized");
119+
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
120+
121+
codecContext->sw_pix_fmt = AV_PIX_FMT_NV12;
122+
codecContext->pix_fmt = AV_PIX_FMT_CUDA;
123+
124+
AVBufferRef* hwFramesCtxRef = av_hwframe_ctx_alloc(hardwareDeviceCtx_.get());
125+
TORCH_CHECK(
126+
hwFramesCtxRef != nullptr,
127+
"Failed to allocate hardware frames context for codec");
128+
129+
AVHWFramesContext* hwFramesCtx =
130+
reinterpret_cast<AVHWFramesContext*>(hwFramesCtxRef->data);
131+
hwFramesCtx->format = codecContext->pix_fmt;
132+
hwFramesCtx->sw_format = codecContext->sw_pix_fmt;
133+
hwFramesCtx->width = codecContext->width;
134+
hwFramesCtx->height = codecContext->height;
135+
136+
int ret = av_hwframe_ctx_init(hwFramesCtxRef);
137+
if (ret < 0) {
138+
av_buffer_unref(&hwFramesCtxRef);
139+
TORCH_CHECK(
140+
false,
141+
"Failed to initialize CUDA frames context for codec: ",
142+
getFFMPEGErrorStringFromErrorCode(ret));
143+
}
144+
145+
codecContext->hw_frames_ctx = hwFramesCtxRef;
146+
}
147+
148+
UniqueAVFrame GpuEncoder::convertTensorToAVFrame(
149+
const torch::Tensor& tensor,
150+
[[maybe_unused]] AVPixelFormat targetFormat,
151+
int frameIndex,
152+
AVCodecContext* codecContext) {
153+
TORCH_CHECK(tensor.is_cuda(), "GpuEncoder requires CUDA tensors");
154+
TORCH_CHECK(
155+
tensor.dim() == 3 && tensor.size(0) == 3,
156+
"Expected 3D RGB tensor (CHW format), got shape: ",
157+
tensor.sizes());
158+
159+
return convertRGBTensorToNV12Frame(tensor, frameIndex, codecContext);
160+
}
161+
162+
UniqueAVFrame GpuEncoder::convertRGBTensorToNV12Frame(
163+
const torch::Tensor& tensor,
164+
int frameIndex,
165+
AVCodecContext* codecContext) {
166+
UniqueAVFrame avFrame(av_frame_alloc());
167+
TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame");
168+
169+
avFrame->format = AV_PIX_FMT_CUDA;
170+
avFrame->width = static_cast<int>(tensor.size(2));
171+
avFrame->height = static_cast<int>(tensor.size(1));
172+
avFrame->pts = frameIndex;
173+
174+
int ret = av_hwframe_get_buffer(
175+
codecContext ? codecContext->hw_frames_ctx : nullptr, avFrame.get(), 0);
176+
TORCH_CHECK(
177+
ret >= 0,
178+
"Failed to allocate hardware frame: ",
179+
getFFMPEGErrorStringFromErrorCode(ret));
180+
181+
at::cuda::CUDAStream currentStream =
182+
at::cuda::getCurrentCUDAStream(device_.index());
183+
184+
facebook::torchcodec::convertRGBTensorToNV12Frame(
185+
tensor, avFrame, device_, nppCtx_, currentStream);
186+
187+
// Set color properties to FFmpeg defaults
188+
avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601
189+
avFrame->color_range = AVCOL_RANGE_MPEG; // Limited range
190+
191+
return avFrame;
192+
}
193+
194+
} // namespace facebook::torchcodec

src/torchcodec/_core/GpuEncoder.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <torch/types.h>
10+
#include <memory>
11+
#include <optional>
12+
13+
#include "CUDACommon.h"
14+
#include "FFMPEGCommon.h"
15+
#include "StreamOptions.h"
16+
17+
extern "C" {
18+
#include <libavcodec/avcodec.h>
19+
#include <libavutil/buffer.h>
20+
#include <libavutil/hwcontext.h>
21+
}
22+
23+
namespace facebook::torchcodec {
24+
25+
class GpuEncoder {
26+
public:
27+
explicit GpuEncoder(const torch::Device& device);
28+
~GpuEncoder();
29+
30+
std::optional<const AVCodec*> findEncoder(const AVCodecID& codecId);
31+
void registerHardwareDeviceWithCodec(AVCodecContext* codecContext);
32+
void setupEncodingContext(AVCodecContext* codecContext);
33+
34+
UniqueAVFrame convertTensorToAVFrame(
35+
const torch::Tensor& tensor,
36+
AVPixelFormat targetFormat,
37+
int frameIndex,
38+
AVCodecContext* codecContext);
39+
40+
const torch::Device& device() const {
41+
return device_;
42+
}
43+
44+
private:
45+
torch::Device device_;
46+
UniqueAVBufferRef hardwareDeviceCtx_;
47+
UniqueNppContext nppCtx_;
48+
49+
void initializeHardwareContext();
50+
void setupHardwareFrameContext(AVCodecContext* codecContext);
51+
52+
UniqueAVFrame convertRGBTensorToNV12Frame(
53+
const torch::Tensor& tensor,
54+
int frameIndex,
55+
AVCodecContext* codecContext);
56+
};
57+
58+
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)