Skip to content

Commit 0256e18

Browse files
committed
WIP
1 parent 2640fa9 commit 0256e18

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
191191
int64_t numFrames,
192192
const VideoStreamDecoderOptions& options,
193193
const StreamMetadata& metadata)
194-
: frames(torch::empty(
195-
{numFrames,
196-
options.height.value_or(*metadata.height),
197-
options.width.value_or(*metadata.width),
198-
3},
199-
at::TensorOptions(options.device).dtype(torch::kUInt8))),
200-
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
201-
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {}
194+
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
195+
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
196+
int height, width;
197+
std::tie(height, width) =
198+
getHeightAndWidthFromOptionsOrMetadata(options, metadata);
199+
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames);
200+
}
202201

203202
VideoDecoder::VideoDecoder() {}
204203

@@ -893,8 +892,9 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
893892
torch::Tensor tensor;
894893
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
895894
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
896-
int width = streamInfo.options.width.value_or(frame->width);
897-
int height = streamInfo.options.height.value_or(frame->height);
895+
int height, width;
896+
std::tie(height, width) =
897+
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, frame);
898898
if (preAllocatedOutputTensor.has_value()) {
899899
tensor = preAllocatedOutputTensor.value();
900900
auto shape = tensor.sizes();
@@ -908,8 +908,8 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
908908
"x3, got ",
909909
shape);
910910
} else {
911-
tensor = torch::empty(
912-
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
911+
tensor = allocateEmptyHWCTensor(
912+
height, width, streamInfo.options.device.type());
913913
}
914914
rawOutput.data = tensor.data_ptr<uint8_t>();
915915
convertFrameToBufferUsingSwsScale(rawOutput);
@@ -1400,6 +1400,38 @@ VideoDecoder::~VideoDecoder() {
14001400
}
14011401
}
14021402

1403+
std::tuple<int, int> getHeightAndWidthFromOptionsOrMetadata(
1404+
const VideoDecoder::VideoStreamDecoderOptions& options,
1405+
const VideoDecoder::StreamMetadata& metadata) {
1406+
return std::make_tuple(
1407+
options.height.value_or(*metadata.height),
1408+
options.width.value_or(*metadata.width));
1409+
}
1410+
1411+
std::tuple<int, int> getHeightAndWidthFromOptionsOrAVFrame(
1412+
const VideoDecoder::VideoStreamDecoderOptions& options,
1413+
AVFrame* avFrame) {
1414+
return std::make_tuple(
1415+
options.height.value_or(avFrame->height),
1416+
options.width.value_or(avFrame->width));
1417+
}
1418+
1419+
torch::Tensor allocateEmptyHWCTensor(
1420+
int height,
1421+
int width,
1422+
torch::Device device,
1423+
std::optional<int> numFrames) {
1424+
auto tensorOptions = torch::TensorOptions()
1425+
.dtype(torch::kUInt8)
1426+
.layout(torch::kStrided)
1427+
.device(device);
1428+
if (numFrames.has_value()) {
1429+
return torch::empty({numFrames.value(), height, width, 3}, tensorOptions);
1430+
} else {
1431+
return torch::empty({height, width, 3}, tensorOptions);
1432+
}
1433+
}
1434+
14031435
std::ostream& operator<<(
14041436
std::ostream& os,
14051437
const VideoDecoder::DecodeStats& stats) {

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ class VideoDecoder {
243243
const VideoStreamDecoderOptions& options,
244244
const StreamMetadata& metadata);
245245
};
246+
246247
// Returns frames at the given indices for a given stream as a single stacked
247248
// Tensor.
248249
BatchDecodedOutput getFramesAtIndices(
@@ -413,6 +414,20 @@ class VideoDecoder {
413414
bool scanned_all_streams_ = false;
414415
};
415416

417+
std::tuple<int, int> getHeightAndWidthFromOptionsOrMetadata(
418+
const VideoDecoder::VideoStreamDecoderOptions& options,
419+
const VideoDecoder::StreamMetadata& metadata);
420+
421+
std::tuple<int, int> getHeightAndWidthFromOptionsOrAVFrame(
422+
const VideoDecoder::VideoStreamDecoderOptions& options,
423+
AVFrame* avFrame);
424+
425+
torch::Tensor allocateEmptyHWCTensor(
426+
int height,
427+
int width,
428+
torch::Device device,
429+
std::optional<int> numFrames = std::nullopt);
430+
416431
// Prints the VideoDecoder::DecodeStats to the ostream.
417432
std::ostream& operator<<(
418433
std::ostream& os,

0 commit comments

Comments
 (0)