Skip to content

Commit bff16e6

Browse files
authored
Merge branch 'meta-pytorch:main' into skip-tests
2 parents 451189e + 408b373 commit bff16e6

File tree

11 files changed

+142
-94
lines changed

11 files changed

+142
-94
lines changed

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ int getNumChannels(const SharedAVCodecContext& avCodecContext) {
158158
#endif
159159
}
160160

161+
int getNumChannels(const AVCodecParameters* codecpar) {
162+
TORCH_CHECK(codecpar != nullptr, "codecpar is null")
163+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
164+
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
165+
return codecpar->ch_layout.nb_channels;
166+
#else
167+
return codecpar->channels;
168+
#endif
169+
}
170+
161171
void setDefaultChannelLayout(
162172
UniqueAVCodecContext& avCodecContext,
163173
int numChannels) {

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);
180180

181181
int getNumChannels(const UniqueAVFrame& avFrame);
182182
int getNumChannels(const SharedAVCodecContext& avCodecContext);
183+
int getNumChannels(const AVCodecParameters* codecpar);
183184

184185
void setDefaultChannelLayout(
185186
UniqueAVCodecContext& avCodecContext,

src/torchcodec/_core/Metadata.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ std::optional<double> StreamMetadata::getDurationSeconds(
2929
return static_cast<double>(numFramesFromHeader.value()) /
3030
averageFpsFromHeader.value();
3131
}
32+
if (durationSecondsFromContainer.has_value()) {
33+
return durationSecondsFromContainer.value();
34+
}
3235
return std::nullopt;
3336
default:
3437
TORCH_CHECK(false, "Unknown SeekMode");
@@ -80,13 +83,13 @@ std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
8083
numFramesFromContent.has_value(), "Missing numFramesFromContent");
8184
return numFramesFromContent.value();
8285
case SeekMode::approximate: {
86+
auto durationSeconds = getDurationSeconds(seekMode);
8387
if (numFramesFromHeader.has_value()) {
8488
return numFramesFromHeader.value();
8589
}
86-
if (averageFpsFromHeader.has_value() &&
87-
durationSecondsFromHeader.has_value()) {
90+
if (averageFpsFromHeader.has_value() && durationSeconds.has_value()) {
8891
return static_cast<int64_t>(
89-
averageFpsFromHeader.value() * durationSecondsFromHeader.value());
92+
averageFpsFromHeader.value() * durationSeconds.value());
9093
}
9194
return std::nullopt;
9295
}

src/torchcodec/_core/Metadata.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ enum class SeekMode { exact, approximate, custom_frame_mappings };
2323
struct StreamMetadata {
2424
// Common (video and audio) fields derived from the AVStream.
2525
int streamIndex;
26+
2627
// See this link for what various values are available:
2728
// https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48
2829
AVMediaType mediaType;
30+
2931
std::optional<AVCodecID> codecId;
3032
std::optional<std::string> codecName;
3133
std::optional<double> durationSecondsFromHeader;
@@ -35,17 +37,22 @@ struct StreamMetadata {
3537
std::optional<double> averageFpsFromHeader;
3638
std::optional<double> bitRate;
3739

40+
// Used as fallback in approximate mode when stream duration is unavailable.
41+
std::optional<double> durationSecondsFromContainer;
42+
3843
// More accurate duration, obtained by scanning the file.
3944
// These presentation timestamps are in time base.
4045
std::optional<int64_t> beginStreamPtsFromContent;
4146
std::optional<int64_t> endStreamPtsFromContent;
47+
4248
// These presentation timestamps are in seconds.
4349
std::optional<double> beginStreamPtsSecondsFromContent;
4450
std::optional<double> endStreamPtsSecondsFromContent;
51+
4552
// This can be useful for index-based seeking.
4653
std::optional<int64_t> numFramesFromContent;
4754

48-
// Video-only fields derived from the AVCodecContext.
55+
// Video-only fields
4956
std::optional<int> width;
5057
std::optional<int> height;
5158
std::optional<AVRational> sampleAspectRatio;
@@ -67,13 +74,17 @@ struct ContainerMetadata {
6774
std::vector<StreamMetadata> allStreamMetadata;
6875
int numAudioStreams = 0;
6976
int numVideoStreams = 0;
77+
7078
// Note that this is the container-level duration, which is usually the max
7179
// of all stream durations available in the container.
7280
std::optional<double> durationSecondsFromHeader;
81+
7382
// Total BitRate level information at the container level in bit/s
7483
std::optional<double> bitRate;
84+
7585
// If set, this is the index to the default audio stream.
7686
std::optional<int> bestAudioStreamIndex;
87+
7788
// If set, this is the index to the default video stream.
7889
std::optional<int> bestVideoStreamIndex;
7990
};

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,26 @@ void SingleStreamDecoder::initializeDecoder() {
100100
"Failed to find stream info: ",
101101
getFFMPEGErrorStringFromErrorCode(status));
102102

103+
if (formatContext_->duration > 0) {
104+
AVRational defaultTimeBase{1, AV_TIME_BASE};
105+
containerMetadata_.durationSecondsFromHeader =
106+
ptsToSeconds(formatContext_->duration, defaultTimeBase);
107+
}
108+
109+
if (formatContext_->bit_rate > 0) {
110+
containerMetadata_.bitRate = formatContext_->bit_rate;
111+
}
112+
113+
int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO);
114+
if (bestVideoStream >= 0) {
115+
containerMetadata_.bestVideoStreamIndex = bestVideoStream;
116+
}
117+
118+
int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO);
119+
if (bestAudioStream >= 0) {
120+
containerMetadata_.bestAudioStreamIndex = bestAudioStream;
121+
}
122+
103123
for (unsigned int i = 0; i < formatContext_->nb_streams; i++) {
104124
AVStream* avStream = formatContext_->streams[i];
105125
StreamMetadata streamMetadata;
@@ -110,8 +130,8 @@ void SingleStreamDecoder::initializeDecoder() {
110130
", does not match AVStream's index, " +
111131
std::to_string(avStream->index) + ".");
112132
streamMetadata.streamIndex = i;
113-
streamMetadata.mediaType = avStream->codecpar->codec_type;
114133
streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id);
134+
streamMetadata.mediaType = avStream->codecpar->codec_type;
115135
streamMetadata.bitRate = avStream->codecpar->bit_rate;
116136

117137
int64_t frameCount = avStream->nb_frames;
@@ -133,10 +153,18 @@ void SingleStreamDecoder::initializeDecoder() {
133153
if (fps > 0) {
134154
streamMetadata.averageFpsFromHeader = fps;
135155
}
156+
streamMetadata.width = avStream->codecpar->width;
157+
streamMetadata.height = avStream->codecpar->height;
158+
streamMetadata.sampleAspectRatio =
159+
avStream->codecpar->sample_aspect_ratio;
136160
containerMetadata_.numVideoStreams++;
137161
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
138162
AVSampleFormat format =
139163
static_cast<AVSampleFormat>(avStream->codecpar->format);
164+
streamMetadata.sampleRate =
165+
static_cast<int64_t>(avStream->codecpar->sample_rate);
166+
streamMetadata.numChannels =
167+
static_cast<int64_t>(getNumChannels(avStream->codecpar));
140168

141169
// If the AVSampleFormat is not recognized, we get back nullptr. We have
142170
// to make sure we don't initialize a std::string with nullptr. There's
@@ -149,27 +177,10 @@ void SingleStreamDecoder::initializeDecoder() {
149177
containerMetadata_.numAudioStreams++;
150178
}
151179

152-
containerMetadata_.allStreamMetadata.push_back(streamMetadata);
153-
}
180+
streamMetadata.durationSecondsFromContainer =
181+
containerMetadata_.durationSecondsFromHeader;
154182

155-
if (formatContext_->duration > 0) {
156-
AVRational defaultTimeBase{1, AV_TIME_BASE};
157-
containerMetadata_.durationSecondsFromHeader =
158-
ptsToSeconds(formatContext_->duration, defaultTimeBase);
159-
}
160-
161-
if (formatContext_->bit_rate > 0) {
162-
containerMetadata_.bitRate = formatContext_->bit_rate;
163-
}
164-
165-
int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO);
166-
if (bestVideoStream >= 0) {
167-
containerMetadata_.bestVideoStreamIndex = bestVideoStream;
168-
}
169-
170-
int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO);
171-
if (bestAudioStream >= 0) {
172-
containerMetadata_.bestAudioStreamIndex = bestAudioStream;
183+
containerMetadata_.allStreamMetadata.push_back(streamMetadata);
173184
}
174185

175186
if (seekMode_ == SeekMode::exact) {
@@ -524,11 +535,6 @@ void SingleStreamDecoder::addVideoStream(
524535
auto& streamInfo = streamInfos_[activeStreamIndex_];
525536
streamInfo.videoStreamOptions = videoStreamOptions;
526537

527-
streamMetadata.width = streamInfo.codecContext->width;
528-
streamMetadata.height = streamInfo.codecContext->height;
529-
streamMetadata.sampleAspectRatio =
530-
streamInfo.codecContext->sample_aspect_ratio;
531-
532538
if (seekMode_ == SeekMode::custom_frame_mappings) {
533539
TORCH_CHECK(
534540
customFrameMappings.has_value(),
@@ -574,13 +580,6 @@ void SingleStreamDecoder::addAudioStream(
574580
auto& streamInfo = streamInfos_[activeStreamIndex_];
575581
streamInfo.audioStreamOptions = audioStreamOptions;
576582

577-
auto& streamMetadata =
578-
containerMetadata_.allStreamMetadata[activeStreamIndex_];
579-
streamMetadata.sampleRate =
580-
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
581-
streamMetadata.numChannels =
582-
static_cast<int64_t>(getNumChannels(streamInfo.codecContext));
583-
584583
// FFmpeg docs say that the decoder will try to decode natively in this
585584
// format, if it can. Docs don't say what the decoder does when it doesn't
586585
// support that format, but it looks like it does nothing, so this probably

src/torchcodec/_core/_metadata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class StreamMetadata:
4444
from the actual frames if a :term:`scan` was performed. Otherwise we
4545
fall back to ``duration_seconds_from_header``. If that value is also None,
4646
we instead calculate the duration from ``num_frames_from_header`` and
47-
``average_fps_from_header``.
47+
``average_fps_from_header``. If all of those are unavailable, we fall back
48+
to the container-level ``duration_seconds_from_header``.
4849
"""
4950
begin_stream_seconds: Optional[float]
5051
"""Beginning of the stream, in seconds (float). Conceptually, this

src/torchcodec/_core/custom_ops.cpp

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,34 @@ SeekMode seekModeFromString(std::string_view seekMode) {
198198
}
199199
}
200200

201+
void writeFallbackBasedMetadata(
202+
std::map<std::string, std::string>& map,
203+
const StreamMetadata& streamMetadata,
204+
SeekMode seekMode) {
205+
auto durationSeconds = streamMetadata.getDurationSeconds(seekMode);
206+
if (durationSeconds.has_value()) {
207+
map["durationSeconds"] = std::to_string(durationSeconds.value());
208+
}
209+
210+
auto numFrames = streamMetadata.getNumFrames(seekMode);
211+
if (numFrames.has_value()) {
212+
map["numFrames"] = std::to_string(numFrames.value());
213+
}
214+
215+
double beginStreamSeconds = streamMetadata.getBeginStreamSeconds(seekMode);
216+
map["beginStreamSeconds"] = std::to_string(beginStreamSeconds);
217+
218+
auto endStreamSeconds = streamMetadata.getEndStreamSeconds(seekMode);
219+
if (endStreamSeconds.has_value()) {
220+
map["endStreamSeconds"] = std::to_string(endStreamSeconds.value());
221+
}
222+
223+
auto averageFps = streamMetadata.getAverageFps(seekMode);
224+
if (averageFps.has_value()) {
225+
map["averageFps"] = std::to_string(averageFps.value());
226+
}
227+
}
228+
201229
int checkedToPositiveInt(const std::string& str) {
202230
int ret = 0;
203231
try {
@@ -917,30 +945,28 @@ std::string get_stream_json_metadata(
917945
// In approximate mode: content-based metadata does not exist for any stream.
918946
// In custom_frame_mappings: content-based metadata exists only for the active
919947
// stream.
948+
//
920949
// Our fallback logic assumes content-based metadata is available.
921950
// It is available for decoding on the active stream, but would break
922951
// when getting metadata from non-active streams.
923952
if ((seekMode != SeekMode::custom_frame_mappings) ||
924953
(seekMode == SeekMode::custom_frame_mappings &&
925954
stream_index == activeStreamIndex)) {
926-
if (streamMetadata.getDurationSeconds(seekMode).has_value()) {
927-
map["durationSeconds"] =
928-
std::to_string(streamMetadata.getDurationSeconds(seekMode).value());
929-
}
930-
if (streamMetadata.getNumFrames(seekMode).has_value()) {
931-
map["numFrames"] =
932-
std::to_string(streamMetadata.getNumFrames(seekMode).value());
933-
}
934-
map["beginStreamSeconds"] =
935-
std::to_string(streamMetadata.getBeginStreamSeconds(seekMode));
936-
if (streamMetadata.getEndStreamSeconds(seekMode).has_value()) {
937-
map["endStreamSeconds"] =
938-
std::to_string(streamMetadata.getEndStreamSeconds(seekMode).value());
939-
}
940-
if (streamMetadata.getAverageFps(seekMode).has_value()) {
941-
map["averageFps"] =
942-
std::to_string(streamMetadata.getAverageFps(seekMode).value());
943-
}
955+
writeFallbackBasedMetadata(map, streamMetadata, seekMode);
956+
} else if (seekMode == SeekMode::custom_frame_mappings) {
957+
// If this is not the active stream, then we don't have content-based
958+
// metadata for custom frame mappings. In that case, we want the same
959+
// behavior as we would get with approximate mode. Encoding this behavior in
960+
// the fallback logic itself is tricky and not worth it for this corner
961+
// case. So we hardcode in approximate mode.
962+
//
963+
// TODO: This hacky behavior is only necessary because the custom frame
964+
// mapping is supplied in SingleStreamDecoder::addVideoStream() rather
965+
// than in the constructor. And it's supplied to addVideoStream() and
966+
// not the constructor because we need to know the stream index. If we
967+
// can encode the relevant stream indices into custom frame mappings
968+
// itself, then we can put it in the constructor.
969+
writeFallbackBasedMetadata(map, streamMetadata, SeekMode::approximate);
944970
}
945971

946972
return mapToJson(map);

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,6 @@ def __init__(
6363
torch._C._log_api_usage_once("torchcodec.decoders.AudioDecoder")
6464
self._decoder = create_decoder(source=source, seek_mode="approximate")
6565

66-
core.add_audio_stream(
67-
self._decoder,
68-
stream_index=stream_index,
69-
sample_rate=sample_rate,
70-
num_channels=num_channels,
71-
)
72-
7366
container_metadata = core.get_container_metadata(self._decoder)
7467
self.stream_index = (
7568
container_metadata.best_audio_stream_index
@@ -81,13 +74,28 @@ def __init__(
8174
"The best audio stream is unknown and there is no specified stream. "
8275
+ ERROR_REPORTING_INSTRUCTIONS
8376
)
77+
if self.stream_index >= len(container_metadata.streams):
78+
raise ValueError(
79+
f"The stream at index {stream_index} is not a valid stream."
80+
)
81+
8482
self.metadata = container_metadata.streams[self.stream_index]
85-
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
83+
if not isinstance(self.metadata, core._metadata.AudioStreamMetadata):
84+
raise ValueError(
85+
f"The stream at index {stream_index} is not an audio stream. "
86+
)
8687

8788
self._desired_sample_rate = (
8889
sample_rate if sample_rate is not None else self.metadata.sample_rate
8990
)
9091

92+
core.add_audio_stream(
93+
self._decoder,
94+
stream_index=stream_index,
95+
sample_rate=sample_rate,
96+
num_channels=num_channels,
97+
)
98+
9199
def get_all_samples(self) -> AudioSamples:
92100
"""Returns all the audio samples from the source.
93101

0 commit comments

Comments
 (0)