Skip to content

Commit 82d5c48

Browse files
authored
Merge pull request #195 from OpenPipe/tau_bench
make model creation simpler in tau bench
2 parents aa4adad + a4cb26f commit 82d5c48

File tree

7 files changed

+3495
-3260
lines changed

7 files changed

+3495
-3260
lines changed

dev/tau-bench/benchmark_rl.py

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# Copyright OpenPipe
2+
3+
import argparse
4+
import asyncio
5+
import json
6+
import os
7+
from datetime import datetime
8+
from typing import List, Dict, Any
9+
from dotenv import load_dotenv
10+
11+
import art
12+
from art.utils import iterate_dataset
13+
from tau_bench.types import RunConfig, TauBenchPolicyConfig, TauBenchTrainingConfig
14+
from tau_bench.envs import get_env
15+
from tau_bench.run import display_metrics
16+
from tau_bench.types import EnvRunResult
17+
from litellm import provider_list
18+
from tau_bench.envs.user import UserStrategy
19+
20+
# Import evaluate_model and rollout functions from run_rl
21+
from run_rl import evaluate_model, rollout_tau_bench_task
22+
23+
# Load environment variables
24+
load_dotenv(override=True)
25+
26+
27+
def parse_args() -> tuple[RunConfig, argparse.Namespace]:
28+
"""Parse command line arguments for benchmarking"""
29+
parser = argparse.ArgumentParser(
30+
description="Benchmark off-the-shelf models on tau-bench using RL evaluation"
31+
)
32+
33+
# Model configuration
34+
parser.add_argument(
35+
"--models",
36+
type=str,
37+
nargs="+",
38+
default=["gpt-4o"],
39+
help="List of models to benchmark (default: gpt-4o)",
40+
)
41+
parser.add_argument(
42+
"--model-providers",
43+
type=str,
44+
nargs="+",
45+
default=["openai"],
46+
choices=provider_list,
47+
help="List of model providers corresponding to each model",
48+
)
49+
50+
# Environment configuration
51+
parser.add_argument(
52+
"--env", type=str, choices=["retail", "airline"], default="retail"
53+
)
54+
parser.add_argument(
55+
"--user-model",
56+
type=str,
57+
default="gpt-4o",
58+
help="The model to use for the user simulator",
59+
)
60+
parser.add_argument(
61+
"--user-model-provider",
62+
type=str,
63+
default="openai",
64+
choices=provider_list,
65+
help="The model provider for the user simulator",
66+
)
67+
parser.add_argument(
68+
"--user-strategy",
69+
type=str,
70+
default="llm",
71+
choices=[item.value for item in UserStrategy],
72+
)
73+
74+
# Task configuration
75+
parser.add_argument(
76+
"--task-split",
77+
type=str,
78+
default="test",
79+
choices=["train", "test", "dev"],
80+
help="The split of tasks to benchmark on",
81+
)
82+
parser.add_argument("--start-index", type=int, default=0)
83+
parser.add_argument(
84+
"--end-index", type=int, default=100, help="End index for benchmark tasks"
85+
)
86+
parser.add_argument(
87+
"--task-ids",
88+
type=int,
89+
nargs="+",
90+
help="(Optional) run only the tasks with the given IDs",
91+
)
92+
93+
# Evaluation configuration
94+
parser.add_argument(
95+
"--temperature",
96+
type=float,
97+
default=1.0,
98+
help="The sampling temperature for the models",
99+
)
100+
parser.add_argument(
101+
"--max-num-steps",
102+
type=int,
103+
default=30,
104+
help="Maximum number of steps per rollout",
105+
)
106+
parser.add_argument(
107+
"--num-trials",
108+
type=int,
109+
default=1,
110+
help="Number of trials to run for each task",
111+
)
112+
113+
# Output configuration
114+
parser.add_argument("--log-dir", type=str, default="benchmark_results")
115+
parser.add_argument("--seed", type=int, default=10)
116+
117+
args = parser.parse_args()
118+
119+
# Ensure model providers match models
120+
if len(args.model_providers) == 1 and len(args.models) > 1:
121+
args.model_providers = args.model_providers * len(args.models)
122+
elif len(args.model_providers) != len(args.models):
123+
raise ValueError(
124+
"Number of model providers must match number of models or be 1"
125+
)
126+
127+
# Create RunConfig
128+
run_config = RunConfig(
129+
model_provider=args.model_providers[0], # Will be updated per model
130+
user_model_provider=args.user_model_provider,
131+
model=args.models[0], # Will be updated per model
132+
user_model=args.user_model,
133+
num_trials=args.num_trials,
134+
env=args.env,
135+
agent_strategy="tool-calling-rl",
136+
temperature=args.temperature,
137+
task_split=args.task_split,
138+
start_index=args.start_index,
139+
end_index=args.end_index,
140+
task_ids=args.task_ids,
141+
log_dir=args.log_dir,
142+
max_concurrency=50,
143+
seed=args.seed,
144+
shuffle=0,
145+
user_strategy=args.user_strategy,
146+
max_num_steps=args.max_num_steps,
147+
reward_type="real",
148+
)
149+
150+
return run_config, args
151+
152+
153+
async def benchmark_model(
154+
model_name: str,
155+
model_provider: str,
156+
config: RunConfig,
157+
task_indices: List[int],
158+
num_trials: int,
159+
) -> Dict[str, Any]:
160+
"""Benchmark a single model on the given tasks"""
161+
print(f"\n{'=' * 60}")
162+
print(f"Benchmarking model: {model_name} (provider: {model_provider})")
163+
print(f"{'=' * 60}")
164+
165+
# Update config for this model
166+
config.model = model_name
167+
config.model_provider = model_provider
168+
169+
# Create a mock trainable model for evaluation
170+
model = art.Model(
171+
name=model_name,
172+
config=TauBenchPolicyConfig(
173+
run_config=config,
174+
training_config=None, # No training config needed for evaluation
175+
),
176+
)
177+
178+
# Store results for each trial
179+
all_results = []
180+
trial_rewards = {}
181+
182+
for trial in range(num_trials):
183+
print(f"\nTrial {trial + 1}/{num_trials}")
184+
185+
# Run evaluation for this trial
186+
trial_results = []
187+
total_reward = 0.0
188+
189+
# Collect trajectories for all tasks in this trial
190+
trajectories = []
191+
for task_idx in task_indices:
192+
traj = await rollout_tau_bench_task(
193+
model=model,
194+
task_index=task_idx,
195+
step=0,
196+
phase="eval",
197+
is_shadow=False,
198+
)
199+
trajectories.append(traj)
200+
201+
# Track results
202+
reward = traj.reward
203+
total_reward += reward
204+
205+
result = EnvRunResult(
206+
task_id=task_idx,
207+
reward=reward,
208+
info=traj.metadata,
209+
traj=[], # We could extract messages if needed
210+
trial=trial,
211+
)
212+
trial_results.append(result)
213+
all_results.append(result)
214+
215+
print(
216+
"" if reward == 1 else "L",
217+
f"task_id={task_idx}",
218+
f"reward={reward}",
219+
)
220+
221+
avg_reward = total_reward / len(task_indices)
222+
trial_rewards[trial] = avg_reward
223+
print(f"\nTrial {trial + 1} average reward: {avg_reward:.3f}")
224+
225+
# Calculate overall metrics
226+
print(f"\n{'-' * 40}")
227+
print(f"Overall Results for {model_name}:")
228+
display_metrics(all_results)
229+
230+
# Return summary
231+
return {
232+
"model": model_name,
233+
"provider": model_provider,
234+
"num_tasks": len(task_indices),
235+
"num_trials": num_trials,
236+
"trial_rewards": trial_rewards,
237+
"all_results": [r.model_dump() for r in all_results],
238+
"average_reward": sum(trial_rewards.values()) / len(trial_rewards),
239+
}
240+
241+
242+
async def main():
243+
"""Main benchmarking function"""
244+
config, args = parse_args()
245+
246+
# Create output directory
247+
if not os.path.exists(args.log_dir):
248+
os.makedirs(args.log_dir)
249+
250+
# Get task indices
251+
env = get_env(
252+
config.env,
253+
user_strategy=config.user_strategy,
254+
user_model=config.user_model,
255+
user_provider=config.user_model_provider,
256+
task_split=config.task_split,
257+
)
258+
259+
if args.task_ids:
260+
task_indices = args.task_ids
261+
else:
262+
end_index = (
263+
min(args.end_index, len(env.tasks))
264+
if args.end_index != -1
265+
else len(env.tasks)
266+
)
267+
task_indices = list(range(args.start_index, end_index))
268+
269+
print(
270+
f"Benchmarking on {len(task_indices)} tasks from {config.env} {config.task_split} split"
271+
)
272+
print(f"Models to benchmark: {args.models}")
273+
274+
# Benchmark each model
275+
all_benchmark_results = {}
276+
for model_name, model_provider in zip(args.models, args.model_providers):
277+
results = await benchmark_model(
278+
model_name=model_name,
279+
model_provider=model_provider,
280+
config=config,
281+
task_indices=task_indices,
282+
num_trials=args.num_trials,
283+
)
284+
all_benchmark_results[model_name] = results
285+
286+
# Save results
287+
time_str = datetime.now().strftime("%m%d%H%M%S")
288+
output_path = os.path.join(
289+
args.log_dir, f"benchmark_{config.env}_{config.task_split}_{time_str}.json"
290+
)
291+
292+
with open(output_path, "w") as f:
293+
json.dump(all_benchmark_results, f, indent=2)
294+
295+
print(f"\nResults saved to {output_path}")
296+
297+
# Display comparison
298+
print(f"\n{'=' * 60}")
299+
print("BENCHMARK SUMMARY")
300+
print(f"{'=' * 60}")
301+
print(f"{'Model':<30} {'Provider':<10} {'Avg Reward':<12} {'Pass@1':<10}")
302+
print(f"{'-' * 60}")
303+
304+
for model_name, results in all_benchmark_results.items():
305+
# Calculate Pass@1
306+
pass_1 = sum(
307+
1
308+
for r in results["all_results"]
309+
if r["reward"] >= 0.999 and r["trial"] == 0
310+
) / len(task_indices)
311+
312+
print(
313+
f"{model_name:<30} {results['provider']:<10} "
314+
f"{results['average_reward']:<12.3f} {pass_1:<10.3f}"
315+
)
316+
317+
318+
if __name__ == "__main__":
319+
asyncio.run(main())

0 commit comments

Comments
 (0)