|
| 1 | +# Copyright OpenPipe |
| 2 | + |
| 3 | +import argparse |
| 4 | +import asyncio |
| 5 | +import json |
| 6 | +import os |
| 7 | +from datetime import datetime |
| 8 | +from typing import List, Dict, Any |
| 9 | +from dotenv import load_dotenv |
| 10 | + |
| 11 | +import art |
| 12 | +from art.utils import iterate_dataset |
| 13 | +from tau_bench.types import RunConfig, TauBenchPolicyConfig, TauBenchTrainingConfig |
| 14 | +from tau_bench.envs import get_env |
| 15 | +from tau_bench.run import display_metrics |
| 16 | +from tau_bench.types import EnvRunResult |
| 17 | +from litellm import provider_list |
| 18 | +from tau_bench.envs.user import UserStrategy |
| 19 | + |
| 20 | +# Import evaluate_model and rollout functions from run_rl |
| 21 | +from run_rl import evaluate_model, rollout_tau_bench_task |
| 22 | + |
| 23 | +# Load environment variables |
| 24 | +load_dotenv(override=True) |
| 25 | + |
| 26 | + |
| 27 | +def parse_args() -> tuple[RunConfig, argparse.Namespace]: |
| 28 | + """Parse command line arguments for benchmarking""" |
| 29 | + parser = argparse.ArgumentParser( |
| 30 | + description="Benchmark off-the-shelf models on tau-bench using RL evaluation" |
| 31 | + ) |
| 32 | + |
| 33 | + # Model configuration |
| 34 | + parser.add_argument( |
| 35 | + "--models", |
| 36 | + type=str, |
| 37 | + nargs="+", |
| 38 | + default=["gpt-4o"], |
| 39 | + help="List of models to benchmark (default: gpt-4o)", |
| 40 | + ) |
| 41 | + parser.add_argument( |
| 42 | + "--model-providers", |
| 43 | + type=str, |
| 44 | + nargs="+", |
| 45 | + default=["openai"], |
| 46 | + choices=provider_list, |
| 47 | + help="List of model providers corresponding to each model", |
| 48 | + ) |
| 49 | + |
| 50 | + # Environment configuration |
| 51 | + parser.add_argument( |
| 52 | + "--env", type=str, choices=["retail", "airline"], default="retail" |
| 53 | + ) |
| 54 | + parser.add_argument( |
| 55 | + "--user-model", |
| 56 | + type=str, |
| 57 | + default="gpt-4o", |
| 58 | + help="The model to use for the user simulator", |
| 59 | + ) |
| 60 | + parser.add_argument( |
| 61 | + "--user-model-provider", |
| 62 | + type=str, |
| 63 | + default="openai", |
| 64 | + choices=provider_list, |
| 65 | + help="The model provider for the user simulator", |
| 66 | + ) |
| 67 | + parser.add_argument( |
| 68 | + "--user-strategy", |
| 69 | + type=str, |
| 70 | + default="llm", |
| 71 | + choices=[item.value for item in UserStrategy], |
| 72 | + ) |
| 73 | + |
| 74 | + # Task configuration |
| 75 | + parser.add_argument( |
| 76 | + "--task-split", |
| 77 | + type=str, |
| 78 | + default="test", |
| 79 | + choices=["train", "test", "dev"], |
| 80 | + help="The split of tasks to benchmark on", |
| 81 | + ) |
| 82 | + parser.add_argument("--start-index", type=int, default=0) |
| 83 | + parser.add_argument( |
| 84 | + "--end-index", type=int, default=100, help="End index for benchmark tasks" |
| 85 | + ) |
| 86 | + parser.add_argument( |
| 87 | + "--task-ids", |
| 88 | + type=int, |
| 89 | + nargs="+", |
| 90 | + help="(Optional) run only the tasks with the given IDs", |
| 91 | + ) |
| 92 | + |
| 93 | + # Evaluation configuration |
| 94 | + parser.add_argument( |
| 95 | + "--temperature", |
| 96 | + type=float, |
| 97 | + default=1.0, |
| 98 | + help="The sampling temperature for the models", |
| 99 | + ) |
| 100 | + parser.add_argument( |
| 101 | + "--max-num-steps", |
| 102 | + type=int, |
| 103 | + default=30, |
| 104 | + help="Maximum number of steps per rollout", |
| 105 | + ) |
| 106 | + parser.add_argument( |
| 107 | + "--num-trials", |
| 108 | + type=int, |
| 109 | + default=1, |
| 110 | + help="Number of trials to run for each task", |
| 111 | + ) |
| 112 | + |
| 113 | + # Output configuration |
| 114 | + parser.add_argument("--log-dir", type=str, default="benchmark_results") |
| 115 | + parser.add_argument("--seed", type=int, default=10) |
| 116 | + |
| 117 | + args = parser.parse_args() |
| 118 | + |
| 119 | + # Ensure model providers match models |
| 120 | + if len(args.model_providers) == 1 and len(args.models) > 1: |
| 121 | + args.model_providers = args.model_providers * len(args.models) |
| 122 | + elif len(args.model_providers) != len(args.models): |
| 123 | + raise ValueError( |
| 124 | + "Number of model providers must match number of models or be 1" |
| 125 | + ) |
| 126 | + |
| 127 | + # Create RunConfig |
| 128 | + run_config = RunConfig( |
| 129 | + model_provider=args.model_providers[0], # Will be updated per model |
| 130 | + user_model_provider=args.user_model_provider, |
| 131 | + model=args.models[0], # Will be updated per model |
| 132 | + user_model=args.user_model, |
| 133 | + num_trials=args.num_trials, |
| 134 | + env=args.env, |
| 135 | + agent_strategy="tool-calling-rl", |
| 136 | + temperature=args.temperature, |
| 137 | + task_split=args.task_split, |
| 138 | + start_index=args.start_index, |
| 139 | + end_index=args.end_index, |
| 140 | + task_ids=args.task_ids, |
| 141 | + log_dir=args.log_dir, |
| 142 | + max_concurrency=50, |
| 143 | + seed=args.seed, |
| 144 | + shuffle=0, |
| 145 | + user_strategy=args.user_strategy, |
| 146 | + max_num_steps=args.max_num_steps, |
| 147 | + reward_type="real", |
| 148 | + ) |
| 149 | + |
| 150 | + return run_config, args |
| 151 | + |
| 152 | + |
| 153 | +async def benchmark_model( |
| 154 | + model_name: str, |
| 155 | + model_provider: str, |
| 156 | + config: RunConfig, |
| 157 | + task_indices: List[int], |
| 158 | + num_trials: int, |
| 159 | +) -> Dict[str, Any]: |
| 160 | + """Benchmark a single model on the given tasks""" |
| 161 | + print(f"\n{'=' * 60}") |
| 162 | + print(f"Benchmarking model: {model_name} (provider: {model_provider})") |
| 163 | + print(f"{'=' * 60}") |
| 164 | + |
| 165 | + # Update config for this model |
| 166 | + config.model = model_name |
| 167 | + config.model_provider = model_provider |
| 168 | + |
| 169 | + # Create a mock trainable model for evaluation |
| 170 | + model = art.Model( |
| 171 | + name=model_name, |
| 172 | + config=TauBenchPolicyConfig( |
| 173 | + run_config=config, |
| 174 | + training_config=None, # No training config needed for evaluation |
| 175 | + ), |
| 176 | + ) |
| 177 | + |
| 178 | + # Store results for each trial |
| 179 | + all_results = [] |
| 180 | + trial_rewards = {} |
| 181 | + |
| 182 | + for trial in range(num_trials): |
| 183 | + print(f"\nTrial {trial + 1}/{num_trials}") |
| 184 | + |
| 185 | + # Run evaluation for this trial |
| 186 | + trial_results = [] |
| 187 | + total_reward = 0.0 |
| 188 | + |
| 189 | + # Collect trajectories for all tasks in this trial |
| 190 | + trajectories = [] |
| 191 | + for task_idx in task_indices: |
| 192 | + traj = await rollout_tau_bench_task( |
| 193 | + model=model, |
| 194 | + task_index=task_idx, |
| 195 | + step=0, |
| 196 | + phase="eval", |
| 197 | + is_shadow=False, |
| 198 | + ) |
| 199 | + trajectories.append(traj) |
| 200 | + |
| 201 | + # Track results |
| 202 | + reward = traj.reward |
| 203 | + total_reward += reward |
| 204 | + |
| 205 | + result = EnvRunResult( |
| 206 | + task_id=task_idx, |
| 207 | + reward=reward, |
| 208 | + info=traj.metadata, |
| 209 | + traj=[], # We could extract messages if needed |
| 210 | + trial=trial, |
| 211 | + ) |
| 212 | + trial_results.append(result) |
| 213 | + all_results.append(result) |
| 214 | + |
| 215 | + print( |
| 216 | + "" if reward == 1 else "L", |
| 217 | + f"task_id={task_idx}", |
| 218 | + f"reward={reward}", |
| 219 | + ) |
| 220 | + |
| 221 | + avg_reward = total_reward / len(task_indices) |
| 222 | + trial_rewards[trial] = avg_reward |
| 223 | + print(f"\nTrial {trial + 1} average reward: {avg_reward:.3f}") |
| 224 | + |
| 225 | + # Calculate overall metrics |
| 226 | + print(f"\n{'-' * 40}") |
| 227 | + print(f"Overall Results for {model_name}:") |
| 228 | + display_metrics(all_results) |
| 229 | + |
| 230 | + # Return summary |
| 231 | + return { |
| 232 | + "model": model_name, |
| 233 | + "provider": model_provider, |
| 234 | + "num_tasks": len(task_indices), |
| 235 | + "num_trials": num_trials, |
| 236 | + "trial_rewards": trial_rewards, |
| 237 | + "all_results": [r.model_dump() for r in all_results], |
| 238 | + "average_reward": sum(trial_rewards.values()) / len(trial_rewards), |
| 239 | + } |
| 240 | + |
| 241 | + |
| 242 | +async def main(): |
| 243 | + """Main benchmarking function""" |
| 244 | + config, args = parse_args() |
| 245 | + |
| 246 | + # Create output directory |
| 247 | + if not os.path.exists(args.log_dir): |
| 248 | + os.makedirs(args.log_dir) |
| 249 | + |
| 250 | + # Get task indices |
| 251 | + env = get_env( |
| 252 | + config.env, |
| 253 | + user_strategy=config.user_strategy, |
| 254 | + user_model=config.user_model, |
| 255 | + user_provider=config.user_model_provider, |
| 256 | + task_split=config.task_split, |
| 257 | + ) |
| 258 | + |
| 259 | + if args.task_ids: |
| 260 | + task_indices = args.task_ids |
| 261 | + else: |
| 262 | + end_index = ( |
| 263 | + min(args.end_index, len(env.tasks)) |
| 264 | + if args.end_index != -1 |
| 265 | + else len(env.tasks) |
| 266 | + ) |
| 267 | + task_indices = list(range(args.start_index, end_index)) |
| 268 | + |
| 269 | + print( |
| 270 | + f"Benchmarking on {len(task_indices)} tasks from {config.env} {config.task_split} split" |
| 271 | + ) |
| 272 | + print(f"Models to benchmark: {args.models}") |
| 273 | + |
| 274 | + # Benchmark each model |
| 275 | + all_benchmark_results = {} |
| 276 | + for model_name, model_provider in zip(args.models, args.model_providers): |
| 277 | + results = await benchmark_model( |
| 278 | + model_name=model_name, |
| 279 | + model_provider=model_provider, |
| 280 | + config=config, |
| 281 | + task_indices=task_indices, |
| 282 | + num_trials=args.num_trials, |
| 283 | + ) |
| 284 | + all_benchmark_results[model_name] = results |
| 285 | + |
| 286 | + # Save results |
| 287 | + time_str = datetime.now().strftime("%m%d%H%M%S") |
| 288 | + output_path = os.path.join( |
| 289 | + args.log_dir, f"benchmark_{config.env}_{config.task_split}_{time_str}.json" |
| 290 | + ) |
| 291 | + |
| 292 | + with open(output_path, "w") as f: |
| 293 | + json.dump(all_benchmark_results, f, indent=2) |
| 294 | + |
| 295 | + print(f"\nResults saved to {output_path}") |
| 296 | + |
| 297 | + # Display comparison |
| 298 | + print(f"\n{'=' * 60}") |
| 299 | + print("BENCHMARK SUMMARY") |
| 300 | + print(f"{'=' * 60}") |
| 301 | + print(f"{'Model':<30} {'Provider':<10} {'Avg Reward':<12} {'Pass@1':<10}") |
| 302 | + print(f"{'-' * 60}") |
| 303 | + |
| 304 | + for model_name, results in all_benchmark_results.items(): |
| 305 | + # Calculate Pass@1 |
| 306 | + pass_1 = sum( |
| 307 | + 1 |
| 308 | + for r in results["all_results"] |
| 309 | + if r["reward"] >= 0.999 and r["trial"] == 0 |
| 310 | + ) / len(task_indices) |
| 311 | + |
| 312 | + print( |
| 313 | + f"{model_name:<30} {results['provider']:<10} " |
| 314 | + f"{results['average_reward']:<12.3f} {pass_1:<10.3f}" |
| 315 | + ) |
| 316 | + |
| 317 | + |
| 318 | +if __name__ == "__main__": |
| 319 | + asyncio.run(main()) |
0 commit comments