Skip to content

Commit 88d9483

Browse files
authored
Merge pull request #196 from OpenPipe/tau_bench
pass internal config from run_training - tau_bench
2 parents 82d5c48 + 0980a15 commit 88d9483

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

dev/tau-bench/run_rl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,12 @@ def main():
400400
# The nested `config` needs to be converted back into the proper pydantic model.
401401
model_dict["config"] = TauBenchPolicyConfig(**model_dict["config"])
402402

403+
# the nested "_internal_config" needs to be converted back into the proper pydantic model.
404+
if "_internal_config" in model_dict:
405+
model_dict["_internal_config"] = art.dev.InternalModelConfig(
406+
**model_dict["_internal_config"]
407+
)
408+
403409
model: art.TrainableModel[TauBenchPolicyConfig] = art.TrainableModel(**model_dict)
404410
model.config.run_config.model = (
405411
model.name

dev/tau-bench/run_training.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@
5454
trainable_models["002"].config.training_config.training_dataset_size = 4
5555
trainable_models["002"].config.training_config.learning_rate = 5e-6
5656

57+
# v high lr, v low gn, because twitter said so
58+
trainable_models["003"] = trainable_models["002"].model_copy(deep=True)
59+
assert trainable_models["003"].config.training_config is not None
60+
trainable_models["003"].name = "tau-bench-rl-003-tm"
61+
trainable_models["003"].config.training_config.learning_rate = 1e-2
62+
trainable_models["003"]._internal_config = art.dev.InternalModelConfig(
63+
trainer_args=art.dev.TrainerArgs(
64+
max_grad_norm=1e-7,
65+
)
66+
)
5767

5868
parser = argparse.ArgumentParser(
5969
description="Train one or more tau-bench RL models (comma separated)."
@@ -73,10 +83,10 @@
7383

7484
# Parse and validate the requested model keys
7585
requested_models = [m.strip() for m in args.models.split(",") if m.strip()]
76-
unknown = [m for m in requested_models if m not in models]
86+
unknown = [m for m in requested_models if m not in trainable_models]
7787
if unknown:
7888
raise ValueError(
79-
f"Unknown model keys requested: {', '.join(unknown)}. Valid keys: {', '.join(models.keys())}"
89+
f"Unknown model keys requested: {', '.join(unknown)}. Valid keys: {', '.join(trainable_models.keys())}"
8090
)
8191

8292

0 commit comments

Comments
 (0)