Skip to content

Commit c7724e2

Browse files
authored
Add CUDA devices to the benchmark (#344)
1 parent af13ac5 commit c7724e2

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def get_frames_from_video(self, video_file, pts_list):
123123
decoder,
124124
num_threads=self._num_threads,
125125
color_conversion_library=self._color_conversion_library,
126+
device=self._device,
126127
)
127128
metadata = json.loads(get_json_metadata(decoder))
128129
best_video_stream = metadata["bestVideoStreamIndex"]
@@ -137,6 +138,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
137138
decoder,
138139
num_threads=self._num_threads,
139140
color_conversion_library=self._color_conversion_library,
141+
device=self._device,
140142
)
141143

142144
frames = []
@@ -176,6 +178,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
176178
decoder,
177179
num_threads=self._num_threads,
178180
color_conversion_library=self._color_conversion_library,
181+
device=self._device,
179182
)
180183

181184
frames = []
@@ -187,10 +190,11 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
187190

188191

189192
class TorchCodecCoreBatch(AbstractDecoder):
190-
def __init__(self, num_threads=None, color_conversion_library=None):
193+
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
191194
self._print_each_iteration_time = False
192195
self._num_threads = int(num_threads) if num_threads else None
193196
self._color_conversion_library = color_conversion_library
197+
self._device = device
194198

195199
def get_frames_from_video(self, video_file, pts_list):
196200
decoder = create_from_file(video_file)
@@ -199,6 +203,7 @@ def get_frames_from_video(self, video_file, pts_list):
199203
decoder,
200204
num_threads=self._num_threads,
201205
color_conversion_library=self._color_conversion_library,
206+
device=self._device,
202207
)
203208
metadata = json.loads(get_json_metadata(decoder))
204209
best_video_stream = metadata["bestVideoStreamIndex"]
@@ -214,6 +219,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
214219
decoder,
215220
num_threads=self._num_threads,
216221
color_conversion_library=self._color_conversion_library,
222+
device=self._device,
217223
)
218224
metadata = json.loads(get_json_metadata(decoder))
219225
best_video_stream = metadata["bestVideoStreamIndex"]
@@ -225,17 +231,22 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
225231

226232

227233
class TorchCodecPublic(AbstractDecoder):
228-
def __init__(self, num_ffmpeg_threads=None):
234+
def __init__(self, num_ffmpeg_threads=None, device="cpu"):
229235
self._num_ffmpeg_threads = (
230236
int(num_ffmpeg_threads) if num_ffmpeg_threads else None
231237
)
238+
self._device = device
232239

233240
def get_frames_from_video(self, video_file, pts_list):
234-
decoder = VideoDecoder(video_file, num_ffmpeg_threads=self._num_ffmpeg_threads)
241+
decoder = VideoDecoder(
242+
video_file, num_ffmpeg_threads=self._num_ffmpeg_threads, device=self._device
243+
)
235244
return decoder.get_frames_played_at(pts_list)
236245

237246
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
238-
decoder = VideoDecoder(video_file, num_ffmpeg_threads=self._num_ffmpeg_threads)
247+
decoder = VideoDecoder(
248+
video_file, num_ffmpeg_threads=self._num_ffmpeg_threads, device=self._device
249+
)
239250
frames = []
240251
count = 0
241252
for frame in decoder:

0 commit comments

Comments
 (0)