1313import numpy as np
1414import yaml
1515
16- RTOL = 0.08
16+ DEFAULT_RTOL = 0.08
1717
1818
1919def launch_lm_eval (eval_config , tp_size ):
2020 trust_remote_code = eval_config .get ("trust_remote_code" , False )
2121 max_model_len = eval_config .get ("max_model_len" , 4096 )
2222 batch_size = eval_config .get ("batch_size" , "auto" )
2323 backend = eval_config .get ("backend" , "vllm" )
24- model_args = (
25- f"pretrained={ eval_config ['model_name' ]} ,"
26- f"tensor_parallel_size={ tp_size } ,"
27- f"enforce_eager=true,"
28- f"add_bos_token=true,"
29- f"trust_remote_code={ trust_remote_code } ,"
30- f"max_model_len={ max_model_len } ,"
31- )
24+
25+ model_args_list = [
26+ f"pretrained={ eval_config ['model_name' ]} " ,
27+ f"tensor_parallel_size={ tp_size } " ,
28+ "enforce_eager=true" ,
29+ "add_bos_token=true" ,
30+ f"trust_remote_code={ trust_remote_code } " ,
31+ f"max_model_len={ max_model_len } " ,
32+ ]
33+
34+ if "vllm_args" in eval_config :
35+ for key , value in eval_config ["vllm_args" ].items ():
36+ if isinstance (value , bool ):
37+ value = str (value ).lower ()
38+ model_args_list .append (f"{ key } ={ value } " )
39+
40+ model_args = "," .join (model_args_list )
41+
3242 results = lm_eval .simple_evaluate (
3343 model = backend ,
3444 model_args = model_args ,
@@ -49,15 +59,18 @@ def test_lm_eval_correctness_param(config_filename, tp_size):
4959
5060 results = launch_lm_eval (eval_config , tp_size )
5161
62+ rtol = eval_config .get ("rtol" , DEFAULT_RTOL )
63+
5264 success = True
5365 for task in eval_config ["tasks" ]:
5466 for metric in task ["metrics" ]:
5567 ground_truth = metric ["value" ]
5668 measured_value = results ["results" ][task ["name" ]][metric ["name" ]]
5769 print (
5870 f"{ task ['name' ]} | { metric ['name' ]} : "
59- f"ground_truth={ ground_truth } | measured={ measured_value } "
71+ f"ground_truth={ ground_truth :.3f} | "
72+ f"measured={ measured_value :.3f} | rtol={ rtol } "
6073 )
61- success = success and np .isclose (ground_truth , measured_value , rtol = RTOL )
74+ success = success and np .isclose (ground_truth , measured_value , rtol = rtol )
6275
6376 assert success
0 commit comments