Skip to content

Commit 650e8bf

Browse files
authored
Honor preAllocatedOutputTensor with filtergraph (#295)
1 parent 679f25e commit 650e8bf

File tree

2 files changed

+53
-12
lines changed

2 files changed

+53
-12
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -873,16 +873,25 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
873873
return output;
874874
}
875875

876+
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
877+
// Callers may pass a pre-allocated tensor, where the output frame tensor will
878+
// be stored. This parameter is honored in any case, but it only leads to a
879+
// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
880+
// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
881+
// found a way to do that with filtegraph.
882+
// TODO: Figure out whether that's possilbe!
883+
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
884+
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
876885
void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
877886
VideoDecoder::RawDecodedOutput& rawOutput,
878887
DecodedOutput& output,
879888
std::optional<torch::Tensor> preAllocatedOutputTensor) {
880889
int streamIndex = rawOutput.streamIndex;
881890
AVFrame* frame = rawOutput.frame.get();
882891
auto& streamInfo = streams_[streamIndex];
892+
torch::Tensor tensor;
883893
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
884894
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
885-
torch::Tensor tensor;
886895
int width = streamInfo.options.width.value_or(frame->width);
887896
int height = streamInfo.options.height.value_or(frame->height);
888897
if (preAllocatedOutputTensor.has_value()) {
@@ -908,7 +917,13 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
908917
} else if (
909918
streamInfo.colorConversionLibrary ==
910919
ColorConversionLibrary::FILTERGRAPH) {
911-
output.frame = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
920+
tensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
921+
if (preAllocatedOutputTensor.has_value()) {
922+
preAllocatedOutputTensor.value().copy_(tensor);
923+
output.frame = preAllocatedOutputTensor.value();
924+
} else {
925+
output.frame = tensor;
926+
}
912927
} else {
913928
throw std::runtime_error(
914929
"Invalid color conversion library: " +
@@ -1060,10 +1075,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10601075
} else {
10611076
DecodedOutput singleOut = getFrameAtIndex(
10621077
streamIndex, indexInVideo, output.frames[indexInOutput]);
1063-
if (options.colorConversionLibrary ==
1064-
ColorConversionLibrary::FILTERGRAPH) {
1065-
output.frames[indexInOutput] = singleOut.frame;
1066-
}
10671078
output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds;
10681079
output.durationSeconds[indexInOutput] = singleOut.durationSeconds;
10691080
}
@@ -1140,9 +1151,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
11401151

11411152
for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
11421153
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
1143-
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1144-
output.frames[f] = singleOut.frame;
1145-
}
11461154
output.ptsSeconds[f] = singleOut.ptsSeconds;
11471155
output.durationSeconds[f] = singleOut.durationSeconds;
11481156
}
@@ -1236,9 +1244,6 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
12361244
BatchDecodedOutput output(numFrames, options, streamMetadata);
12371245
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
12381246
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
1239-
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1240-
output.frames[f] = singleOut.frame;
1241-
}
12421247
output.ptsSeconds[f] = singleOut.ptsSeconds;
12431248
output.durationSeconds[f] = singleOut.durationSeconds;
12441249
}

test/decoders/VideoDecoderTest.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,42 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
387387
}
388388
}
389389

390+
TEST_P(VideoDecoderTest, PreAllocatedTensorFilterGraph) {
391+
std::string path = getResourcePath("nasa_13013.mp4");
392+
auto preAllocatedOutputTensor = torch::empty({270, 480, 3}, {torch::kUInt8});
393+
394+
std::unique_ptr<VideoDecoder> ourDecoder =
395+
VideoDecoderTest::createDecoderFromPath(path, GetParam());
396+
ourDecoder->scanFileAndUpdateMetadataAndIndex();
397+
int bestVideoStreamIndex =
398+
*ourDecoder->getContainerMetadata().bestVideoStreamIndex;
399+
ourDecoder->addVideoStreamDecoder(
400+
bestVideoStreamIndex,
401+
VideoDecoder::VideoStreamDecoderOptions(
402+
"color_conversion_library=filtergraph"));
403+
auto output = ourDecoder->getFrameAtIndex(
404+
bestVideoStreamIndex, 0, preAllocatedOutputTensor);
405+
EXPECT_EQ(output.frame.data_ptr(), preAllocatedOutputTensor.data_ptr());
406+
}
407+
408+
TEST_P(VideoDecoderTest, PreAllocatedTensorSwscale) {
409+
std::string path = getResourcePath("nasa_13013.mp4");
410+
auto preAllocatedOutputTensor = torch::empty({270, 480, 3}, {torch::kUInt8});
411+
412+
std::unique_ptr<VideoDecoder> ourDecoder =
413+
VideoDecoderTest::createDecoderFromPath(path, GetParam());
414+
ourDecoder->scanFileAndUpdateMetadataAndIndex();
415+
int bestVideoStreamIndex =
416+
*ourDecoder->getContainerMetadata().bestVideoStreamIndex;
417+
ourDecoder->addVideoStreamDecoder(
418+
bestVideoStreamIndex,
419+
VideoDecoder::VideoStreamDecoderOptions(
420+
"color_conversion_library=swscale"));
421+
auto output = ourDecoder->getFrameAtIndex(
422+
bestVideoStreamIndex, 0, preAllocatedOutputTensor);
423+
EXPECT_EQ(output.frame.data_ptr(), preAllocatedOutputTensor.data_ptr());
424+
}
425+
390426
TEST_P(VideoDecoderTest, GetAudioMetadata) {
391427
std::string path = getResourcePath("nasa_13013.mp4.audio.mp3");
392428
std::unique_ptr<VideoDecoder> decoder =

0 commit comments

Comments
 (0)