Skip to content

Commit 84cf475

Browse files
author
pytorchbot
committed
2025-11-11 nightly release (289ff5d)
1 parent cfe8aaf commit 84cf475

19 files changed

+585
-360
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ function(make_torchcodec_libraries
9696
Encoder.cpp
9797
ValidationUtils.cpp
9898
Transform.cpp
99+
Metadata.cpp
99100
)
100101

101102
if(ENABLE_CUDA)

src/torchcodec/_core/Encoder.cpp

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#include "Encoder.h"
55
#include "torch/types.h"
66

7+
extern "C" {
8+
#include <libavutil/pixdesc.h>
9+
}
10+
711
namespace facebook::torchcodec {
812

913
namespace {
@@ -534,6 +538,36 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
534538
return frames.contiguous();
535539
}
536540

541+
AVPixelFormat validatePixelFormat(
542+
const AVCodec& avCodec,
543+
const std::string& targetPixelFormat) {
544+
AVPixelFormat pixelFormat = av_get_pix_fmt(targetPixelFormat.c_str());
545+
546+
// Validate that the encoder supports this pixel format
547+
const AVPixelFormat* supportedFormats = getSupportedPixelFormats(avCodec);
548+
if (supportedFormats != nullptr) {
549+
for (int i = 0; supportedFormats[i] != AV_PIX_FMT_NONE; ++i) {
550+
if (supportedFormats[i] == pixelFormat) {
551+
return pixelFormat;
552+
}
553+
}
554+
}
555+
556+
std::stringstream errorMsg;
557+
// av_get_pix_fmt failed to find a pix_fmt
558+
if (pixelFormat == AV_PIX_FMT_NONE) {
559+
errorMsg << "Unknown pixel format: " << targetPixelFormat;
560+
} else {
561+
errorMsg << "Specified pixel format " << targetPixelFormat
562+
<< " is not supported by the " << avCodec.name << " encoder.";
563+
}
564+
// Build error message, similar to FFmpeg's error log
565+
errorMsg << "\nSupported pixel formats for " << avCodec.name << ":";
566+
for (int i = 0; supportedFormats[i] != AV_PIX_FMT_NONE; ++i) {
567+
errorMsg << " " << av_get_pix_fmt_name(supportedFormats[i]);
568+
}
569+
TORCH_CHECK(false, errorMsg.str());
570+
}
537571
} // namespace
538572

539573
VideoEncoder::~VideoEncoder() {
@@ -635,15 +669,19 @@ void VideoEncoder::initializeEncoder(
635669
outWidth_ = inWidth_;
636670
outHeight_ = inHeight_;
637671

638-
// TODO-VideoEncoder: Enable other pixel formats
639-
// Let FFmpeg choose best pixel format to minimize loss
640-
outPixelFormat_ = avcodec_find_best_pix_fmt_of_list(
641-
getSupportedPixelFormats(*avCodec), // List of supported formats
642-
AV_PIX_FMT_GBRP, // We reorder input to GBRP currently
643-
0, // No alpha channel
644-
nullptr // Discard conversion loss information
645-
);
646-
TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt")
672+
if (videoStreamOptions.pixelFormat.has_value()) {
673+
outPixelFormat_ =
674+
validatePixelFormat(*avCodec, videoStreamOptions.pixelFormat.value());
675+
} else {
676+
const AVPixelFormat* formats = getSupportedPixelFormats(*avCodec);
677+
// Use first listed pixel format as default (often yuv420p).
678+
// This is similar to FFmpeg's logic:
679+
// https://www.ffmpeg.org/doxygen/4.0/decode_8c_source.html#l01087
680+
// If pixel formats are undefined for some reason, try yuv420p
681+
outPixelFormat_ = (formats && formats[0] != AV_PIX_FMT_NONE)
682+
? formats[0]
683+
: AV_PIX_FMT_YUV420P;
684+
}
647685

648686
// Configure codec parameters
649687
avCodecContext_->codec_id = avCodec->id;

src/torchcodec/_core/Metadata.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "Metadata.h"
8+
#include "torch/types.h"
9+
10+
namespace facebook::torchcodec {
11+
12+
std::optional<double> StreamMetadata::getDurationSeconds(
13+
SeekMode seekMode) const {
14+
switch (seekMode) {
15+
case SeekMode::custom_frame_mappings:
16+
case SeekMode::exact:
17+
TORCH_CHECK(
18+
endStreamPtsSecondsFromContent.has_value() &&
19+
beginStreamPtsSecondsFromContent.has_value(),
20+
"Missing beginStreamPtsSecondsFromContent or endStreamPtsSecondsFromContent");
21+
return endStreamPtsSecondsFromContent.value() -
22+
beginStreamPtsSecondsFromContent.value();
23+
case SeekMode::approximate:
24+
if (durationSecondsFromHeader.has_value()) {
25+
return durationSecondsFromHeader.value();
26+
}
27+
if (numFramesFromHeader.has_value() && averageFpsFromHeader.has_value() &&
28+
averageFpsFromHeader.value() != 0.0) {
29+
return static_cast<double>(numFramesFromHeader.value()) /
30+
averageFpsFromHeader.value();
31+
}
32+
return std::nullopt;
33+
default:
34+
TORCH_CHECK(false, "Unknown SeekMode");
35+
}
36+
}
37+
38+
double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const {
39+
switch (seekMode) {
40+
case SeekMode::custom_frame_mappings:
41+
case SeekMode::exact:
42+
TORCH_CHECK(
43+
beginStreamPtsSecondsFromContent.has_value(),
44+
"Missing beginStreamPtsSecondsFromContent");
45+
return beginStreamPtsSecondsFromContent.value();
46+
case SeekMode::approximate:
47+
if (beginStreamPtsSecondsFromContent.has_value()) {
48+
return beginStreamPtsSecondsFromContent.value();
49+
}
50+
return 0.0;
51+
default:
52+
TORCH_CHECK(false, "Unknown SeekMode");
53+
}
54+
}
55+
56+
std::optional<double> StreamMetadata::getEndStreamSeconds(
57+
SeekMode seekMode) const {
58+
switch (seekMode) {
59+
case SeekMode::custom_frame_mappings:
60+
case SeekMode::exact:
61+
TORCH_CHECK(
62+
endStreamPtsSecondsFromContent.has_value(),
63+
"Missing endStreamPtsSecondsFromContent");
64+
return endStreamPtsSecondsFromContent.value();
65+
case SeekMode::approximate:
66+
if (endStreamPtsSecondsFromContent.has_value()) {
67+
return endStreamPtsSecondsFromContent.value();
68+
}
69+
return getDurationSeconds(seekMode);
70+
default:
71+
TORCH_CHECK(false, "Unknown SeekMode");
72+
}
73+
}
74+
75+
std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
76+
switch (seekMode) {
77+
case SeekMode::custom_frame_mappings:
78+
case SeekMode::exact:
79+
TORCH_CHECK(
80+
numFramesFromContent.has_value(), "Missing numFramesFromContent");
81+
return numFramesFromContent.value();
82+
case SeekMode::approximate: {
83+
if (numFramesFromHeader.has_value()) {
84+
return numFramesFromHeader.value();
85+
}
86+
if (averageFpsFromHeader.has_value() &&
87+
durationSecondsFromHeader.has_value()) {
88+
return static_cast<int64_t>(
89+
averageFpsFromHeader.value() * durationSecondsFromHeader.value());
90+
}
91+
return std::nullopt;
92+
}
93+
default:
94+
TORCH_CHECK(false, "Unknown SeekMode");
95+
}
96+
}
97+
98+
std::optional<double> StreamMetadata::getAverageFps(SeekMode seekMode) const {
99+
switch (seekMode) {
100+
case SeekMode::custom_frame_mappings:
101+
case SeekMode::exact: {
102+
auto numFrames = getNumFrames(seekMode);
103+
if (numFrames.has_value() &&
104+
beginStreamPtsSecondsFromContent.has_value() &&
105+
endStreamPtsSecondsFromContent.has_value()) {
106+
double duration = endStreamPtsSecondsFromContent.value() -
107+
beginStreamPtsSecondsFromContent.value();
108+
if (duration != 0.0) {
109+
return static_cast<double>(numFrames.value()) / duration;
110+
}
111+
}
112+
return averageFpsFromHeader;
113+
}
114+
case SeekMode::approximate:
115+
return averageFpsFromHeader;
116+
default:
117+
TORCH_CHECK(false, "Unknown SeekMode");
118+
}
119+
}
120+
121+
} // namespace facebook::torchcodec

src/torchcodec/_core/Metadata.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ extern "C" {
1818

1919
namespace facebook::torchcodec {
2020

21+
enum class SeekMode { exact, approximate, custom_frame_mappings };
22+
2123
struct StreamMetadata {
2224
// Common (video and audio) fields derived from the AVStream.
2325
int streamIndex;
@@ -52,6 +54,13 @@ struct StreamMetadata {
5254
std::optional<int64_t> sampleRate;
5355
std::optional<int64_t> numChannels;
5456
std::optional<std::string> sampleFormat;
57+
58+
// Computed methods with fallback logic
59+
std::optional<double> getDurationSeconds(SeekMode seekMode) const;
60+
double getBeginStreamSeconds(SeekMode seekMode) const;
61+
std::optional<double> getEndStreamSeconds(SeekMode seekMode) const;
62+
std::optional<int64_t> getNumFrames(SeekMode seekMode) const;
63+
std::optional<double> getAverageFps(SeekMode seekMode) const;
5564
};
5665

5766
struct ContainerMetadata {

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 17 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,14 @@ ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
367367
return containerMetadata_;
368368
}
369369

370+
SeekMode SingleStreamDecoder::getSeekMode() const {
371+
return seekMode_;
372+
}
373+
374+
int SingleStreamDecoder::getActiveStreamIndex() const {
375+
return activeStreamIndex_;
376+
}
377+
370378
torch::Tensor SingleStreamDecoder::getKeyFrameIndices() {
371379
validateActiveStream(AVMEDIA_TYPE_VIDEO);
372380
validateScannedAllStreams("getKeyFrameIndices");
@@ -611,7 +619,7 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
611619
const auto& streamMetadata =
612620
containerMetadata_.allStreamMetadata[activeStreamIndex_];
613621

614-
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
622+
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
615623
if (numFrames.has_value()) {
616624
// If the frameIndex is negative, we convert it to a positive index
617625
frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value();
@@ -705,7 +713,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
705713

706714
// Note that if we do not have the number of frames available in our
707715
// metadata, then we assume that the upper part of the range is valid.
708-
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
716+
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
709717
if (numFrames.has_value()) {
710718
TORCH_CHECK(
711719
stop <= numFrames.value(),
@@ -779,8 +787,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
779787
const auto& streamMetadata =
780788
containerMetadata_.allStreamMetadata[activeStreamIndex_];
781789

782-
double minSeconds = getMinSeconds(streamMetadata);
783-
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
790+
double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
791+
std::optional<double> maxSeconds =
792+
streamMetadata.getEndStreamSeconds(seekMode_);
784793

785794
// The frame played at timestamp t and the one played at timestamp `t +
786795
// eps` are probably the same frame, with the same index. The easiest way to
@@ -857,7 +866,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
857866
return frameBatchOutput;
858867
}
859868

860-
double minSeconds = getMinSeconds(streamMetadata);
869+
double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
861870
TORCH_CHECK(
862871
startSeconds >= minSeconds,
863872
"Start seconds is " + std::to_string(startSeconds) +
@@ -866,7 +875,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
866875

867876
// Note that if we can't determine the maximum seconds from the metadata,
868877
// then we assume upper range is valid.
869-
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
878+
std::optional<double> maxSeconds =
879+
streamMetadata.getEndStreamSeconds(seekMode_);
870880
if (maxSeconds.has_value()) {
871881
TORCH_CHECK(
872882
startSeconds < maxSeconds.value(),
@@ -1439,47 +1449,6 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
14391449
// STREAM AND METADATA APIS
14401450
// --------------------------------------------------------------------------
14411451

1442-
std::optional<int64_t> SingleStreamDecoder::getNumFrames(
1443-
const StreamMetadata& streamMetadata) {
1444-
switch (seekMode_) {
1445-
case SeekMode::custom_frame_mappings:
1446-
case SeekMode::exact:
1447-
return streamMetadata.numFramesFromContent.value();
1448-
case SeekMode::approximate: {
1449-
return streamMetadata.numFramesFromHeader;
1450-
}
1451-
default:
1452-
TORCH_CHECK(false, "Unknown SeekMode");
1453-
}
1454-
}
1455-
1456-
double SingleStreamDecoder::getMinSeconds(
1457-
const StreamMetadata& streamMetadata) {
1458-
switch (seekMode_) {
1459-
case SeekMode::custom_frame_mappings:
1460-
case SeekMode::exact:
1461-
return streamMetadata.beginStreamPtsSecondsFromContent.value();
1462-
case SeekMode::approximate:
1463-
return 0;
1464-
default:
1465-
TORCH_CHECK(false, "Unknown SeekMode");
1466-
}
1467-
}
1468-
1469-
std::optional<double> SingleStreamDecoder::getMaxSeconds(
1470-
const StreamMetadata& streamMetadata) {
1471-
switch (seekMode_) {
1472-
case SeekMode::custom_frame_mappings:
1473-
case SeekMode::exact:
1474-
return streamMetadata.endStreamPtsSecondsFromContent.value();
1475-
case SeekMode::approximate: {
1476-
return streamMetadata.durationSecondsFromHeader;
1477-
}
1478-
default:
1479-
TORCH_CHECK(false, "Unknown SeekMode");
1480-
}
1481-
}
1482-
14831452
// --------------------------------------------------------------------------
14841453
// VALIDATION UTILS
14851454
// --------------------------------------------------------------------------
@@ -1529,7 +1498,7 @@ void SingleStreamDecoder::validateFrameIndex(
15291498

15301499
// Note that if we do not have the number of frames available in our
15311500
// metadata, then we assume that the frameIndex is valid.
1532-
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
1501+
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
15331502
if (numFrames.has_value()) {
15341503
if (frameIndex >= numFrames.value()) {
15351504
throw std::out_of_range(

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "DeviceInterface.h"
1717
#include "FFMPEGCommon.h"
1818
#include "Frame.h"
19+
#include "Metadata.h"
1920
#include "StreamOptions.h"
2021
#include "Transform.h"
2122

@@ -30,8 +31,6 @@ class SingleStreamDecoder {
3031
// CONSTRUCTION API
3132
// --------------------------------------------------------------------------
3233

33-
enum class SeekMode { exact, approximate, custom_frame_mappings };
34-
3534
// Creates a SingleStreamDecoder from the video at videoFilePath.
3635
explicit SingleStreamDecoder(
3736
const std::string& videoFilePath,
@@ -60,6 +59,12 @@ class SingleStreamDecoder {
6059
// Returns the metadata for the container.
6160
ContainerMetadata getContainerMetadata() const;
6261

62+
// Returns the seek mode of this decoder.
63+
SeekMode getSeekMode() const;
64+
65+
// Returns the active stream index. Returns -2 if no stream is active.
66+
int getActiveStreamIndex() const;
67+
6368
// Returns the key frame indices as a tensor. The tensor is 1D and contains
6469
// int64 values, where each value is the frame index for a key frame.
6570
torch::Tensor getKeyFrameIndices();
@@ -312,10 +317,6 @@ class SingleStreamDecoder {
312317
// index. Note that this index may be truncated for some files.
313318
int getBestStreamIndex(AVMediaType mediaType);
314319

315-
std::optional<int64_t> getNumFrames(const StreamMetadata& streamMetadata);
316-
double getMinSeconds(const StreamMetadata& streamMetadata);
317-
std::optional<double> getMaxSeconds(const StreamMetadata& streamMetadata);
318-
319320
// --------------------------------------------------------------------------
320321
// VALIDATION UTILS
321322
// --------------------------------------------------------------------------

0 commit comments

Comments
 (0)