Skip to content

Commit 3484615

Browse files
committed
WIP
1 parent faa0178 commit 3484615

File tree

2 files changed

+60
-36
lines changed

2 files changed

+60
-36
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,31 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
3434
return ptsToSeconds(pts, timeBase.den);
3535
}
3636

37+
// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
38+
// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
39+
// or 4D.
40+
// Calling permute() is guaranteed to return a view as per the docs:
41+
// https://pytorch.org/docs/stable/generated/torch.permute.html
42+
torch::Tensor MaybeHWC2CHW(
43+
const VideoDecoder::VideoStreamDecoderOptions& options,
44+
torch::Tensor& hwcTensor) {
45+
if (options.dimensionOrder == "NHWC") {
46+
return hwcTensor;
47+
}
48+
auto numDimensions = hwcTensor.dim();
49+
auto shape = hwcTensor.sizes();
50+
if (numDimensions == 3) {
51+
TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape);
52+
return hwcTensor.permute({2, 0, 1});
53+
} else if (numDimensions == 4) {
54+
TORCH_CHECK(shape[3] == 3, "Not a HWC tensor: ", shape);
55+
return hwcTensor.permute({0, 3, 1, 2});
56+
} else {
57+
TORCH_CHECK(
58+
false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions);
59+
}
60+
}
61+
3762
struct AVInput {
3863
UniqueAVFormatContext formatContext;
3964
std::unique_ptr<AVIOBytesContext> ioBytesContext;
@@ -167,28 +192,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
167192
const VideoStreamDecoderOptions& options,
168193
const StreamMetadata& metadata)
169194
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
170-
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
171-
if (options.dimensionOrder == "NHWC") {
172-
frames = torch::empty(
173-
{numFrames,
174-
options.height.value_or(*metadata.height),
175-
options.width.value_or(*metadata.width),
176-
3},
177-
{torch::kUInt8});
178-
} else if (options.dimensionOrder == "NCHW") {
179-
frames = torch::empty(
180-
{numFrames,
181-
3,
182-
options.height.value_or(*metadata.height),
183-
options.width.value_or(*metadata.width)},
184-
torch::TensorOptions()
185-
.memory_format(torch::MemoryFormat::ChannelsLast)
186-
.dtype({torch::kUInt8}));
187-
} else {
188-
TORCH_CHECK(
189-
false, "Unsupported frame dimensionOrder =" + options.dimensionOrder)
190-
}
191-
}
195+
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})),
196+
frames(torch::empty(
197+
{numFrames,
198+
options.height.value_or(*metadata.height),
199+
options.width.value_or(*metadata.width),
200+
3},
201+
{torch::kUInt8})) {}
192202

193203
VideoDecoder::VideoDecoder() {}
194204

@@ -887,35 +897,45 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
887897
int streamIndex = rawOutput.streamIndex;
888898
AVFrame* frame = rawOutput.frame.get();
889899
auto& streamInfo = streams_[streamIndex];
900+
auto comes_from_batch = preAllocatedOutputTensor.has_value();
890901
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
891902
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
892903
torch::Tensor tensor;
904+
int width = streamInfo.options.width.value_or(frame->width);
905+
int height = streamInfo.options.height.value_or(frame->height);
893906
if (preAllocatedOutputTensor.has_value()) {
894-
// TODO: check shape of preAllocatedOutputTensor?
895907
tensor = preAllocatedOutputTensor.value();
908+
auto shape = tensor.sizes();
909+
TORCH_CHECK(
910+
(shape.size() == 3) && (shape[0] == height) &&
911+
(shape[1] == width) && (shape[2] == 3),
912+
"Expected tensor of shape ",
913+
height,
914+
"x",
915+
width,
916+
"x3, got ",
917+
shape);
896918
} else {
897-
int width = streamInfo.options.width.value_or(frame->width);
898-
int height = streamInfo.options.height.value_or(frame->height);
899919
tensor = torch::empty(
900920
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
901921
}
902-
903922
rawOutput.data = tensor.data_ptr<uint8_t>();
904923
convertFrameToBufferUsingSwsScale(rawOutput);
905924

906-
if (streamInfo.options.dimensionOrder == "NCHW") {
907-
tensor = tensor.permute({2, 0, 1});
908-
}
909925
output.frame = tensor;
910926
} else if (
911927
streamInfo.colorConversionLibrary ==
912928
ColorConversionLibrary::FILTERGRAPH) {
913-
output.frame = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
929+
output.frame =
930+
convertFrameToTensorUsingFilterGraph(streamIndex, frame); // NHWC
914931
} else {
915932
throw std::runtime_error(
916933
"Invalid color conversion library: " +
917934
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
918935
}
936+
if (!comes_from_batch) {
937+
output.frame = MaybeHWC2CHW(streamInfo.options, output.frame);
938+
}
919939

920940
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
921941
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
@@ -1046,6 +1066,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10461066
}
10471067
i++;
10481068
}
1069+
output.frames = MaybeHWC2CHW(options, output.frames);
10491070
return output;
10501071
}
10511072

@@ -1081,7 +1102,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10811102
output.ptsSeconds[f] = singleOut.ptsSeconds;
10821103
output.durationSeconds[f] = singleOut.durationSeconds;
10831104
}
1084-
1105+
output.frames = MaybeHWC2CHW(options, output.frames);
10851106
return output;
10861107
}
10871108

@@ -1134,6 +1155,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11341155
// need this special case below.
11351156
if (startSeconds == stopSeconds) {
11361157
BatchDecodedOutput output(0, options, streamMetadata);
1158+
output.frames = MaybeHWC2CHW(options, output.frames);
11371159
return output;
11381160
}
11391161

@@ -1176,6 +1198,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11761198
output.ptsSeconds[f] = singleOut.ptsSeconds;
11771199
output.durationSeconds[f] = singleOut.durationSeconds;
11781200
}
1201+
output.frames = MaybeHWC2CHW(options, output.frames);
11791202

11801203
return output;
11811204
}
@@ -1302,11 +1325,6 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
13021325
torch::Tensor tensor = torch::from_blob(
13031326
filteredFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
13041327
StreamInfo& activeStream = streams_[streamIndex];
1305-
if (activeStream.options.dimensionOrder == "NCHW") {
1306-
// The docs guaranty this to return a view:
1307-
// https://pytorch.org/docs/stable/generated/torch.permute.html
1308-
tensor = tensor.permute({2, 0, 1});
1309-
}
13101328
return tensor;
13111329
}
13121330

test/decoders/test_video_decoder_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,12 @@ def test_color_conversion_library_with_dimension_order(
425425
assert frames.shape[1:] == expected_shape
426426
assert_tensor_equal(frames[0], frame0_ref)
427427

428+
frames = get_frames_at_indices(
429+
decoder, stream_index=stream_index, frame_indices=[0, 1, 3, 4]
430+
)
431+
assert frames.shape[1:] == expected_shape
432+
assert_tensor_equal(frames[0], frame0_ref)
433+
428434
@pytest.mark.parametrize(
429435
"width_scaling_factor,height_scaling_factor",
430436
((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),

0 commit comments

Comments
 (0)