Skip to content

Commit 1881e15

Browse files
committed
fix arena evaluation bugs
1 parent aa3a017 commit 1881e15

File tree

5 files changed

+24
-25
lines changed

5 files changed

+24
-25
lines changed

examples/arena/run_arena.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@ def run_arena(
5252

5353
if __name__ == "__main__":
5454
run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10)
55+
# run_arena(render=True, parallel=True, seed=1, total_games=10, max_game_onetime=2)

examples/selfplay/human_vs_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_human_env(env_num):
2727
env = make(
2828
"tictactoe_v3",
2929
env_num=env_num,
30-
asynchronous=False,
30+
asynchronous=True,
3131
opponent_wrappers=[TictactoeRender, HumanOpponentWrapper],
3232
env_wrappers=[FlattenObservation],
3333
auto_reset=False,

examples/selfplay/tictactoe_utils/tictactoe_render.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,6 @@ def step(self, action: ActionType) -> None:
4848
self.last_action = action[0]
4949
return result
5050

51-
def observe(self, agent: str) -> Optional[ObsType]:
52-
obs = super().observe(agent)
53-
if self.last_action is not None:
54-
if self.render_mode == "game":
55-
self.game.make_move(self.last_action // 3, self.last_action % 3)
56-
pygame.display.update()
57-
self.last_action = None
58-
time.sleep(0.3)
59-
return obs
60-
6151
def close(self):
6252
super().close()
6353
self.game.close()
@@ -69,3 +59,12 @@ def get_human_action(self, agent, observation, termination, truncation, info):
6959
return self.game.get_human_action(
7060
agent, observation, termination, truncation, info
7161
)
62+
63+
def last(self, observe: bool = True):
64+
if self.last_action is not None:
65+
if self.render_mode == "game":
66+
self.game.make_move(self.last_action // 3, self.last_action % 3)
67+
pygame.display.update()
68+
self.last_action = None
69+
time.sleep(0.3)
70+
return self.env.last(observe)

openrl/arena/games/two_player_game.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _run(self, env_fn: Callable, agents: List[BaseAgent]):
5454
info = {}
5555
for player_name in env.agent_iter():
5656
observation, reward, termination, truncation, info = env.last()
57+
5758
if termination:
5859
break
5960
action = player2agent[player_name].act(

openrl/envs/wrappers/pettingzoo_wrappers.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,6 @@ def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None):
7777
self.total_rewards = defaultdict(float)
7878
return super().reset(seed, options)
7979

80-
def step(self, action: ActionType) -> None:
81-
super().step(action)
82-
winners = None
83-
losers = None
84-
for agent in self.terminations:
85-
if self.terminations[agent]:
86-
if winners is None:
87-
winners = self.get_winners()
88-
losers = [player for player in self.agents if player not in winners]
89-
self.infos[agent]["winners"] = winners
90-
self.infos[agent]["losers"] = losers
91-
9280
def get_winners(self):
9381
max_reward = max(self.total_rewards.values())
9482

@@ -101,11 +89,21 @@ def get_winners(self):
10189

10290
def last(self, observe: bool = True):
10391
"""Returns observation, cumulative reward, terminated, truncated, info for the current agent (specified by self.agent_selection)."""
92+
10493
agent = self.agent_selection
105-
# if self._cumulative_rewards[agent]!=0:
106-
# print("agent:",agent,self._cumulative_rewards[agent])
94+
# this may be miss the last reward for another agent
10795
self.total_rewards[agent] += self._cumulative_rewards[agent]
10896

97+
winners = None
98+
losers = None
99+
for agent in self.terminations:
100+
if self.terminations[agent]:
101+
if winners is None:
102+
winners = self.get_winners()
103+
losers = [player for player in self.agents if player not in winners]
104+
self.infos[agent]["winners"] = winners
105+
self.infos[agent]["losers"] = losers
106+
109107
return super().last(observe)
110108

111109

0 commit comments

Comments
 (0)