Skip to content

Commit 4a9c00c

Browse files
committed
WIP
1 parent 0256e18 commit 4a9c00c

File tree

5 files changed

+20
-28
lines changed

5 files changed

+20
-28
lines changed

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace facebook::torchcodec {
1717
void convertAVFrameToDecodedOutputOnCuda(
1818
const torch::Device& device,
1919
const VideoDecoder::VideoStreamDecoderOptions& options,
20-
AVCodecContext* codecContext,
20+
const VideoDecoder::StreamMetadata& metadata,
2121
VideoDecoder::RawDecodedOutput& rawOutput,
2222
VideoDecoder::DecodedOutput& output,
2323
std::optional<torch::Tensor> preAllocatedOutputTensor) {

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -154,18 +154,6 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
154154
#endif
155155
}
156156

157-
torch::Tensor allocateDeviceTensor(
158-
at::IntArrayRef shape,
159-
torch::Device device,
160-
const torch::Dtype dtype = torch::kUInt8) {
161-
return torch::empty(
162-
shape,
163-
torch::TensorOptions()
164-
.dtype(dtype)
165-
.layout(torch::kStrided)
166-
.device(device));
167-
}
168-
169157
void throwErrorIfNonCudaDevice(const torch::Device& device) {
170158
TORCH_CHECK(
171159
device.type() != torch::kCPU,
@@ -199,7 +187,7 @@ void initializeContextOnCuda(
199187
void convertAVFrameToDecodedOutputOnCuda(
200188
const torch::Device& device,
201189
const VideoDecoder::VideoStreamDecoderOptions& options,
202-
AVCodecContext* codecContext,
190+
const VideoDecoder::StreamMetadata& metadata,
203191
VideoDecoder::RawDecodedOutput& rawOutput,
204192
VideoDecoder::DecodedOutput& output,
205193
std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -209,8 +197,9 @@ void convertAVFrameToDecodedOutputOnCuda(
209197
src->format == AV_PIX_FMT_CUDA,
210198
"Expected format to be AV_PIX_FMT_CUDA, got " +
211199
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format)));
212-
int width = options.width.value_or(codecContext->width);
213-
int height = options.height.value_or(codecContext->height);
200+
int height = 0, width = 0;
201+
std::tie(height, width) =
202+
getHeightAndWidthFromOptionsOrMetadata(options, metadata);
214203
NppiSize oSizeROI = {width, height};
215204
Npp8u* input[2] = {src->data[0], src->data[1]};
216205
torch::Tensor& dst = output.frame;
@@ -227,7 +216,7 @@ void convertAVFrameToDecodedOutputOnCuda(
227216
"x3, got ",
228217
shape);
229218
} else {
230-
dst = allocateDeviceTensor({height, width, 3}, options.device);
219+
dst = allocateEmptyHWCTensor(height, width, options.device);
231220
}
232221

233222
// Use the user-requested GPU for running the NPP kernel.

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void initializeContextOnCuda(
3535
void convertAVFrameToDecodedOutputOnCuda(
3636
const torch::Device& device,
3737
const VideoDecoder::VideoStreamDecoderOptions& options,
38-
AVCodecContext* codecContext,
38+
const VideoDecoder::StreamMetadata& metadata,
3939
VideoDecoder::RawDecodedOutput& rawOutput,
4040
VideoDecoder::DecodedOutput& output,
4141
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
193193
const StreamMetadata& metadata)
194194
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
195195
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
196-
int height, width;
196+
int height = 0, width = 0;
197197
std::tie(height, width) =
198198
getHeightAndWidthFromOptionsOrMetadata(options, metadata);
199199
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames);
@@ -359,12 +359,10 @@ void VideoDecoder::initializeFilterGraphForStream(
359359
inputs->pad_idx = 0;
360360
inputs->next = nullptr;
361361
char description[512];
362-
int width = activeStream.codecContext->width;
363-
int height = activeStream.codecContext->height;
364-
if (options.height.has_value() && options.width.has_value()) {
365-
width = *options.width;
366-
height = *options.height;
367-
}
362+
int height = 0, width = 0;
363+
std::tie(height, width) = getHeightAndWidthFromOptionsOrMetadata(
364+
options, containerMetadata_.streams[streamIndex]);
365+
368366
std::snprintf(
369367
description,
370368
sizeof(description),
@@ -862,7 +860,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
862860
convertAVFrameToDecodedOutputOnCuda(
863861
streamInfo.options.device,
864862
streamInfo.options,
865-
streamInfo.codecContext.get(),
863+
containerMetadata_.streams[streamIndex],
866864
rawOutput,
867865
output,
868866
preAllocatedOutputTensor);
@@ -1309,8 +1307,9 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
13091307
enum AVPixelFormat frameFormat =
13101308
static_cast<enum AVPixelFormat>(frame->format);
13111309
StreamInfo& activeStream = streams_[streamIndex];
1312-
int outputWidth = activeStream.options.width.value_or(frame->width);
1313-
int outputHeight = activeStream.options.height.value_or(frame->height);
1310+
int outputHeight = 0, outputWidth = 0;
1311+
std::tie(outputHeight, outputWidth) =
1312+
getHeightAndWidthFromOptionsOrAVFrame(activeStream.options, frame);
13141313
if (activeStream.swsContext.get() == nullptr) {
13151314
SwsContext* swsContext = sws_getContext(
13161315
frame->width,

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,10 @@ class VideoDecoder {
414414
bool scanned_all_streams_ = false;
415415
};
416416

417+
// --------------------------------------------------------------------------
418+
// FRAME TENSOR ALLOCATION APIs
419+
// --------------------------------------------------------------------------
420+
417421
std::tuple<int, int> getHeightAndWidthFromOptionsOrMetadata(
418422
const VideoDecoder::VideoStreamDecoderOptions& options,
419423
const VideoDecoder::StreamMetadata& metadata);

0 commit comments

Comments
 (0)