@@ -49,7 +49,7 @@ class MAEMetricTest(unittest.TestCase):
4949 clazz : Type [RecMetric ] = MAEMetric
5050 task_name : str = "mae"
5151
52- def test_unfused_mae (self ) -> None :
52+ def test_mae_unfused (self ) -> None :
5353 rec_metric_value_test_launcher (
5454 target_clazz = MAEMetric ,
5555 target_compute_mode = RecComputeMode .UNFUSED_TASKS_COMPUTATION ,
@@ -63,7 +63,7 @@ def test_unfused_mae(self) -> None:
6363 entry_point = metric_test_helper ,
6464 )
6565
66- def test_fused_mae (self ) -> None :
66+ def test_mae_fused_tasks (self ) -> None :
6767 rec_metric_value_test_launcher (
6868 target_clazz = MAEMetric ,
6969 target_compute_mode = RecComputeMode .FUSED_TASKS_COMPUTATION ,
@@ -77,6 +77,20 @@ def test_fused_mae(self) -> None:
7777 entry_point = metric_test_helper ,
7878 )
7979
80+ def test_mae_fused_tasks_and_states (self ) -> None :
81+ rec_metric_value_test_launcher (
82+ target_clazz = MAEMetric ,
83+ target_compute_mode = RecComputeMode .FUSED_TASKS_AND_STATES_COMPUTATION ,
84+ test_clazz = TestMAEMetric ,
85+ metric_name = "mae" ,
86+ task_names = ["t1" , "t2" , "t3" ],
87+ fused_update_limit = 0 ,
88+ compute_on_all_ranks = False ,
89+ should_validate_update = False ,
90+ world_size = WORLD_SIZE ,
91+ entry_point = metric_test_helper ,
92+ )
93+
8094
8195class MAEGPUSyncTest (unittest .TestCase ):
8296 clazz : Type [RecMetric ] = MAEMetric
0 commit comments