@@ -191,14 +191,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
191191 int64_t numFrames,
192192 const VideoStreamDecoderOptions& options,
193193 const StreamMetadata& metadata)
194- : frames(torch::empty(
195- {numFrames,
196- options.height .value_or (*metadata.height ),
197- options.width .value_or (*metadata.width ),
198- 3 },
199- at::TensorOptions (options.device).dtype(torch::kUInt8 ))),
200- ptsSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
201- durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {}
194+ : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64 })),
195+ durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {
196+ int height, width;
197+ std::tie (height, width) =
198+ getHeightAndWidthFromOptionsOrMetadata (options, metadata);
199+ frames = allocateEmptyHWCTensor (height, width, options.device , numFrames);
200+ }
202201
203202VideoDecoder::VideoDecoder () {}
204203
@@ -893,8 +892,9 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
893892 torch::Tensor tensor;
894893 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
895894 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
896- int width = streamInfo.options .width .value_or (frame->width );
897- int height = streamInfo.options .height .value_or (frame->height );
895+ int height, width;
896+ std::tie (height, width) =
897+ getHeightAndWidthFromOptionsOrAVFrame (streamInfo.options , frame);
898898 if (preAllocatedOutputTensor.has_value ()) {
899899 tensor = preAllocatedOutputTensor.value ();
900900 auto shape = tensor.sizes ();
@@ -908,8 +908,8 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
908908 " x3, got " ,
909909 shape);
910910 } else {
911- tensor = torch::empty (
912- { height, width, 3 }, torch::TensorOptions (). dtype ({torch:: kUInt8 } ));
911+ tensor = allocateEmptyHWCTensor (
912+ height, width, streamInfo. options . device . type ( ));
913913 }
914914 rawOutput.data = tensor.data_ptr <uint8_t >();
915915 convertFrameToBufferUsingSwsScale (rawOutput);
@@ -1400,6 +1400,38 @@ VideoDecoder::~VideoDecoder() {
14001400 }
14011401}
14021402
1403+ std::tuple<int , int > getHeightAndWidthFromOptionsOrMetadata (
1404+ const VideoDecoder::VideoStreamDecoderOptions& options,
1405+ const VideoDecoder::StreamMetadata& metadata) {
1406+ return std::make_tuple (
1407+ options.height .value_or (*metadata.height ),
1408+ options.width .value_or (*metadata.width ));
1409+ }
1410+
1411+ std::tuple<int , int > getHeightAndWidthFromOptionsOrAVFrame (
1412+ const VideoDecoder::VideoStreamDecoderOptions& options,
1413+ AVFrame* avFrame) {
1414+ return std::make_tuple (
1415+ options.height .value_or (avFrame->height ),
1416+ options.width .value_or (avFrame->width ));
1417+ }
1418+
1419+ torch::Tensor allocateEmptyHWCTensor (
1420+ int height,
1421+ int width,
1422+ torch::Device device,
1423+ std::optional<int > numFrames) {
1424+ auto tensorOptions = torch::TensorOptions ()
1425+ .dtype (torch::kUInt8 )
1426+ .layout (torch::kStrided )
1427+ .device (device);
1428+ if (numFrames.has_value ()) {
1429+ return torch::empty ({numFrames.value (), height, width, 3 }, tensorOptions);
1430+ } else {
1431+ return torch::empty ({height, width, 3 }, tensorOptions);
1432+ }
1433+ }
1434+
14031435std::ostream& operator <<(
14041436 std::ostream& os,
14051437 const VideoDecoder::DecodeStats& stats) {
0 commit comments