Skip to content

Commit e4b6d52

Browse files
committed
commit
1 parent bf78468 commit e4b6d52

File tree

1 file changed

+267
-0
lines changed

1 file changed

+267
-0
lines changed
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#!/usr/bin/env python3
2+
import subprocess
3+
import tempfile
4+
from argparse import ArgumentParser
5+
from pathlib import Path
6+
from time import perf_counter_ns
7+
8+
import psutil
9+
import torch
10+
from torchcodec.decoders import VideoDecoder
11+
from torchcodec.encoders import VideoEncoder
12+
13+
# GPU monitoring imports (install with: pip install nvidia-ml-py)
14+
try:
15+
import pynvml
16+
17+
GPU_MONITORING_AVAILABLE = True
18+
except ImportError:
19+
print("To enable GPU monitoring, install pynvml with: pip install nvidia-ml-py")
20+
GPU_MONITORING_AVAILABLE = False
21+
22+
DEFAULT_VIDEO_PATH = "test/resources/nasa_13013.mp4"
23+
# Alternatively, run this command to generate a longer test video:
24+
# ffmpeg -f lavfi -i testsrc2=duration=600:size=1280x720:rate=30 -c:v libx264 -pix_fmt yuv420p test/resources/testsrc2_10min.mp4
25+
# DEFAULT_VIDEO_PATH = "test/resources/testsrc2_10min.mp4"
26+
DEFAULT_AVERAGE_OVER = 30
27+
DEFAULT_MAX_FRAMES = 300
28+
29+
30+
def gpu_percent():
31+
if not GPU_MONITORING_AVAILABLE:
32+
return 0.0
33+
try:
34+
pynvml.nvmlInit()
35+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
36+
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
37+
return float(util.gpu)
38+
except Exception:
39+
return 0.0
40+
41+
42+
def bench(f, average_over=50, warmup=2, **f_kwargs):
43+
for _ in range(warmup):
44+
f(**f_kwargs)
45+
46+
times = []
47+
cpu_utils = []
48+
gpu_utils = []
49+
50+
for _ in range(average_over):
51+
psutil.cpu_percent(interval=None)
52+
53+
start = perf_counter_ns()
54+
f(**f_kwargs)
55+
end = perf_counter_ns()
56+
57+
cpu_util = psutil.cpu_percent(interval=None)
58+
gpu_util = gpu_percent()
59+
60+
times.append(end - start)
61+
cpu_utils.append(cpu_util)
62+
gpu_utils.append(gpu_util)
63+
64+
times_tensor = torch.tensor(times).float()
65+
cpu_tensor = torch.tensor(cpu_utils).float()
66+
gpu_tensor = torch.tensor(gpu_utils).float()
67+
68+
return times_tensor, cpu_tensor, gpu_tensor
69+
70+
71+
def report_stats(
72+
times, num_frames, cpu_utils=None, gpu_utils=None, prefix="", unit="ms"
73+
):
74+
mul = {
75+
"ns": 1,
76+
"µs": 1e-3,
77+
"ms": 1e-6,
78+
"s": 1e-9,
79+
}[unit]
80+
unit_times = times * mul
81+
std = unit_times.std().item()
82+
med = unit_times.median().item()
83+
mean = unit_times.mean().item()
84+
min_time = unit_times.min().item()
85+
max_time = unit_times.max().item()
86+
print(
87+
f"\n{prefix}: {med = :.2f}, {mean = :.2f} +- {std:.2f}, {min_time = :.2f}, {max_time = :.2f} - in {unit}"
88+
)
89+
fps = num_frames / (times * 1e-9)
90+
std = fps.std().item()
91+
med = fps.median().item()
92+
max_fps = fps.max().item()
93+
print(f"{med = :.1f} fps +- {std:.1f}, {max_fps = :.1f}")
94+
95+
if cpu_utils is not None:
96+
cpu_avg = cpu_utils.mean().item()
97+
cpu_peak = cpu_utils.max().item()
98+
print(f"CPU utilization: avg = {cpu_avg:.1f}%, peak = {cpu_peak:.1f}%")
99+
100+
if gpu_utils is not None and gpu_utils.numel() > 0:
101+
gpu_avg = gpu_utils.mean().item()
102+
gpu_peak = gpu_utils.max().item()
103+
print(f"GPU utilization: avg = {gpu_avg:.1f}%, peak = {gpu_peak:.1f}%")
104+
105+
106+
def encode_torchcodec(frames, output_path, device="cpu"):
107+
if device == "cuda":
108+
# Move frames to GPU
109+
gpu_frames = frames.cuda() if frames.device.type == "cpu" else frames
110+
encoder = VideoEncoder(frames=gpu_frames, frame_rate=30, device="cuda")
111+
encoder.to_file(dest=output_path, codec="h264_nvenc", extra_options={"qp": 1})
112+
else:
113+
encoder = VideoEncoder(frames=frames, frame_rate=30, device="cpu")
114+
encoder.to_file(dest=output_path, codec="libx264", crf=0)
115+
116+
117+
def write_raw_frames(frames, raw_path):
118+
# Convert NCHW to NHWC for raw video format
119+
raw_frames = frames.permute(0, 2, 3, 1).contiguous()
120+
with open(raw_path, "wb") as f:
121+
f.write(raw_frames.cpu().numpy().tobytes())
122+
123+
124+
def encode_ffmpeg_cli(raw_path, frames_shape, output_path, device="cpu", codec=None):
125+
height, width = frames_shape[2], frames_shape[3]
126+
127+
if device == "cuda":
128+
codec = "h264_nvenc"
129+
quality_params = ["-qp", "0"]
130+
else:
131+
codec = "libx264"
132+
quality_params = ["-crf", "0"]
133+
134+
ffmpeg_cmd = [
135+
"ffmpeg",
136+
"-y",
137+
"-f",
138+
"rawvideo",
139+
"-pix_fmt",
140+
"rgb24",
141+
"-s",
142+
f"{width}x{height}",
143+
"-r",
144+
"30", # frame_rate is 30
145+
"-i",
146+
raw_path,
147+
"-c:v",
148+
codec,
149+
"-pix_fmt",
150+
"yuv420p",
151+
]
152+
ffmpeg_cmd.extend(quality_params)
153+
# By not setting threads, allow FFmpeg to choose.
154+
# ffmpeg_cmd.extend(["-threads", "1"])
155+
ffmpeg_cmd.extend([str(output_path)])
156+
157+
subprocess.run(ffmpeg_cmd, check=True, capture_output=True)
158+
159+
160+
def main():
161+
parser = ArgumentParser()
162+
parser.add_argument(
163+
"--path", type=str, help="Path to input video file", default=DEFAULT_VIDEO_PATH
164+
)
165+
parser.add_argument(
166+
"--average-over",
167+
type=int,
168+
default=DEFAULT_AVERAGE_OVER,
169+
help="Number of runs to average over",
170+
)
171+
parser.add_argument(
172+
"--max-frames",
173+
type=int,
174+
default=DEFAULT_MAX_FRAMES,
175+
help="Maximum number of frames to decode for benchmarking",
176+
)
177+
178+
args = parser.parse_args()
179+
180+
print(
181+
f"Benchmarking up to {args.max_frames} frames from {Path(args.path).name} over {args.average_over} runs:"
182+
)
183+
cuda_available = torch.cuda.is_available()
184+
if not cuda_available:
185+
print("CUDA not available. GPU benchmarks will be skipped.")
186+
187+
# Load up to max_frames frames
188+
decoder = VideoDecoder(str(args.path))
189+
frames = decoder.get_frames_in_range(
190+
start=0, stop=min(args.max_frames, len(decoder))
191+
).data
192+
print(
193+
f"Loaded {frames.shape[0]} frames of size {frames.shape[2]}x{frames.shape[3]}"
194+
)
195+
196+
with tempfile.TemporaryDirectory() as temp_dir:
197+
temp_dir = Path(temp_dir)
198+
raw_frames_path = temp_dir / "input_frames.raw"
199+
write_raw_frames(frames, str(raw_frames_path))
200+
201+
# Benchmark torchcodec on GPU
202+
if cuda_available:
203+
gpu_output = temp_dir / "torchcodec_gpu.mp4"
204+
times, _cpu_utils, gpu_utils = bench(
205+
encode_torchcodec,
206+
frames=frames,
207+
output_path=str(gpu_output),
208+
device="cuda",
209+
average_over=args.average_over,
210+
warmup=1,
211+
)
212+
report_stats(
213+
times, frames.shape[0], None, gpu_utils, prefix="VideoEncoder on GPU"
214+
)
215+
else:
216+
print("Skipping VideoEncoder GPU benchmark (CUDA not available)")
217+
218+
# Benchmark FFmpeg CLI on GPU
219+
if cuda_available:
220+
ffmpeg_gpu_output = temp_dir / "ffmpeg_gpu.mp4"
221+
times, _cpu_utils, gpu_utils = bench(
222+
encode_ffmpeg_cli,
223+
raw_path=str(raw_frames_path),
224+
frames_shape=frames.shape,
225+
output_path=str(ffmpeg_gpu_output),
226+
device="cuda",
227+
average_over=args.average_over,
228+
warmup=1,
229+
)
230+
report_stats(
231+
times, frames.shape[0], None, gpu_utils, prefix="FFmpeg CLI on GPU"
232+
)
233+
else:
234+
print("Skipping FFmpeg CLI GPU benchmark (CUDA not available)")
235+
236+
# Benchmark torchcodec on CPU
237+
cpu_output = temp_dir / "torchcodec_cpu.mp4"
238+
times, cpu_utils, _gpu_utils = bench(
239+
encode_torchcodec,
240+
frames=frames,
241+
output_path=str(cpu_output),
242+
device="cpu",
243+
average_over=args.average_over,
244+
warmup=1,
245+
)
246+
report_stats(
247+
times, frames.shape[0], cpu_utils, None, prefix="VideoEncoder on CPU"
248+
)
249+
250+
# Benchmark FFmpeg CLI on CPU
251+
ffmpeg_cpu_output = temp_dir / "ffmpeg_cpu.mp4"
252+
times, cpu_utils, _gpu_utils = bench(
253+
encode_ffmpeg_cli,
254+
raw_path=str(raw_frames_path),
255+
frames_shape=frames.shape,
256+
output_path=str(ffmpeg_cpu_output),
257+
device="cpu",
258+
average_over=args.average_over,
259+
warmup=1,
260+
)
261+
report_stats(
262+
times, frames.shape[0], cpu_utils, None, prefix="FFmpeg CLI on CPU"
263+
)
264+
265+
266+
if __name__ == "__main__":
267+
main()

0 commit comments

Comments
 (0)