Skip to content

Commit 84fd7a5

Browse files
committed
Refactor order of getting metadata and adding a stream
1 parent 1093339 commit 84fd7a5

File tree

7 files changed

+64
-28
lines changed

7 files changed

+64
-28
lines changed

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ int getNumChannels(const SharedAVCodecContext& avCodecContext) {
158158
#endif
159159
}
160160

161+
int getNumChannels(const AVCodecParameters* codecpar) {
162+
TORCH_CHECK(codecpar != nullptr, "codecpar is null")
163+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
164+
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
165+
return codecpar->ch_layout.nb_channels;
166+
#else
167+
return codecpar->channels;
168+
#endif
169+
}
170+
161171
void setDefaultChannelLayout(
162172
UniqueAVCodecContext& avCodecContext,
163173
int numChannels) {

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);
180180

181181
int getNumChannels(const UniqueAVFrame& avFrame);
182182
int getNumChannels(const SharedAVCodecContext& avCodecContext);
183+
int getNumChannels(const AVCodecParameters* codecpar);
183184

184185
void setDefaultChannelLayout(
185186
UniqueAVCodecContext& avCodecContext,

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ void SingleStreamDecoder::initializeDecoder() {
110110
", does not match AVStream's index, " +
111111
std::to_string(avStream->index) + ".");
112112
streamMetadata.streamIndex = i;
113-
streamMetadata.mediaType = avStream->codecpar->codec_type;
114113
streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id);
114+
streamMetadata.mediaType = avStream->codecpar->codec_type;
115115
streamMetadata.bitRate = avStream->codecpar->bit_rate;
116116

117117
int64_t frameCount = avStream->nb_frames;
@@ -133,10 +133,18 @@ void SingleStreamDecoder::initializeDecoder() {
133133
if (fps > 0) {
134134
streamMetadata.averageFpsFromHeader = fps;
135135
}
136+
streamMetadata.width = avStream->codecpar->width;
137+
streamMetadata.height = avStream->codecpar->height;
138+
streamMetadata.sampleAspectRatio =
139+
avStream->codecpar->sample_aspect_ratio;
136140
containerMetadata_.numVideoStreams++;
137141
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
138142
AVSampleFormat format =
139143
static_cast<AVSampleFormat>(avStream->codecpar->format);
144+
streamMetadata.sampleRate =
145+
static_cast<int64_t>(avStream->codecpar->sample_rate);
146+
streamMetadata.numChannels =
147+
static_cast<int64_t>(getNumChannels(avStream->codecpar));
140148

141149
// If the AVSampleFormat is not recognized, we get back nullptr. We have
142150
// to make sure we don't initialize a std::string with nullptr. There's
@@ -516,6 +524,10 @@ void SingleStreamDecoder::addVideoStream(
516524
auto& streamInfo = streamInfos_[activeStreamIndex_];
517525
streamInfo.videoStreamOptions = videoStreamOptions;
518526

527+
// This metadata was already set in initializeDecoder() from the
528+
// AVCodecParameters that are part of the AVStream. But we consider the
529+
// AVCodecContext to be more authoritative, so we use that for our decoding
530+
// stream.
519531
streamMetadata.width = streamInfo.codecContext->width;
520532
streamMetadata.height = streamInfo.codecContext->height;
521533
streamMetadata.sampleAspectRatio =
@@ -568,6 +580,11 @@ void SingleStreamDecoder::addAudioStream(
568580

569581
auto& streamMetadata =
570582
containerMetadata_.allStreamMetadata[activeStreamIndex_];
583+
584+
// This metadata was already set in initializeDecoder() from the
585+
// AVCodecParameters that are part of the AVStream. But we consider the
586+
// AVCodecContext to be more authoritative, so we use that for our decoding
587+
// stream.
571588
streamMetadata.sampleRate =
572589
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
573590
streamMetadata.numChannels =

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,6 @@ def __init__(
6363
torch._C._log_api_usage_once("torchcodec.decoders.AudioDecoder")
6464
self._decoder = create_decoder(source=source, seek_mode="approximate")
6565

66-
core.add_audio_stream(
67-
self._decoder,
68-
stream_index=stream_index,
69-
sample_rate=sample_rate,
70-
num_channels=num_channels,
71-
)
72-
7366
container_metadata = core.get_container_metadata(self._decoder)
7467
self.stream_index = (
7568
container_metadata.best_audio_stream_index
@@ -81,13 +74,28 @@ def __init__(
8174
"The best audio stream is unknown and there is no specified stream. "
8275
+ ERROR_REPORTING_INSTRUCTIONS
8376
)
77+
if self.stream_index >= len(container_metadata.streams):
78+
raise ValueError(
79+
f"The stream at index {stream_index} is not a valid stream."
80+
)
81+
8482
self.metadata = container_metadata.streams[self.stream_index]
85-
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
83+
if not isinstance(self.metadata, core._metadata.AudioStreamMetadata):
84+
raise ValueError(
85+
f"The stream at index {stream_index} is not an audio stream. "
86+
)
8687

8788
self._desired_sample_rate = (
8889
sample_rate if sample_rate is not None else self.metadata.sample_rate
8990
)
9091

92+
core.add_audio_stream(
93+
self._decoder,
94+
stream_index=stream_index,
95+
sample_rate=sample_rate,
96+
num_channels=num_channels,
97+
)
98+
9199
def get_all_samples(self) -> AudioSamples:
92100
"""Returns all the audio samples from the source.
93101

src/torchcodec/decoders/_video_decoder.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ def __init__(
141141

142142
self._decoder = create_decoder(source=source, seek_mode=seek_mode)
143143

144+
(
145+
self.metadata,
146+
self.stream_index,
147+
self._begin_stream_seconds,
148+
self._end_stream_seconds,
149+
self._num_frames,
150+
) = _get_and_validate_stream_metadata(
151+
decoder=self._decoder, stream_index=stream_index
152+
)
153+
144154
allowed_dimension_orders = ("NCHW", "NHWC")
145155
if dimension_order not in allowed_dimension_orders:
146156
raise ValueError(
@@ -157,12 +167,11 @@ def __init__(
157167
device = str(device)
158168

159169
device_variant = _get_cuda_backend()
160-
161170
transform_specs = _make_transform_specs(transforms)
162171

163172
core.add_video_stream(
164173
self._decoder,
165-
stream_index=stream_index,
174+
stream_index=self.stream_index,
166175
dimension_order=dimension_order,
167176
num_threads=num_ffmpeg_threads,
168177
device=device,
@@ -171,16 +180,6 @@ def __init__(
171180
custom_frame_mappings=custom_frame_mappings_data,
172181
)
173182

174-
(
175-
self.metadata,
176-
self.stream_index,
177-
self._begin_stream_seconds,
178-
self._end_stream_seconds,
179-
self._num_frames,
180-
) = _get_and_validate_stream_metadata(
181-
decoder=self._decoder, stream_index=stream_index
182-
)
183-
184183
def __len__(self) -> int:
185184
return self._num_frames
186185

@@ -413,8 +412,12 @@ def _get_and_validate_stream_metadata(
413412
+ ERROR_REPORTING_INSTRUCTIONS
414413
)
415414

415+
if stream_index >= len(container_metadata.streams):
416+
raise ValueError(f"The stream index {stream_index} is not a valid stream.")
417+
416418
metadata = container_metadata.streams[stream_index]
417-
assert isinstance(metadata, core._metadata.VideoStreamMetadata) # mypy
419+
if not isinstance(metadata, core._metadata.VideoStreamMetadata):
420+
raise ValueError(f"The stream at index {stream_index} is not a video stream. ")
418421

419422
if metadata.begin_stream_seconds is None:
420423
raise ValueError(

test/test_decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ def test_create_fails(self, Decoder):
116116
Decoder(123)
117117

118118
# stream index that does not exist
119-
with pytest.raises(ValueError, match="No valid stream found"):
119+
with pytest.raises(ValueError, match="40 is not a valid stream"):
120120
Decoder(NASA_VIDEO.path, stream_index=40)
121121

122122
# stream index that does exist, but it's not audio or video
123-
with pytest.raises(ValueError, match="No valid stream found"):
123+
with pytest.raises(ValueError, match=r"not (a|an) (video|audio) stream"):
124124
Decoder(NASA_VIDEO.path, stream_index=2)
125125

126126
# user mistakenly forgets to specify binary reading when creating a file

test/test_metadata.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def test_get_metadata(metadata_getter):
5959
)
6060
if (seek_mode == "custom_frame_mappings") and get_ffmpeg_major_version() in (4, 5):
6161
pytest.skip(reason="ffprobe isn't accurate on ffmpeg 4 and 5")
62-
with_added_video_stream = seek_mode == "custom_frame_mappings"
6362
metadata = metadata_getter(NASA_VIDEO.path)
6463

6564
with_scan = (
@@ -99,9 +98,7 @@ def test_get_metadata(metadata_getter):
9998
assert best_video_stream_metadata.begin_stream_seconds_from_header == 0
10099
assert best_video_stream_metadata.bit_rate == 128783
101100
assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001)
102-
assert best_video_stream_metadata.pixel_aspect_ratio == (
103-
Fraction(1, 1) if with_added_video_stream else None
104-
)
101+
assert best_video_stream_metadata.pixel_aspect_ratio == Fraction(1, 1)
105102
assert best_video_stream_metadata.codec == "h264"
106103
assert best_video_stream_metadata.num_frames_from_content == (
107104
390 if with_scan else None

0 commit comments

Comments
 (0)