@@ -873,16 +873,25 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
873873 return output;
874874}
875875
876+ // Note [preAllocatedOutputTensor with swscale and filtergraph]:
877+ // Callers may pass a pre-allocated tensor, where the output frame tensor will
878+ // be stored. This parameter is honored in any case, but it only leads to a
879+ // speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
880+ // decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
881+ // found a way to do that with filtegraph.
882+ // TODO: Figure out whether that's possilbe!
883+ // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
884+ // `dimension_order` parameter. It's up to callers to re-shape it if needed.
876885void VideoDecoder::convertAVFrameToDecodedOutputOnCPU (
877886 VideoDecoder::RawDecodedOutput& rawOutput,
878887 DecodedOutput& output,
879888 std::optional<torch::Tensor> preAllocatedOutputTensor) {
880889 int streamIndex = rawOutput.streamIndex ;
881890 AVFrame* frame = rawOutput.frame .get ();
882891 auto & streamInfo = streams_[streamIndex];
892+ torch::Tensor tensor;
883893 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
884894 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
885- torch::Tensor tensor;
886895 int width = streamInfo.options .width .value_or (frame->width );
887896 int height = streamInfo.options .height .value_or (frame->height );
888897 if (preAllocatedOutputTensor.has_value ()) {
@@ -908,7 +917,13 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
908917 } else if (
909918 streamInfo.colorConversionLibrary ==
910919 ColorConversionLibrary::FILTERGRAPH) {
911- output.frame = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
920+ tensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
921+ if (preAllocatedOutputTensor.has_value ()) {
922+ preAllocatedOutputTensor.value ().copy_ (tensor);
923+ output.frame = preAllocatedOutputTensor.value ();
924+ } else {
925+ output.frame = tensor;
926+ }
912927 } else {
913928 throw std::runtime_error (
914929 " Invalid color conversion library: " +
@@ -1060,10 +1075,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10601075 } else {
10611076 DecodedOutput singleOut = getFrameAtIndex (
10621077 streamIndex, indexInVideo, output.frames [indexInOutput]);
1063- if (options.colorConversionLibrary ==
1064- ColorConversionLibrary::FILTERGRAPH) {
1065- output.frames [indexInOutput] = singleOut.frame ;
1066- }
10671078 output.ptsSeconds [indexInOutput] = singleOut.ptsSeconds ;
10681079 output.durationSeconds [indexInOutput] = singleOut.durationSeconds ;
10691080 }
@@ -1140,9 +1151,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
11401151
11411152 for (int64_t i = start, f = 0 ; i < stop; i += step, ++f) {
11421153 DecodedOutput singleOut = getFrameAtIndex (streamIndex, i, output.frames [f]);
1143- if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1144- output.frames [f] = singleOut.frame ;
1145- }
11461154 output.ptsSeconds [f] = singleOut.ptsSeconds ;
11471155 output.durationSeconds [f] = singleOut.durationSeconds ;
11481156 }
@@ -1236,9 +1244,6 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12361244 BatchDecodedOutput output (numFrames, options, streamMetadata);
12371245 for (int64_t i = startFrameIndex, f = 0 ; i < stopFrameIndex; ++i, ++f) {
12381246 DecodedOutput singleOut = getFrameAtIndex (streamIndex, i, output.frames [f]);
1239- if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1240- output.frames [f] = singleOut.frame ;
1241- }
12421247 output.ptsSeconds [f] = singleOut.ptsSeconds ;
12431248 output.durationSeconds [f] = singleOut.durationSeconds ;
12441249 }
0 commit comments