Skip to content

Commit b983c84

Browse files
committed
feat(agent): use configs to replace the multi-round testing functionality of scripts.
1 parent 3c06966 commit b983c84

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

common_benchmark.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,11 @@ def preprocess_config(cfg: DictConfig, chosen_config_name: str) -> DictConfig:
700700
cfg.num_runs = 1
701701
OmegaConf.set_struct(cfg, True)
702702
# set output_dir to logs/benchmark.name/agent_set/timestamp if not set
703-
if cfg.output_dir == 'logs/':
703+
if cfg.output_dir == "logs/":
704704
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
705-
cfg.output_dir = Path(cfg.output_dir) / cfg.benchmark.name / chosen_config_name / timestamp
705+
cfg.output_dir = (
706+
Path(cfg.output_dir) / cfg.benchmark.name / chosen_config_name / timestamp
707+
)
706708
return cfg
707709

708710

@@ -736,6 +738,7 @@ def main_runs_multiprocess(cfg, args, max_workers: int | None = None):
736738

737739
futures = []
738740
import multiprocessing as mp
741+
739742
try:
740743
mp.set_start_method("spawn", force=True)
741744
except RuntimeError:
@@ -744,10 +747,7 @@ def main_runs_multiprocess(cfg, args, max_workers: int | None = None):
744747
with ProcessPoolExecutor(max_workers=max_workers) as ex:
745748
for i in range(num_runs):
746749
run_id = i + 1
747-
fut = ex.submit(
748-
_run_one_process,
749-
cfg, list(args), run_id, num_runs
750-
)
750+
fut = ex.submit(_run_one_process, cfg, list(args), run_id, num_runs)
751751
futures.append(fut)
752752

753753
ok_count, fail_count = 0, 0
@@ -757,7 +757,7 @@ def main_runs_multiprocess(cfg, args, max_workers: int | None = None):
757757
ok_count += 1
758758
else:
759759
fail_count += 1
760-
760+
761761
print("==========================================")
762762
print(f"All {num_runs} runs finished. OK={ok_count}, FAIL={fail_count}")
763763
print("==========================================")
@@ -784,4 +784,4 @@ def main(*args, config_file_name: str = ""):
784784

785785
_ = bootstrap_logger(level=LOGGER_LEVEL)
786786
# Tracing functionality removed - miroflow-contrib deleted
787-
main_runs_multiprocess(cfg, args)
787+
main_runs_multiprocess(cfg, args)

0 commit comments

Comments
 (0)