3838)
3939
4040
41+ def _get_number_of_gpu_sm () -> int :
42+ if not torch .cuda .is_available ():
43+ raise RuntimeError ("CUDA is not available" )
44+ device_props = torch .cuda .get_device_properties (0 )
45+ return device_props .multi_processor_count
46+
47+
4148def _str_1d_tensor (t : torch .Tensor ) -> str :
4249 sl = [f"{ x :7.4f} " for x in t .tolist ()]
4350 if len (sl ) > 5 :
@@ -48,6 +55,7 @@ def _str_1d_tensor(t: torch.Tensor) -> str:
4855def _do_test_all_to_all (
4956 pgi : ProcessGroupInfo ,
5057 dp_size : int ,
58+ max_sm_count : int ,
5159 moe : MoEConfig ,
5260 internode : bool ,
5361) -> None :
@@ -79,6 +87,7 @@ def _do_test_all_to_all(
7987 * torch .float32 .itemsize
8088 )
8189 ),
90+ max_sm_count = max_sm_count ,
8291 )
8392 else :
8493 ata = AllToAll .intranode (
@@ -99,6 +108,7 @@ def _do_test_all_to_all(
99108 * torch .float32 .itemsize
100109 )
101110 ),
111+ max_sm_count = max_sm_count ,
102112 )
103113
104114 # Generate the same test data on all ranks
@@ -283,6 +293,7 @@ def _worker_test_all_to_all(
283293 dp_size : int ,
284294 in_dtype : str ,
285295 out_dtype : str ,
296+ max_sm_count : int ,
286297 moe_config : MoEConfig ,
287298 internode : bool ,
288299) -> None :
@@ -295,16 +306,17 @@ def _worker_test_all_to_all(
295306 in_dtype = getattr (torch , in_dtype ),
296307 out_dtype = getattr (torch , out_dtype ),
297308 )
298- _do_test_all_to_all (pgi , dp_size , moe_config , internode )
309+ _do_test_all_to_all (pgi , dp_size , max_sm_count , moe_config , internode )
299310
300311 nvshmem_finalize ()
301312
302313
303314@pytest .mark .skipif (torch .cuda .device_count () < 4 , reason = "Requires at least 4 GPUs" )
304315@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
305316@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
317+ @pytest .mark .parametrize ("max_sm_count" , [_get_number_of_gpu_sm (),_get_number_of_gpu_sm ()// 2 ])
306318@pytest .mark .parametrize ("internode" , [True , False ])
307- def test_all_to_all_4_gpu (in_dtype : str , out_dtype : str , internode : bool ) -> None :
319+ def test_all_to_all_4_gpu (in_dtype : str , out_dtype : str , max_sm_count : int , internode : bool ) -> None :
308320 world_size = 4
309321 dp_size = 2
310322 parallel_launch (
@@ -313,6 +325,7 @@ def test_all_to_all_4_gpu(in_dtype: str, out_dtype: str, internode: bool) -> Non
313325 dp_size ,
314326 in_dtype ,
315327 out_dtype ,
328+ max_sm_count ,
316329 small_moe ,
317330 internode ,
318331 )
@@ -322,13 +335,15 @@ def _worker_test_all_to_all_multi_node(
322335 pgi : ProcessGroupInfo ,
323336 in_dtype : str ,
324337 out_dtype : str ,
338+ max_sm_count : int ,
325339) -> None :
326340 dp_size = 4
327341 _worker_test_all_to_all (
328342 pgi ,
329343 dp_size ,
330344 in_dtype ,
331345 out_dtype ,
346+ max_sm_count ,
332347 medium_moe ,
333348 True ,
334349 )
@@ -338,4 +353,5 @@ def _worker_test_all_to_all_multi_node(
338353@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
339354@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
340355def test_all_to_all_multi_node (in_dtype : str , out_dtype : str ) -> None :
341- parallel_launch_from_env (_worker_test_all_to_all_multi_node , in_dtype , out_dtype )
356+ max_sm_count = _get_number_of_gpu_sm ()
357+ parallel_launch_from_env (_worker_test_all_to_all_multi_node , in_dtype , out_dtype , max_sm_count )
0 commit comments