|
| 1 | +#!/usr/bin/env python3 |
| 2 | +import subprocess |
| 3 | +import tempfile |
| 4 | +from argparse import ArgumentParser |
| 5 | +from pathlib import Path |
| 6 | +from time import perf_counter_ns |
| 7 | + |
| 8 | +import psutil |
| 9 | +import torch |
| 10 | +from torchcodec.decoders import VideoDecoder |
| 11 | +from torchcodec.encoders import VideoEncoder |
| 12 | + |
| 13 | +# GPU monitoring imports (install with: pip install nvidia-ml-py) |
| 14 | +try: |
| 15 | + import pynvml |
| 16 | + |
| 17 | + GPU_MONITORING_AVAILABLE = True |
| 18 | +except ImportError: |
| 19 | + print("To enable GPU monitoring, install pynvml with: pip install nvidia-ml-py") |
| 20 | + GPU_MONITORING_AVAILABLE = False |
| 21 | + |
| 22 | +DEFAULT_VIDEO_PATH = "test/resources/nasa_13013.mp4" |
| 23 | +# Alternatively, run this command to generate a longer test video: |
| 24 | +# ffmpeg -f lavfi -i testsrc2=duration=600:size=1280x720:rate=30 -c:v libx264 -pix_fmt yuv420p test/resources/testsrc2_10min.mp4 |
| 25 | +# DEFAULT_VIDEO_PATH = "test/resources/testsrc2_10min.mp4" |
| 26 | +DEFAULT_AVERAGE_OVER = 30 |
| 27 | +DEFAULT_MAX_FRAMES = 300 |
| 28 | + |
| 29 | + |
| 30 | +def gpu_percent(): |
| 31 | + if not GPU_MONITORING_AVAILABLE: |
| 32 | + return 0.0 |
| 33 | + try: |
| 34 | + pynvml.nvmlInit() |
| 35 | + handle = pynvml.nvmlDeviceGetHandleByIndex(0) |
| 36 | + util = pynvml.nvmlDeviceGetUtilizationRates(handle) |
| 37 | + return float(util.gpu) |
| 38 | + except Exception: |
| 39 | + return 0.0 |
| 40 | + |
| 41 | + |
| 42 | +def bench(f, average_over=50, warmup=2, **f_kwargs): |
| 43 | + for _ in range(warmup): |
| 44 | + f(**f_kwargs) |
| 45 | + |
| 46 | + times = [] |
| 47 | + cpu_utils = [] |
| 48 | + gpu_utils = [] |
| 49 | + |
| 50 | + for _ in range(average_over): |
| 51 | + psutil.cpu_percent(interval=None) |
| 52 | + |
| 53 | + start = perf_counter_ns() |
| 54 | + f(**f_kwargs) |
| 55 | + end = perf_counter_ns() |
| 56 | + |
| 57 | + cpu_util = psutil.cpu_percent(interval=None) |
| 58 | + gpu_util = gpu_percent() |
| 59 | + |
| 60 | + times.append(end - start) |
| 61 | + cpu_utils.append(cpu_util) |
| 62 | + gpu_utils.append(gpu_util) |
| 63 | + |
| 64 | + times_tensor = torch.tensor(times).float() |
| 65 | + cpu_tensor = torch.tensor(cpu_utils).float() |
| 66 | + gpu_tensor = torch.tensor(gpu_utils).float() |
| 67 | + |
| 68 | + return times_tensor, cpu_tensor, gpu_tensor |
| 69 | + |
| 70 | + |
| 71 | +def report_stats( |
| 72 | + times, num_frames, cpu_utils=None, gpu_utils=None, prefix="", unit="ms" |
| 73 | +): |
| 74 | + mul = { |
| 75 | + "ns": 1, |
| 76 | + "µs": 1e-3, |
| 77 | + "ms": 1e-6, |
| 78 | + "s": 1e-9, |
| 79 | + }[unit] |
| 80 | + unit_times = times * mul |
| 81 | + std = unit_times.std().item() |
| 82 | + med = unit_times.median().item() |
| 83 | + mean = unit_times.mean().item() |
| 84 | + min_time = unit_times.min().item() |
| 85 | + max_time = unit_times.max().item() |
| 86 | + print( |
| 87 | + f"\n{prefix}: {med = :.2f}, {mean = :.2f} +- {std:.2f}, {min_time = :.2f}, {max_time = :.2f} - in {unit}" |
| 88 | + ) |
| 89 | + fps = num_frames / (times * 1e-9) |
| 90 | + std = fps.std().item() |
| 91 | + med = fps.median().item() |
| 92 | + max_fps = fps.max().item() |
| 93 | + print(f"{med = :.1f} fps +- {std:.1f}, {max_fps = :.1f}") |
| 94 | + |
| 95 | + if cpu_utils is not None: |
| 96 | + cpu_avg = cpu_utils.mean().item() |
| 97 | + cpu_peak = cpu_utils.max().item() |
| 98 | + print(f"CPU utilization: avg = {cpu_avg:.1f}%, peak = {cpu_peak:.1f}%") |
| 99 | + |
| 100 | + if gpu_utils is not None and gpu_utils.numel() > 0: |
| 101 | + gpu_avg = gpu_utils.mean().item() |
| 102 | + gpu_peak = gpu_utils.max().item() |
| 103 | + print(f"GPU utilization: avg = {gpu_avg:.1f}%, peak = {gpu_peak:.1f}%") |
| 104 | + |
| 105 | + |
| 106 | +def encode_torchcodec(frames, output_path, device="cpu"): |
| 107 | + if device == "cuda": |
| 108 | + # Move frames to GPU |
| 109 | + gpu_frames = frames.cuda() if frames.device.type == "cpu" else frames |
| 110 | + encoder = VideoEncoder(frames=gpu_frames, frame_rate=30, device="cuda") |
| 111 | + encoder.to_file(dest=output_path, codec="h264_nvenc", extra_options={"qp": 1}) |
| 112 | + else: |
| 113 | + encoder = VideoEncoder(frames=frames, frame_rate=30, device="cpu") |
| 114 | + encoder.to_file(dest=output_path, codec="libx264", crf=0) |
| 115 | + |
| 116 | + |
| 117 | +def write_raw_frames(frames, raw_path): |
| 118 | + # Convert NCHW to NHWC for raw video format |
| 119 | + raw_frames = frames.permute(0, 2, 3, 1).contiguous() |
| 120 | + with open(raw_path, "wb") as f: |
| 121 | + f.write(raw_frames.cpu().numpy().tobytes()) |
| 122 | + |
| 123 | + |
| 124 | +def encode_ffmpeg_cli(raw_path, frames_shape, output_path, device="cpu", codec=None): |
| 125 | + height, width = frames_shape[2], frames_shape[3] |
| 126 | + |
| 127 | + if device == "cuda": |
| 128 | + codec = "h264_nvenc" |
| 129 | + quality_params = ["-qp", "0"] |
| 130 | + else: |
| 131 | + codec = "libx264" |
| 132 | + quality_params = ["-crf", "0"] |
| 133 | + |
| 134 | + ffmpeg_cmd = [ |
| 135 | + "ffmpeg", |
| 136 | + "-y", |
| 137 | + "-f", |
| 138 | + "rawvideo", |
| 139 | + "-pix_fmt", |
| 140 | + "rgb24", |
| 141 | + "-s", |
| 142 | + f"{width}x{height}", |
| 143 | + "-r", |
| 144 | + "30", # frame_rate is 30 |
| 145 | + "-i", |
| 146 | + raw_path, |
| 147 | + "-c:v", |
| 148 | + codec, |
| 149 | + "-pix_fmt", |
| 150 | + "yuv420p", |
| 151 | + ] |
| 152 | + ffmpeg_cmd.extend(quality_params) |
| 153 | + # By not setting threads, allow FFmpeg to choose. |
| 154 | + # ffmpeg_cmd.extend(["-threads", "1"]) |
| 155 | + ffmpeg_cmd.extend([str(output_path)]) |
| 156 | + |
| 157 | + subprocess.run(ffmpeg_cmd, check=True, capture_output=True) |
| 158 | + |
| 159 | + |
| 160 | +def main(): |
| 161 | + parser = ArgumentParser() |
| 162 | + parser.add_argument( |
| 163 | + "--path", type=str, help="Path to input video file", default=DEFAULT_VIDEO_PATH |
| 164 | + ) |
| 165 | + parser.add_argument( |
| 166 | + "--average-over", |
| 167 | + type=int, |
| 168 | + default=DEFAULT_AVERAGE_OVER, |
| 169 | + help="Number of runs to average over", |
| 170 | + ) |
| 171 | + parser.add_argument( |
| 172 | + "--max-frames", |
| 173 | + type=int, |
| 174 | + default=DEFAULT_MAX_FRAMES, |
| 175 | + help="Maximum number of frames to decode for benchmarking", |
| 176 | + ) |
| 177 | + |
| 178 | + args = parser.parse_args() |
| 179 | + |
| 180 | + print( |
| 181 | + f"Benchmarking up to {args.max_frames} frames from {Path(args.path).name} over {args.average_over} runs:" |
| 182 | + ) |
| 183 | + cuda_available = torch.cuda.is_available() |
| 184 | + if not cuda_available: |
| 185 | + print("CUDA not available. GPU benchmarks will be skipped.") |
| 186 | + |
| 187 | + # Load up to max_frames frames |
| 188 | + decoder = VideoDecoder(str(args.path)) |
| 189 | + frames = decoder.get_frames_in_range( |
| 190 | + start=0, stop=min(args.max_frames, len(decoder)) |
| 191 | + ).data |
| 192 | + print( |
| 193 | + f"Loaded {frames.shape[0]} frames of size {frames.shape[2]}x{frames.shape[3]}" |
| 194 | + ) |
| 195 | + |
| 196 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 197 | + temp_dir = Path(temp_dir) |
| 198 | + raw_frames_path = temp_dir / "input_frames.raw" |
| 199 | + write_raw_frames(frames, str(raw_frames_path)) |
| 200 | + |
| 201 | + # Benchmark torchcodec on GPU |
| 202 | + if cuda_available: |
| 203 | + gpu_output = temp_dir / "torchcodec_gpu.mp4" |
| 204 | + times, _cpu_utils, gpu_utils = bench( |
| 205 | + encode_torchcodec, |
| 206 | + frames=frames, |
| 207 | + output_path=str(gpu_output), |
| 208 | + device="cuda", |
| 209 | + average_over=args.average_over, |
| 210 | + warmup=1, |
| 211 | + ) |
| 212 | + report_stats( |
| 213 | + times, frames.shape[0], None, gpu_utils, prefix="VideoEncoder on GPU" |
| 214 | + ) |
| 215 | + else: |
| 216 | + print("Skipping VideoEncoder GPU benchmark (CUDA not available)") |
| 217 | + |
| 218 | + # Benchmark FFmpeg CLI on GPU |
| 219 | + if cuda_available: |
| 220 | + ffmpeg_gpu_output = temp_dir / "ffmpeg_gpu.mp4" |
| 221 | + times, _cpu_utils, gpu_utils = bench( |
| 222 | + encode_ffmpeg_cli, |
| 223 | + raw_path=str(raw_frames_path), |
| 224 | + frames_shape=frames.shape, |
| 225 | + output_path=str(ffmpeg_gpu_output), |
| 226 | + device="cuda", |
| 227 | + average_over=args.average_over, |
| 228 | + warmup=1, |
| 229 | + ) |
| 230 | + report_stats( |
| 231 | + times, frames.shape[0], None, gpu_utils, prefix="FFmpeg CLI on GPU" |
| 232 | + ) |
| 233 | + else: |
| 234 | + print("Skipping FFmpeg CLI GPU benchmark (CUDA not available)") |
| 235 | + |
| 236 | + # Benchmark torchcodec on CPU |
| 237 | + cpu_output = temp_dir / "torchcodec_cpu.mp4" |
| 238 | + times, cpu_utils, _gpu_utils = bench( |
| 239 | + encode_torchcodec, |
| 240 | + frames=frames, |
| 241 | + output_path=str(cpu_output), |
| 242 | + device="cpu", |
| 243 | + average_over=args.average_over, |
| 244 | + warmup=1, |
| 245 | + ) |
| 246 | + report_stats( |
| 247 | + times, frames.shape[0], cpu_utils, None, prefix="VideoEncoder on CPU" |
| 248 | + ) |
| 249 | + |
| 250 | + # Benchmark FFmpeg CLI on CPU |
| 251 | + ffmpeg_cpu_output = temp_dir / "ffmpeg_cpu.mp4" |
| 252 | + times, cpu_utils, _gpu_utils = bench( |
| 253 | + encode_ffmpeg_cli, |
| 254 | + raw_path=str(raw_frames_path), |
| 255 | + frames_shape=frames.shape, |
| 256 | + output_path=str(ffmpeg_cpu_output), |
| 257 | + device="cpu", |
| 258 | + average_over=args.average_over, |
| 259 | + warmup=1, |
| 260 | + ) |
| 261 | + report_stats( |
| 262 | + times, frames.shape[0], cpu_utils, None, prefix="FFmpeg CLI on CPU" |
| 263 | + ) |
| 264 | + |
| 265 | + |
| 266 | +if __name__ == "__main__": |
| 267 | + main() |
0 commit comments