@@ -367,6 +367,14 @@ ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
367367 return containerMetadata_;
368368}
369369
370+ SeekMode SingleStreamDecoder::getSeekMode() const {
371+ return seekMode_;
372+ }
373+
374+ int SingleStreamDecoder::getActiveStreamIndex() const {
375+ return activeStreamIndex_;
376+ }
377+
370378torch::Tensor SingleStreamDecoder::getKeyFrameIndices() {
371379 validateActiveStream(AVMEDIA_TYPE_VIDEO);
372380 validateScannedAllStreams("getKeyFrameIndices");
@@ -611,7 +619,7 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
611619 const auto& streamMetadata =
612620 containerMetadata_.allStreamMetadata[activeStreamIndex_];
613621
614- std::optional<int64_t> numFrames = getNumFrames(streamMetadata );
622+ std::optional<int64_t> numFrames = streamMetadata. getNumFrames(seekMode_ );
615623 if (numFrames.has_value()) {
616624 // If the frameIndex is negative, we convert it to a positive index
617625 frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value();
@@ -705,7 +713,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
705713
706714 // Note that if we do not have the number of frames available in our
707715 // metadata, then we assume that the upper part of the range is valid.
708- std::optional<int64_t> numFrames = getNumFrames(streamMetadata );
716+ std::optional<int64_t> numFrames = streamMetadata. getNumFrames(seekMode_ );
709717 if (numFrames.has_value()) {
710718 TORCH_CHECK(
711719 stop <= numFrames.value(),
@@ -779,8 +787,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
779787 const auto& streamMetadata =
780788 containerMetadata_.allStreamMetadata[activeStreamIndex_];
781789
782- double minSeconds = getMinSeconds(streamMetadata);
783- std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
790+ double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
791+ std::optional<double> maxSeconds =
792+ streamMetadata.getEndStreamSeconds(seekMode_);
784793
785794 // The frame played at timestamp t and the one played at timestamp `t +
786795 // eps` are probably the same frame, with the same index. The easiest way to
@@ -857,7 +866,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
857866 return frameBatchOutput;
858867 }
859868
860- double minSeconds = getMinSeconds( streamMetadata);
869+ double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_ );
861870 TORCH_CHECK(
862871 startSeconds >= minSeconds,
863872 "Start seconds is " + std::to_string(startSeconds) +
@@ -866,7 +875,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
866875
867876 // Note that if we can't determine the maximum seconds from the metadata,
868877 // then we assume upper range is valid.
869- std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
878+ std::optional<double> maxSeconds =
879+ streamMetadata.getEndStreamSeconds(seekMode_);
870880 if (maxSeconds.has_value()) {
871881 TORCH_CHECK(
872882 startSeconds < maxSeconds.value(),
@@ -1439,47 +1449,6 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
14391449// STREAM AND METADATA APIS
14401450// --------------------------------------------------------------------------
14411451
1442- std::optional<int64_t> SingleStreamDecoder::getNumFrames(
1443- const StreamMetadata& streamMetadata) {
1444- switch (seekMode_) {
1445- case SeekMode::custom_frame_mappings:
1446- case SeekMode::exact:
1447- return streamMetadata.numFramesFromContent.value();
1448- case SeekMode::approximate: {
1449- return streamMetadata.numFramesFromHeader;
1450- }
1451- default:
1452- TORCH_CHECK(false, "Unknown SeekMode");
1453- }
1454- }
1455-
1456- double SingleStreamDecoder::getMinSeconds(
1457- const StreamMetadata& streamMetadata) {
1458- switch (seekMode_) {
1459- case SeekMode::custom_frame_mappings:
1460- case SeekMode::exact:
1461- return streamMetadata.beginStreamPtsSecondsFromContent.value();
1462- case SeekMode::approximate:
1463- return 0;
1464- default:
1465- TORCH_CHECK(false, "Unknown SeekMode");
1466- }
1467- }
1468-
1469- std::optional<double> SingleStreamDecoder::getMaxSeconds(
1470- const StreamMetadata& streamMetadata) {
1471- switch (seekMode_) {
1472- case SeekMode::custom_frame_mappings:
1473- case SeekMode::exact:
1474- return streamMetadata.endStreamPtsSecondsFromContent.value();
1475- case SeekMode::approximate: {
1476- return streamMetadata.durationSecondsFromHeader;
1477- }
1478- default:
1479- TORCH_CHECK(false, "Unknown SeekMode");
1480- }
1481- }
1482-
14831452// --------------------------------------------------------------------------
14841453// VALIDATION UTILS
14851454// --------------------------------------------------------------------------
@@ -1529,7 +1498,7 @@ void SingleStreamDecoder::validateFrameIndex(
15291498
15301499 // Note that if we do not have the number of frames available in our
15311500 // metadata, then we assume that the frameIndex is valid.
1532- std::optional<int64_t> numFrames = getNumFrames(streamMetadata );
1501+ std::optional<int64_t> numFrames = streamMetadata. getNumFrames(seekMode_ );
15331502 if (numFrames.has_value()) {
15341503 if (frameIndex >= numFrames.value()) {
15351504 throw std::out_of_range(
0 commit comments