Skip to content

Commit 400001a

Browse files
committed
.
1 parent a926934 commit 400001a

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed

examples/basic_cuda_example.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""
7+
Accelerated video decoding with NVDEC
8+
=====================================
9+
10+
.. _nvdec_tutorial:
11+
12+
**Author**: `Ahmad Sharif <[email protected]>`__
13+
14+
This tutorial shows how to use NVIDIA’s hardware video decoder (NVDEC)
15+
with TorchCodec. This decoder is called CUDA decoder in the documentation
16+
and APIs.
17+
18+
To use the CUDA decoder, you have to have the following installed in your
19+
environment:
20+
* NVDEC-enabled FFMPEG
21+
* libnpp
22+
* CUDA-enabled pytorch
23+
24+
FFMPEG versions 5, 6 and 7 from conda-forge are built with NVDEC support and
25+
you can install them by running (for example to install ffmpeg version 7):
26+
27+
.. code-block:: bash
28+
29+
conda install ffmpeg=7 -c conda-forge
30+
conda install libnpp -c nvidia
31+
"""
32+
33+
# %%
34+
#
35+
# .. note::
36+
#
37+
# This tutorial requires FFmpeg libraries compiled with CUDA support.
38+
#
39+
#
40+
import torch
41+
42+
print(f"{torch.__version__=}")
43+
print(f"{torch.cuda.is_available()=}")
44+
print(f"{torch.cuda.get_device_properties(0)=}")
45+
46+
47+
# %%
48+
######################################################################
49+
# Downloading the video
50+
######################################################################
51+
#
52+
# We will use the following video which has the following properties;
53+
#
54+
# - Codec: H.264
55+
# - Resolution: 960x540
56+
# - FPS: 29.97
57+
# - Pixel format: YUV420P
58+
#
59+
# .. raw:: html
60+
#
61+
# <video style="max-width: 100%" controls>
62+
# <source src="https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4" type="video/mp4">
63+
# </video>
64+
import urllib.request
65+
66+
video_file = "video.mp4"
67+
urllib.request.urlretrieve(
68+
"https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4",
69+
video_file,
70+
)
71+
72+
73+
# %%
74+
######################################################################
75+
# Decoding with CUDA
76+
######################################################################
77+
#
78+
# To use CUDA decoder, you need to pass in a cuda device to the decoder.
79+
#
80+
from torchcodec.decoders import VideoDecoder
81+
82+
vd = VideoDecoder(video_file, device="cuda:0")
83+
frame = vd[0]
84+
85+
# %%
86+
#
87+
# The video frames are decoded and returned as tensor of NCHW format.
88+
89+
print(frame.data.shape, frame.data.dtype)
90+
91+
# %%
92+
#
93+
# The video frames are left on the GPU memory.
94+
95+
print(frame.data.device)
96+
97+
98+
# %%
99+
######################################################################
100+
# Visualizing Frames
101+
######################################################################
102+
#
103+
# Let's look at the frames decoded by CUDA decoder and compare them
104+
# against equivalent results from the CPU decoders.
105+
import matplotlib.pyplot as plt
106+
107+
108+
def get_frames(timestamps: list[float], device: str):
109+
decoder = VideoDecoder(video_file, device=device)
110+
return [decoder.get_frame_played_at(ts) for ts in timestamps]
111+
112+
113+
def get_numpy_images(frames):
114+
numpy_images = []
115+
for frame in frames:
116+
# We transfer to the CPU so they can be visualized by matplotlib.
117+
numpy_image = frame.data.to("cpu").permute(1, 2, 0).numpy()
118+
numpy_images.append(numpy_image)
119+
return numpy_images
120+
121+
122+
timestamps = [12, 19, 45, 131, 180]
123+
cpu_frames = get_frames(timestamps, device="cpu")
124+
cuda_frames = get_frames(timestamps, device="cuda:0")
125+
cpu_numpy_images = get_numpy_images(cpu_frames)
126+
cuda_numpy_images = get_numpy_images(cuda_frames)
127+
128+
129+
def plot_cpu_and_cuda():
130+
n_rows = len(timestamps)
131+
fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0])
132+
for i in range(n_rows):
133+
axes[i][0].imshow(cpu_numpy_images[i])
134+
axes[i][1].imshow(cuda_numpy_images[i])
135+
136+
axes[0][0].set_title("CPU decoder")
137+
axes[0][1].set_title("CUDA decoder")
138+
plt.setp(axes, xticks=[], yticks=[])
139+
plt.tight_layout()
140+
141+
142+
plot_cpu_and_cuda()
143+
144+
# %%
145+
#
146+
# They look visually similar to the human eye but there may be subtle
147+
# differences because CUDA math is not bit-exact to CPU math.
148+
#
149+
first_cpu_frame = cpu_frames[0].data.to("cpu")
150+
first_cuda_frame = cuda_frames[0].data.to("cpu")
151+
frames_equal = torch.equal(first_cpu_frame, first_cuda_frame)
152+
print(f"{frames_equal=}")

0 commit comments

Comments
 (0)