File tree Expand file tree Collapse file tree 2 files changed +18
-2
lines changed Expand file tree Collapse file tree 2 files changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -400,6 +400,12 @@ def main():
400400 # The nested `config` needs to be converted back into the proper pydantic model.
401401 model_dict ["config" ] = TauBenchPolicyConfig (** model_dict ["config" ])
402402
403+ # the nested "_internal_config" needs to be converted back into the proper pydantic model.
404+ if "_internal_config" in model_dict :
405+ model_dict ["_internal_config" ] = art .dev .InternalModelConfig (
406+ ** model_dict ["_internal_config" ]
407+ )
408+
403409 model : art .TrainableModel [TauBenchPolicyConfig ] = art .TrainableModel (** model_dict )
404410 model .config .run_config .model = (
405411 model .name
Original file line number Diff line number Diff line change 5454trainable_models ["002" ].config .training_config .training_dataset_size = 4
5555trainable_models ["002" ].config .training_config .learning_rate = 5e-6
5656
57+ # v high lr, v low gn, because twitter said so
58+ trainable_models ["003" ] = trainable_models ["002" ].model_copy (deep = True )
59+ assert trainable_models ["003" ].config .training_config is not None
60+ trainable_models ["003" ].name = "tau-bench-rl-003-tm"
61+ trainable_models ["003" ].config .training_config .learning_rate = 1e-2
62+ trainable_models ["003" ]._internal_config = art .dev .InternalModelConfig (
63+ trainer_args = art .dev .TrainerArgs (
64+ max_grad_norm = 1e-7 ,
65+ )
66+ )
5767
5868parser = argparse .ArgumentParser (
5969 description = "Train one or more tau-bench RL models (comma separated)."
7383
7484# Parse and validate the requested model keys
7585requested_models = [m .strip () for m in args .models .split ("," ) if m .strip ()]
76- unknown = [m for m in requested_models if m not in models ]
86+ unknown = [m for m in requested_models if m not in trainable_models ]
7787if unknown :
7888 raise ValueError (
79- f"Unknown model keys requested: { ', ' .join (unknown )} . Valid keys: { ', ' .join (models .keys ())} "
89+ f"Unknown model keys requested: { ', ' .join (unknown )} . Valid keys: { ', ' .join (trainable_models .keys ())} "
8090 )
8191
8292
You can’t perform that action at this time.
0 commit comments