1111 clips_at_regular_timestamps ,
1212)
1313
14+ DEFAULT_VIDEO_PATH = Path (__file__ ).parent / "../../test/resources/nasa_13013.mp4"
15+ DEFAULT_NUM_EXP = 30
1416
15- def bench (f , * args , num_exp = 100 , warmup = 0 , ** kwargs ):
17+
18+ def bench (f , * args , num_exp , warmup = 0 , seed , ** kwargs ):
1619
1720 for _ in range (warmup ):
1821 f (* args , ** kwargs )
1922
2023 num_frames = None
2124 times = []
2225 for _ in range (num_exp ):
26+ if seed is not None :
27+ torch .manual_seed (seed )
2328 start = perf_counter_ns ()
2429 clips = f (* args , ** kwargs )
2530 end = perf_counter_ns ()
@@ -54,8 +59,7 @@ def sample(decoder, sampler, **kwargs):
5459 )
5560
5661
57- def run_sampler_benchmarks (device , video ):
58- NUM_EXP = 30
62+ def run_sampler_benchmarks (device , video , num_experiments , torch_seed ):
5963
6064 for num_clips in (1 , 50 ):
6165 print ("-" * 10 )
@@ -68,8 +72,9 @@ def run_sampler_benchmarks(device, video):
6872 decoder ,
6973 clips_at_random_indices ,
7074 num_clips = num_clips ,
71- num_exp = NUM_EXP ,
75+ num_exp = num_experiments ,
7276 warmup = 2 ,
77+ seed = torch_seed ,
7378 )
7479 report_stats (times , num_frames , unit = "ms" )
7580
@@ -79,8 +84,9 @@ def run_sampler_benchmarks(device, video):
7984 decoder ,
8085 clips_at_regular_indices ,
8186 num_clips = num_clips ,
82- num_exp = NUM_EXP ,
87+ num_exp = num_experiments ,
8388 warmup = 2 ,
89+ seed = torch_seed ,
8490 )
8591 report_stats (times , num_frames , unit = "ms" )
8692
@@ -90,8 +96,9 @@ def run_sampler_benchmarks(device, video):
9096 decoder ,
9197 clips_at_random_timestamps ,
9298 num_clips = num_clips ,
93- num_exp = NUM_EXP ,
99+ num_exp = num_experiments ,
94100 warmup = 2 ,
101+ seed = torch_seed ,
95102 )
96103 report_stats (times , num_frames , unit = "ms" )
97104
@@ -102,19 +109,23 @@ def run_sampler_benchmarks(device, video):
102109 decoder ,
103110 clips_at_regular_timestamps ,
104111 seconds_between_clip_starts = seconds_between_clip_starts ,
105- num_exp = NUM_EXP ,
112+ num_exp = num_experiments ,
106113 warmup = 2 ,
114+ seed = torch_seed ,
107115 )
108116 report_stats (times , num_frames , unit = "ms" )
109117
110118
111119def main ():
112- DEFAULT_VIDEO_PATH = Path (__file__ ).parent / "../../test/resources/nasa_13013.mp4"
113120 parser = argparse .ArgumentParser ()
114121 parser .add_argument ("--device" , type = str , default = "cpu" )
115122 parser .add_argument ("--video" , type = str , default = str (DEFAULT_VIDEO_PATH ))
123+ parser .add_argument ("--num_experiments" , type = int , default = DEFAULT_NUM_EXP )
124+ parser .add_argument ("--torch_seed" , type = int )
116125 args = parser .parse_args ()
117- run_sampler_benchmarks (args .device , args .video )
126+ run_sampler_benchmarks (
127+ args .device , args .video , args .num_experiments , args .torch_seed
128+ )
118129
119130
120131if __name__ == "__main__" :
0 commit comments