11#!/usr/bin/env python3
2- # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3+ # All rights reserved.
4+ #
5+ # This source code is licensed under the BSD-style license found in the
6+ # LICENSE file in the root directory of this source tree.
37
48# pyre-strict
59
@@ -28,7 +32,7 @@ class TestMCH(unittest.TestCase):
2832 # pyre-ignore[56]
2933 @unittest .skipIf (
3034 torch .cuda .device_count () < 1 ,
31- "Not enough GPUs, this test requires at least two GPUs " ,
35+ "Not enough GPUs, this test requires at least one GPU " ,
3236 )
3337 def test_zch_hash_inference (self ) -> None :
3438 # prepare
@@ -143,11 +147,6 @@ def test_zch_hash_inference(self) -> None:
143147 f"{ torch .unique (m3 ._hash_zch_identities )= } " ,
144148 )
145149
146- # pyre-ignore[56]
147- @unittest .skipIf (
148- torch .cuda .device_count () < 1 ,
149- "This test requires CUDA device" ,
150- )
151150 def test_scriptability (self ) -> None :
152151 zch_size = 10
153152 mc_modules = {
@@ -180,11 +179,6 @@ def test_scriptability(self) -> None:
180179 )
181180 torch .jit .script (mcc_ec )
182181
183- # pyre-ignore[56]
184- @unittest .skipIf (
185- torch .cuda .device_count () < 1 ,
186- "This test requires CUDA device" ,
187- )
188182 def test_scriptability_lru (self ) -> None :
189183 zch_size = 10
190184 mc_modules = {
@@ -219,13 +213,13 @@ def test_scriptability_lru(self) -> None:
219213 torch .jit .script (mcc_ec )
220214
221215 @unittest .skipIf (
222- torch .cuda .device_count () < 1 ,
223- "Not enough GPUs, this test requires at least one GPUs" ,
216+ torch .cuda .device_count () < 2 ,
217+ "Not enough GPUs, this test requires at least two GPUs" ,
224218 )
225219 # pyre-ignore [56]
226220 @given (hash_size = st .sampled_from ([0 , 80 ]), keep_original_indices = st .booleans ())
227221 @settings (max_examples = 6 , deadline = None )
228- def test_zch_hash_train_to_inf_block_bucketize (
222+ def test_zch_hash_train_to_inf_block_bucketize_disabled_in_oss_compatibility (
229223 self , hash_size : int , keep_original_indices : bool
230224 ) -> None :
231225 # rank 0
@@ -298,13 +292,15 @@ def test_zch_hash_train_to_inf_block_bucketize(
298292 )
299293
300294 @unittest .skipIf (
301- torch .cuda .device_count () < 1 ,
302- "Not enough GPUs, this test requires at least one GPUs" ,
295+ torch .cuda .device_count () < 2 ,
296+ "Not enough GPUs, this test requires at least two GPUs" ,
303297 )
304298 # pyre-ignore [56]
305299 @given (hash_size = st .sampled_from ([0 , 80 ]))
306300 @settings (max_examples = 5 , deadline = None )
307- def test_zch_hash_train_rescales_two (self , hash_size : int ) -> None :
301+ def test_zch_hash_train_rescales_two_disabled_in_oss_compatibility (
302+ self , hash_size : int
303+ ) -> None :
308304 keep_original_indices = False
309305 # rank 0
310306 world_size = 2
@@ -410,13 +406,13 @@ def test_zch_hash_train_rescales_two(self, hash_size: int) -> None:
410406 )
411407
412408 @unittest .skipIf (
413- torch .cuda .device_count () < 1 ,
409+ torch .cuda .device_count () < 2 ,
414410 "Not enough GPUs, this test requires at least one GPUs" ,
415411 )
416412 # pyre-ignore [56]
417413 @given (hash_size = st .sampled_from ([0 , 80 ]))
418414 @settings (max_examples = 5 , deadline = None )
419- def test_zch_hash_train_rescales_four (self , hash_size : int ) -> None :
415+ def test_zch_hash_train_rescales_one (self , hash_size : int ) -> None :
420416 keep_original_indices = True
421417 kjt = KeyedJaggedTensor (
422418 keys = ["f" ],
@@ -452,23 +448,20 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
452448 ),
453449 )
454450
455- # start with world_size = 4
456- world_size = 4
451+ # start with world_size = 2
452+ world_size = 2
457453 block_sizes = torch .tensor (
458454 [(size + world_size - 1 ) // world_size for size in [hash_size ]],
459455 dtype = torch .int64 ,
460456 device = "cuda" ,
461457 )
462458
463- m1_1 = m0 .rebuild_with_output_id_range ((0 , 10 ))
464- m2_1 = m0 .rebuild_with_output_id_range ((10 , 20 ))
465- m3_1 = m0 .rebuild_with_output_id_range ((20 , 30 ))
466- m4_1 = m0 .rebuild_with_output_id_range ((30 , 40 ))
459+ m1_1 = m0 .rebuild_with_output_id_range ((0 , 20 ))
460+ m2_1 = m0 .rebuild_with_output_id_range ((20 , 40 ))
467461
468- # shard, now world size 2!
469- # start with world_size = 4
462+ # shard, now world size 1!
470463 if hash_size > 0 :
471- world_size = 2
464+ world_size = 1
472465 block_sizes = torch .tensor (
473466 [(size + world_size - 1 ) // world_size for size in [hash_size ]],
474467 dtype = torch .int64 ,
@@ -482,7 +475,7 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
482475 keep_original_indices = keep_original_indices ,
483476 output_permute = True ,
484477 )
485- in1_2 , in2_2 = bucketized_kjt .split ([len (kjt .keys ())] * world_size )
478+ in1_2 = bucketized_kjt .split ([len (kjt .keys ())] * world_size )[ 0 ]
486479 else :
487480 bucketized_kjt , permute = bucketize_kjt_before_all2all (
488481 kjt ,
@@ -498,14 +491,8 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
498491 values = torch .cat ([kjts [0 ].values (), kjts [1 ].values ()], dim = 0 ),
499492 lengths = torch .cat ([kjts [0 ].lengths (), kjts [1 ].lengths ()], dim = 0 ),
500493 )
501- in2_2 = KeyedJaggedTensor (
502- keys = kjts [2 ].keys (),
503- values = torch .cat ([kjts [2 ].values (), kjts [3 ].values ()], dim = 0 ),
504- lengths = torch .cat ([kjts [2 ].lengths (), kjts [3 ].lengths ()], dim = 0 ),
505- )
506494
507- m1_2 = m0 .rebuild_with_output_id_range ((0 , 20 ))
508- m2_2 = m0 .rebuild_with_output_id_range ((20 , 40 ))
495+ m1_2 = m0 .rebuild_with_output_id_range ((0 , 40 ))
509496 m1_zch_identities = torch .cat (
510497 [
511498 m1_1 .state_dict ()["_hash_zch_identities" ],
@@ -522,53 +509,30 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
522509 state_dict ["_hash_zch_identities" ] = m1_zch_identities
523510 state_dict ["_hash_zch_metadata" ] = m1_zch_metadata
524511 m1_2 .load_state_dict (state_dict )
525-
526- m2_zch_identities = torch .cat (
527- [
528- m3_1 .state_dict ()["_hash_zch_identities" ],
529- m4_1 .state_dict ()["_hash_zch_identities" ],
530- ]
531- )
532- m2_zch_metadata = torch .cat (
533- [
534- m3_1 .state_dict ()["_hash_zch_metadata" ],
535- m4_1 .state_dict ()["_hash_zch_metadata" ],
536- ]
537- )
538- state_dict = m2_2 .state_dict ()
539- state_dict ["_hash_zch_identities" ] = m2_zch_identities
540- state_dict ["_hash_zch_metadata" ] = m2_zch_metadata
541- m2_2 .load_state_dict (state_dict )
542-
543512 _ = m1_2 (in1_2 .to_dict ())
544- _ = m2_2 (in2_2 .to_dict ())
545513
546514 m0 .reset_inference_mode () # just clears out training state
547515 full_zch_identities = torch .cat (
548516 [
549517 m1_2 .state_dict ()["_hash_zch_identities" ],
550- m2_2 .state_dict ()["_hash_zch_identities" ],
551518 ]
552519 )
553520 state_dict = m0 .state_dict ()
554521 state_dict ["_hash_zch_identities" ] = full_zch_identities
555522 m0 .load_state_dict (state_dict )
556523
557- # now set all models to eval, and run kjt
558524 m1_2 .eval ()
559- m2_2 .eval ()
560525 assert m0 .training is False
561526
562527 inf_input = kjt .to_dict ()
563- inf_output = m0 (inf_input )
564528
529+ inf_output = m0 (inf_input )
565530 o1_2 = m1_2 (in1_2 .to_dict ())
566- o2_2 = m2_2 (in2_2 .to_dict ())
567531 self .assertTrue (
568532 torch .allclose (
569533 inf_output ["f" ].values (),
570534 torch .index_select (
571- torch . cat ([ x [ "f" ].values () for x in [ o1_2 , o2_2 ]] ),
535+ o1_2 [ "f" ].values (),
572536 dim = 0 ,
573537 index = cast (torch .Tensor , permute ),
574538 ),
@@ -578,7 +542,7 @@ def test_zch_hash_train_rescales_four(self, hash_size: int) -> None:
578542 # pyre-ignore[56]
579543 @unittest .skipIf (
580544 torch .cuda .device_count () < 1 ,
581- "This test requires CUDA device " ,
545+ "This test requires at least one GPU " ,
582546 )
583547 def test_output_global_offset_tensor (self ) -> None :
584548 m = HashZchManagedCollisionModule (
@@ -653,7 +617,7 @@ def test_output_global_offset_tensor(self) -> None:
653617 # pyre-ignore[56]
654618 @unittest .skipIf (
655619 torch .cuda .device_count () < 1 ,
656- "This test requires CUDA device " ,
620+ "This test requires at least one GPU " ,
657621 )
658622 def test_dynamically_switch_inference_training_mode (self ) -> None :
659623 m = HashZchManagedCollisionModule (
0 commit comments