@@ -34,6 +34,31 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
3434 return ptsToSeconds (pts, timeBase.den );
3535}
3636
37+ // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
38+ // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
39+ // or 4D.
40+ // Calling permute() is guaranteed to return a view as per the docs:
41+ // https://pytorch.org/docs/stable/generated/torch.permute.html
42+ torch::Tensor MaybeHWC2CHW (
43+ const VideoDecoder::VideoStreamDecoderOptions& options,
44+ torch::Tensor& hwcTensor) {
45+ if (options.dimensionOrder == " NHWC" ) {
46+ return hwcTensor;
47+ }
48+ auto numDimensions = hwcTensor.dim ();
49+ auto shape = hwcTensor.sizes ();
50+ if (numDimensions == 3 ) {
51+ TORCH_CHECK (shape[2 ] == 3 , " Not a HWC tensor: " , shape);
52+ return hwcTensor.permute ({2 , 0 , 1 });
53+ } else if (numDimensions == 4 ) {
54+ TORCH_CHECK (shape[3 ] == 3 , " Not a HWC tensor: " , shape);
55+ return hwcTensor.permute ({0 , 3 , 1 , 2 });
56+ } else {
57+ TORCH_CHECK (
58+ false , " Expected tensor with 3 or 4 dimensions, got " , numDimensions);
59+ }
60+ }
61+
3762struct AVInput {
3863 UniqueAVFormatContext formatContext;
3964 std::unique_ptr<AVIOBytesContext> ioBytesContext;
@@ -167,28 +192,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
167192 const VideoStreamDecoderOptions& options,
168193 const StreamMetadata& metadata)
169194 : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64 })),
170- durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {
171- if (options.dimensionOrder == " NHWC" ) {
172- frames = torch::empty (
173- {numFrames,
174- options.height .value_or (*metadata.height ),
175- options.width .value_or (*metadata.width ),
176- 3 },
177- {torch::kUInt8 });
178- } else if (options.dimensionOrder == " NCHW" ) {
179- frames = torch::empty (
180- {numFrames,
181- 3 ,
182- options.height .value_or (*metadata.height ),
183- options.width .value_or (*metadata.width )},
184- torch::TensorOptions ()
185- .memory_format (torch::MemoryFormat::ChannelsLast)
186- .dtype ({torch::kUInt8 }));
187- } else {
188- TORCH_CHECK (
189- false , " Unsupported frame dimensionOrder =" + options.dimensionOrder )
190- }
191- }
195+ durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
196+ frames (torch::empty(
197+ {numFrames,
198+ options.height .value_or (*metadata.height ),
199+ options.width .value_or (*metadata.width ),
200+ 3 },
201+ {torch::kUInt8 })) {}
192202
193203VideoDecoder::VideoDecoder () {}
194204
@@ -887,35 +897,45 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
887897 int streamIndex = rawOutput.streamIndex ;
888898 AVFrame* frame = rawOutput.frame .get ();
889899 auto & streamInfo = streams_[streamIndex];
900+ auto comes_from_batch = preAllocatedOutputTensor.has_value ();
890901 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
891902 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
892903 torch::Tensor tensor;
904+ int width = streamInfo.options .width .value_or (frame->width );
905+ int height = streamInfo.options .height .value_or (frame->height );
893906 if (preAllocatedOutputTensor.has_value ()) {
894- // TODO: check shape of preAllocatedOutputTensor?
895907 tensor = preAllocatedOutputTensor.value ();
908+ auto shape = tensor.sizes ();
909+ TORCH_CHECK (
910+ (shape.size () == 3 ) && (shape[0 ] == height) &&
911+ (shape[1 ] == width) && (shape[2 ] == 3 ),
912+ " Expected tensor of shape " ,
913+ height,
914+ " x" ,
915+ width,
916+ " x3, got " ,
917+ shape);
896918 } else {
897- int width = streamInfo.options .width .value_or (frame->width );
898- int height = streamInfo.options .height .value_or (frame->height );
899919 tensor = torch::empty (
900920 {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
901921 }
902-
903922 rawOutput.data = tensor.data_ptr <uint8_t >();
904923 convertFrameToBufferUsingSwsScale (rawOutput);
905924
906- if (streamInfo.options .dimensionOrder == " NCHW" ) {
907- tensor = tensor.permute ({2 , 0 , 1 });
908- }
909925 output.frame = tensor;
910926 } else if (
911927 streamInfo.colorConversionLibrary ==
912928 ColorConversionLibrary::FILTERGRAPH) {
913- output.frame = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
929+ output.frame =
930+ convertFrameToTensorUsingFilterGraph (streamIndex, frame); // NHWC
914931 } else {
915932 throw std::runtime_error (
916933 " Invalid color conversion library: " +
917934 std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
918935 }
936+ if (!comes_from_batch) {
937+ output.frame = MaybeHWC2CHW (streamInfo.options , output.frame );
938+ }
919939
920940 } else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
921941 // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
@@ -1046,6 +1066,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10461066 }
10471067 i++;
10481068 }
1069+ output.frames = MaybeHWC2CHW (options, output.frames );
10491070 return output;
10501071}
10511072
@@ -1081,7 +1102,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10811102 output.ptsSeconds [f] = singleOut.ptsSeconds ;
10821103 output.durationSeconds [f] = singleOut.durationSeconds ;
10831104 }
1084-
1105+ output. frames = MaybeHWC2CHW (options, output. frames );
10851106 return output;
10861107}
10871108
@@ -1134,6 +1155,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11341155 // need this special case below.
11351156 if (startSeconds == stopSeconds) {
11361157 BatchDecodedOutput output (0 , options, streamMetadata);
1158+ output.frames = MaybeHWC2CHW (options, output.frames );
11371159 return output;
11381160 }
11391161
@@ -1176,6 +1198,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11761198 output.ptsSeconds [f] = singleOut.ptsSeconds ;
11771199 output.durationSeconds [f] = singleOut.durationSeconds ;
11781200 }
1201+ output.frames = MaybeHWC2CHW (options, output.frames );
11791202
11801203 return output;
11811204}
@@ -1302,11 +1325,6 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
13021325 torch::Tensor tensor = torch::from_blob (
13031326 filteredFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
13041327 StreamInfo& activeStream = streams_[streamIndex];
1305- if (activeStream.options .dimensionOrder == " NCHW" ) {
1306- // The docs guaranty this to return a view:
1307- // https://pytorch.org/docs/stable/generated/torch.permute.html
1308- tensor = tensor.permute ({2 , 0 , 1 });
1309- }
13101328 return tensor;
13111329}
13121330
0 commit comments