Skip to content

Commit 22126c4

Browse files
committed
Nits
1 parent 3484615 commit 22126c4

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ torch::Tensor MaybeHWC2CHW(
5151
TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape);
5252
return hwcTensor.permute({2, 0, 1});
5353
} else if (numDimensions == 4) {
54-
TORCH_CHECK(shape[3] == 3, "Not a HWC tensor: ", shape);
54+
TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape);
5555
return hwcTensor.permute({0, 3, 1, 2});
5656
} else {
5757
TORCH_CHECK(
@@ -897,7 +897,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
897897
int streamIndex = rawOutput.streamIndex;
898898
AVFrame* frame = rawOutput.frame.get();
899899
auto& streamInfo = streams_[streamIndex];
900-
auto comes_from_batch = preAllocatedOutputTensor.has_value();
901900
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
902901
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
903902
torch::Tensor tensor;
@@ -926,14 +925,18 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
926925
} else if (
927926
streamInfo.colorConversionLibrary ==
928927
ColorConversionLibrary::FILTERGRAPH) {
929-
output.frame =
930-
convertFrameToTensorUsingFilterGraph(streamIndex, frame); // NHWC
928+
output.frame = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
931929
} else {
932930
throw std::runtime_error(
933931
"Invalid color conversion library: " +
934932
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
935933
}
936-
if (!comes_from_batch) {
934+
if (!preAllocatedOutputTensor.has_value()) {
935+
// We only convert to CHW if a pre-allocated tensor wasn't passed. When a
936+
// pre-allocated tensor is passed, it's up to the caller (typically a
937+
// batch API) to do the conversion. This is more efficient as it allows
938+
// batch NHWC tensors to be permuted only once, instead of permuting HWC
939+
// tensors N times.
937940
output.frame = MaybeHWC2CHW(streamInfo.options, output.frame);
938941
}
939942

0 commit comments

Comments
 (0)