|
19 | 19 | # :ref:`creating_decoder`. |
20 | 20 |
|
21 | 21 | from typing import Optional |
22 | | - |
23 | | -import requests |
24 | 22 | import torch |
| 23 | +import requests |
25 | 24 |
|
26 | 25 |
|
27 | 26 | # Video source: https://www.pexels.com/video/dog-eating-854132/ |
|
34 | 33 | raw_video_bytes = response.content |
35 | 34 |
|
36 | 35 |
|
37 | | -def plot(frames: torch.Tensor, title: Optional[str] = None): |
| 36 | +def plot(frames: torch.Tensor, title : Optional[str] = None): |
38 | 37 | try: |
39 | | - import matplotlib.pyplot as plt |
40 | | - from torchvision.transforms.v2.functional import to_pil_image |
41 | 38 | from torchvision.utils import make_grid |
| 39 | + from torchvision.transforms.v2.functional import to_pil_image |
| 40 | + import matplotlib.pyplot as plt |
42 | 41 | except ImportError: |
43 | 42 | print("Cannot plot, please run `pip install torchvision matplotlib`") |
44 | 43 | return |
45 | 44 |
|
46 | | - plt.rcParams["savefig.bbox"] = "tight" |
| 45 | + plt.rcParams["savefig.bbox"] = 'tight' |
47 | 46 | fig, ax = plt.subplots() |
48 | 47 | ax.imshow(to_pil_image(make_grid(frames))) |
49 | 48 | ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) |
@@ -77,7 +76,7 @@ def plot(frames: torch.Tensor, title: Optional[str] = None): |
77 | 76 | # --------------------------------------- |
78 | 77 |
|
79 | 78 | first_frame = decoder[0] # using a single int index |
80 | | -every_twenty_frame = decoder[0:-1:20] # using slices |
| 79 | +every_twenty_frame = decoder[0 : -1 : 20] # using slices |
81 | 80 |
|
82 | 81 | print(f"{first_frame.shape = }") |
83 | 82 | print(f"{first_frame.dtype = }") |
@@ -107,10 +106,9 @@ def plot(frames: torch.Tensor, title: Optional[str] = None): |
107 | 106 | # The decoder is a normal iterable object and can be iterated over like so: |
108 | 107 |
|
109 | 108 | for frame in decoder: |
110 | | - assert isinstance(frame, torch.Tensor) and frame.shape == ( |
111 | | - 3, |
112 | | - decoder.metadata.height, |
113 | | - decoder.metadata.width, |
| 109 | + assert ( |
| 110 | + isinstance(frame, torch.Tensor) |
| 111 | + and frame.shape == (3, decoder.metadata.height, decoder.metadata.width) |
114 | 112 | ) |
115 | 113 |
|
116 | 114 | # %% |
|
0 commit comments