@@ -51,7 +51,7 @@ torch::Tensor MaybeHWC2CHW(
5151 TORCH_CHECK (shape[2 ] == 3 , " Not a HWC tensor: " , shape);
5252 return hwcTensor.permute ({2 , 0 , 1 });
5353 } else if (numDimensions == 4 ) {
54- TORCH_CHECK (shape[3 ] == 3 , " Not a HWC tensor: " , shape);
54+ TORCH_CHECK (shape[3 ] == 3 , " Not a NHWC tensor: " , shape);
5555 return hwcTensor.permute ({0 , 3 , 1 , 2 });
5656 } else {
5757 TORCH_CHECK (
@@ -897,7 +897,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
897897 int streamIndex = rawOutput.streamIndex ;
898898 AVFrame* frame = rawOutput.frame .get ();
899899 auto & streamInfo = streams_[streamIndex];
900- auto comes_from_batch = preAllocatedOutputTensor.has_value ();
901900 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
902901 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
903902 torch::Tensor tensor;
@@ -926,14 +925,18 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
926925 } else if (
927926 streamInfo.colorConversionLibrary ==
928927 ColorConversionLibrary::FILTERGRAPH) {
929- output.frame =
930- convertFrameToTensorUsingFilterGraph (streamIndex, frame); // NHWC
928+ output.frame = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
931929 } else {
932930 throw std::runtime_error (
933931 " Invalid color conversion library: " +
934932 std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
935933 }
936- if (!comes_from_batch) {
934+ if (!preAllocatedOutputTensor.has_value ()) {
935+ // We only convert to CHW if a pre-allocated tensor wasn't passed. When a
936+ // pre-allocated tensor is passed, it's up to the caller (typically a
937+ // batch API) to do the conversion. This is more efficient as it allows
938+ // batch NHWC tensors to be permuted only once, instead of permuting HWC
939+ // tensors N times.
937940 output.frame = MaybeHWC2CHW (streamInfo.options , output.frame );
938941 }
939942
0 commit comments