Skip to content

Conversation

@NicolasHug
Copy link
Contributor

@NicolasHug NicolasHug commented Nov 7, 2025

We have a "skip seeking" logic where we try to minimize the number of seeks we have to do. This logic lives in

bool SingleStreamDecoder::canWeAvoidSeeking() const {

The problem is: canWeAvoidSeeking() itself can be expensive to call!! This is especially true in approximate mode which, as I just found out, can be slower than exact mode for short videos (10s long).

In this PR, we skip the call to canWeAvoidSeeking() when we can - yes, we "skip the skip-checking logic"!
EDIT: we don't skip the call to canWeAvoidSeeking(), we just add a new condition that causes canWeAvoidSeeking() to return early, before it runs its potentially expensive parts.

We can do that when we are decoding frames contiguously: 0, 1, 2, 3....

This will provide significant speedups when:

  • seek_mode="approximate"
  • frames are decoded contiguously. That can happen in both get_frames_at() and get_frames_played_at().

Why is canWeAvoidSeeking() slow?

Because it calls getKeyFrameIndexForPts() which, in approximate mode, calls av_index_search_timestamp(). Calling this for all frames can dominate the runtime!

Benchmarks:

Decoding all 300 frames of a short 10s long h264 720p video (testsrc2), approximate mode goes from 1249.56ms to 825.74ms (1.5X faster).

~/dev/torchcodec-cuda (avoid_seeking_checks*) » python ~/benchmark_torchcodec_decord.py ~/videos_h264/ --sampling all --num-threads 1   nicolashug@nicolashug-fedora-PW0H326Y
torchcodec.__version__ = '0.9.0a0+afd5aba'
videos: 720x1280, 30.0 fps, 300 frames long
Using 1 thread(s), averaging over 10 runs


# This PR:
=== TorchCodec approx ===
med = 825.74ms +- 1.31, max = 828.85ms
=== TorchCodec exact ===
med = 828.05ms +- 3.69, max = 836.18ms

# On main
=== TorchCodec approx ===
med = 1249.56ms +- 1.81, max = 1253.79ms
=== TorchCodec exact ===
med = 832.75ms +- 0.98, max = 835.03ms

messy (but correct) benchmarking code:

import argparse
from pathlib import Path
from time import perf_counter_ns

import decord
import psutil
import torch
from joblib import delayed, Parallel
import torchcodec
from torchcodec.decoders import VideoDecoder


def bench(f, *args, num_exp=100, warmup=0, **kwargs):
    process = psutil.Process()

    for _ in range(warmup):
        f(*args, **kwargs)

    times = []
    cpu_utils = []
    mem_usages = []

    for _ in range(num_exp):
        psutil.cpu_percent(interval=None)

        start = perf_counter_ns()
        f(*args, **kwargs)
        end = perf_counter_ns()

        cpu_util = psutil.cpu_percent(interval=None)  # since last call
        mem_end = process.memory_info().rss

        times.append(end - start)
        cpu_utils.append(cpu_util)
        mem_usages.append(mem_end)

    return torch.tensor(times).float(), torch.tensor(cpu_utils).float(), torch.tensor(mem_usages).float()


def report_stats(times, cpu_utils=None, mem_usages=None, unit="ms"):
    mul = {
        "ns": 1,
        "µs": 1e-3,
        "ms": 1e-6,
        "s": 1e-9,
    }[unit]
    times = times * mul
    std = times.std().item()
    med = times.median().item()
    max = times.max().item()
    print(f"{med = :.2f}{unit} +- {std:.2f}, {max = :.2f}{unit}")

    if cpu_utils is not None:
        cpu_avg = cpu_utils.mean().item()
        cpu_peak = cpu_utils.max().item()
        print(f"CPU utilization: avg = {cpu_avg:.1f}%, peak = {cpu_peak:.1f}%")

    if mem_usages is not None:
        mem_gb = mem_usages / (1024 ** 3)
        mem_peak = mem_gb.max().item()
        mem_min = mem_gb.min().item()
        mem_delta = mem_peak - mem_min
        print(f"Memory: peak = {mem_peak:.2f}GB, delta = +{mem_delta:.2f}GB")


def decode_one_video_torchcodec(video_path, seek_mode="approximate"):
    decoder = VideoDecoder(str(video_path), device="cpu", seek_mode=seek_mode, num_ffmpeg_threads=1)
    return decoder.get_frames_at(indices)

def decode_one_video_decord(video_path):
    vr = decord.VideoReader(str(video_path), ctx=decord.cpu(), num_threads=1)
    return vr.get_batch(indices.tolist())

def decode_videos(library="torchcodec"):
    if library == "torchcodec":
        decode_one_video = decode_one_video_torchcodec
    elif library == "decord":
        decode_one_video = decode_one_video_decord
    else:
        raise ValueError(f"Unknown library: {library}")

    Parallel(n_jobs=args.num_threads, prefer="threads")(
        delayed(decode_one_video)(video_path) for video_path in video_files
    )


def validate(video_path):
    out_tc = decode_one_video_torchcodec(video_path)
    out_dc = decode_one_video_decord(video_path)

    torch.testing.assert_close(out_tc.data, (out_dc).permute(0, 3, 1, 2), rtol=0, atol=0)
    print("outputs are the same!")


NUM_EXP = 10
parser = argparse.ArgumentParser()
parser.add_argument("video_folder", help="Folder containing .h264 files")
parser.add_argument(
    "--sampling",
    type=str,
    default="all",
    help="Sampling strategy. 'all' for all frames, or an N (int) for N evenly spaced frames.",
)
parser.add_argument(
    "--num-threads",
    type=int,
    default=1,
    help="Number of threads to spawn. Each thread decodes one single video.",
)
args = parser.parse_args()

video_files = list(Path(args.video_folder).glob("*.mp4"))

# We kinda assume all the videos in the folder have the same number of frames
dummy_dec = VideoDecoder(str(video_files[0]), device="cpu")
if str(args.sampling).startswith("first"):
    num_frames_to_samples = int(args.sampling[len("first") :])
    indices = torch.arange(num_frames_to_samples)
elif args.sampling == "all":
    indices = torch.arange(len(dummy_dec))
else:
    num_frames_to_samples = int(args.sampling)
    indices = torch.linspace(
        0, len(dummy_dec) - 1, num_frames_to_samples, dtype=torch.int
    )

decord.bridge.set_bridge("torch")
# validate(video_files[0])

print(f"{torchcodec.__version__ = }")

# print(
#     f"Decoding {args.sampling} frames from {len(video_files)} video files in {args.video_folder}"
# )
print(
    f"videos: {dummy_dec.metadata.height}x{dummy_dec.metadata.width}, {dummy_dec.metadata.average_fps} fps, {dummy_dec.metadata.num_frames} frames long"
)
print(f"Using {args.num_threads} thread(s), averaging over {NUM_EXP} runs")

# print("\n=== TorchCodec ===")
# times_tc, cpu_utils_tc, mem_usages_tc = bench(decode_videos, library="torchcodec", warmup=1, num_exp=NUM_EXP)
# report_stats(times_tc, cpu_utils_tc, mem_usages_tc)

# print("\n=== Decord ===")
# times_dc, cpu_utils_dc, mem_usages_dc = bench(decode_videos, library="decord", warmup=1, num_exp=NUM_EXP)
# report_stats(times_dc, cpu_utils_dc, mem_usages_dc)


print("\n=== TorchCodec approx ===")
times_tc, cpu_utils_tc, mem_usages_tc = bench(decode_one_video_torchcodec, video_path=video_files[0], seek_mode="approximate", warmup=1,  num_exp=NUM_EXP)
report_stats(times_tc, cpu_utils_tc, mem_usages_tc)

print("\n=== TorchCodec exact ===")
times_tc, cpu_utils_tc, mem_usages_tc = bench(decode_one_video_torchcodec, video_path=video_files[0], seek_mode="exact", warmup=1, num_exp=NUM_EXP)
report_stats(times_tc, cpu_utils_tc, mem_usages_tc)

# print("\n=== Decord ===")
# times_dc, cpu_utils_dc, mem_usages_dc = bench(decode_one_video_decord, video_path=video_files[0], warmup=1, num_exp=NUM_EXP)
# report_stats(times_dc, cpu_utils_dc, mem_usages_dc)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 7, 2025
@scotts
Copy link
Contributor

scotts commented Nov 7, 2025

Digging into canWeAvoidSeeking() more, we already have a few optimizations where we back out before we do the mapping. But it also seems like the whole point of the mapping is to know the index of the last decoded frame:

int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts_);

Can't we now replace that call with your newly added lastDecodedFrameIndex_? That would actually speed up all scenarios. We may be able to still apply parts of this "skip if sequential", but I think you've now made the most expensive step unneeded in general.

if (frameIndex != lastDecodedFrameIndex_ + 1) {
int64_t pts = getPts(frameIndex);
setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase));
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I included the changes from #1039 in this PR.

@scotts , does the comment make sense now? What the comment is not saying is why returning early in canWeAvoidSeeking() is important. It's important because it avoids the calls to av_index_search_timestamp which are potentially slow. I don't know if we want to get to that level of detail in the comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this greatly helps with understanding.

I do think we should explain somewhere that canWeAvoidSeeking() is itself expensive in some circumstances. . That's deeply counter-intuitive, that the function we're calling as an optimization to avoid the slow thing is itself also a slow thing. (But hopefully less slow.) Perhaps that should belong at the top of canWeAvoidSeeking().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will address that in an immediate follow-up

@NicolasHug
Copy link
Contributor Author

Replying to #1028 (comment)

Can't we now replace that call with your newly added lastDecodedFrameIndex_?

Unfortunately no, the new "skip condition" I am introducing in this PR is different. Our new condition is to return early when we know there's no need to seek, and there isn't even a need to run further "can we skip seeking" checks.

This existing condition you mentioned:

int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts_);

is part of those "further checks" (and it's actually the part that's expensive!). The idea is that if we were around keyframe K and the new frame we want is also related to K, we don't need to seek. Note that lastDecodedAvFrameIndex is not an accurate name. It's not the index of the last decoded frame, it's the keyframe index of the last decoded frame.

@NicolasHug NicolasHug merged commit b35005d into meta-pytorch:main Nov 12, 2025
57 of 64 checks passed
@NicolasHug NicolasHug deleted the avoid_seeking_checks branch November 12, 2025 11:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants