@@ -123,6 +123,7 @@ def get_frames_from_video(self, video_file, pts_list):
123123 decoder ,
124124 num_threads = self ._num_threads ,
125125 color_conversion_library = self ._color_conversion_library ,
126+ device = self ._device ,
126127 )
127128 metadata = json .loads (get_json_metadata (decoder ))
128129 best_video_stream = metadata ["bestVideoStreamIndex" ]
@@ -137,6 +138,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
137138 decoder ,
138139 num_threads = self ._num_threads ,
139140 color_conversion_library = self ._color_conversion_library ,
141+ device = self ._device ,
140142 )
141143
142144 frames = []
@@ -176,6 +178,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
176178 decoder ,
177179 num_threads = self ._num_threads ,
178180 color_conversion_library = self ._color_conversion_library ,
181+ device = self ._device ,
179182 )
180183
181184 frames = []
@@ -187,10 +190,11 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
187190
188191
189192class TorchCodecCoreBatch (AbstractDecoder ):
190- def __init__ (self , num_threads = None , color_conversion_library = None ):
193+ def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
191194 self ._print_each_iteration_time = False
192195 self ._num_threads = int (num_threads ) if num_threads else None
193196 self ._color_conversion_library = color_conversion_library
197+ self ._device = device
194198
195199 def get_frames_from_video (self , video_file , pts_list ):
196200 decoder = create_from_file (video_file )
@@ -199,6 +203,7 @@ def get_frames_from_video(self, video_file, pts_list):
199203 decoder ,
200204 num_threads = self ._num_threads ,
201205 color_conversion_library = self ._color_conversion_library ,
206+ device = self ._device ,
202207 )
203208 metadata = json .loads (get_json_metadata (decoder ))
204209 best_video_stream = metadata ["bestVideoStreamIndex" ]
@@ -214,6 +219,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
214219 decoder ,
215220 num_threads = self ._num_threads ,
216221 color_conversion_library = self ._color_conversion_library ,
222+ device = self ._device ,
217223 )
218224 metadata = json .loads (get_json_metadata (decoder ))
219225 best_video_stream = metadata ["bestVideoStreamIndex" ]
@@ -225,17 +231,22 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
225231
226232
227233class TorchCodecPublic (AbstractDecoder ):
228- def __init__ (self , num_ffmpeg_threads = None ):
234+ def __init__ (self , num_ffmpeg_threads = None , device = "cpu" ):
229235 self ._num_ffmpeg_threads = (
230236 int (num_ffmpeg_threads ) if num_ffmpeg_threads else None
231237 )
238+ self ._device = device
232239
233240 def get_frames_from_video (self , video_file , pts_list ):
234- decoder = VideoDecoder (video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads )
241+ decoder = VideoDecoder (
242+ video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads , device = self ._device
243+ )
235244 return decoder .get_frames_played_at (pts_list )
236245
237246 def get_consecutive_frames_from_video (self , video_file , numFramesToDecode ):
238- decoder = VideoDecoder (video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads )
247+ decoder = VideoDecoder (
248+ video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads , device = self ._device
249+ )
239250 frames = []
240251 count = 0
241252 for frame in decoder :
0 commit comments