Skip to content

Commit cf9f783

Browse files
pseudo-rnd-thoughtsMark Towerssimonsays1980
authored
[rllib] Improve test_single_agent_env_runner to prevent flaky tests (#58397)
## Description In improving the `SingleEnvRunner.make_env`, I found that some of the tests could be flaky. This PR improves the testing, in particular, to `sample` to ensure that the tests don't fail occasionally and the documentation to reflect this. The primary flaky problem I found is that `sample(num_timesteps=X)` will not always return a total of `X` timesteps, rather at least X timesteps up to the number of environments more. I'm updated the documentation to clarify this for users. In addition, I've added tests for when neither the number of timesteps or episodes are given and for the `force_reset` argument --------- Signed-off-by: Mark Towers <[email protected]> Co-authored-by: Mark Towers <[email protected]> Co-authored-by: simonsays1980 <[email protected]>
1 parent 0a8b888 commit cf9f783

File tree

2 files changed

+168
-33
lines changed

2 files changed

+168
-33
lines changed

rllib/env/single_agent_env_runner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,15 @@ def sample(
155155
) -> List[SingleAgentEpisode]:
156156
"""Runs and returns a sample (n timesteps or m episodes) on the env(s).
157157
158+
If neither `num_timesteps` nor `num_episodes` are provided and the config
159+
`batch_mode` is "truncate_episodes" then
160+
`config.get_rollout_fragment_length(self.worker_index) * self.num_envs`
161+
timesteps will be sampled.
162+
158163
Args:
159164
num_timesteps: The number of timesteps to sample during this call.
165+
The episodes returned will contain the total timesteps greater than or
166+
equal to num_timesteps and less than num_timesteps + num_envs_per_env_runner.
160167
Note that only one of `num_timesteps` or `num_episodes` may be provided.
161168
num_episodes: The number of episodes to sample during this call.
162169
Note that only one of `num_timesteps` or `num_episodes` may be provided.
@@ -169,7 +176,7 @@ def sample(
169176
random_actions: If True, actions will be sampled randomly (from the action
170177
space of the environment). If False (default), actions or action
171178
distribution parameters are computed by the RLModule.
172-
force_reset: Whether to force-reset all (vector) environments before
179+
force_reset: Whether to force-reset all vectorized environments before
173180
sampling. Useful if you would like to collect a clean slate of new
174181
episodes via this call. Note that when sampling n episodes
175182
(`num_episodes != None`), this is fixed to True.
@@ -203,6 +210,7 @@ def sample(
203210
# desired timesteps/episodes to sample and exploration behavior.
204211
if explore is None:
205212
explore = self.config.explore
213+
206214
if (
207215
num_timesteps is None
208216
and num_episodes is None
@@ -215,6 +223,7 @@ def sample(
215223

216224
# Sample n timesteps.
217225
if num_timesteps is not None:
226+
assert num_timesteps >= 0
218227
samples = self._sample(
219228
num_timesteps=num_timesteps,
220229
explore=explore,
@@ -223,6 +232,7 @@ def sample(
223232
)
224233
# Sample m episodes.
225234
elif num_episodes is not None:
235+
assert num_episodes >= 0
226236
samples = self._sample(
227237
num_episodes=num_episodes,
228238
explore=explore,

rllib/env/tests/test_single_agent_env_runner.py

Lines changed: 157 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
import unittest
2-
from functools import partial
32
from unittest.mock import patch
43

54
import gymnasium as gym
5+
from gymnasium.envs.mujoco.swimmer_v4 import SwimmerEnv
66

77
import ray
88
from ray import tune
99
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
1010
from ray.rllib.env.env_runner import StepFailedRecreateEnvError
1111
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
12-
from ray.rllib.env.utils import _gym_env_creator
1312
from ray.rllib.examples.envs.classes.simple_corridor import SimpleCorridor
14-
from ray.rllib.utils.test_utils import check
13+
from ray.tune.registry import ENV_CREATOR, _global_registry
1514

1615

1716
class TestSingleAgentEnvRunner(unittest.TestCase):
@@ -21,22 +20,29 @@ def setUpClass(cls) -> None:
2120

2221
tune.register_env(
2322
"tune-registered",
24-
lambda cfg: SimpleCorridor({"corridor_length": 10}),
23+
lambda cfg: SimpleCorridor({"corridor_length": 10} | cfg),
2524
)
2625

2726
gym.register(
2827
"TestEnv-v0",
29-
partial(
30-
_gym_env_creator,
31-
env_context={"corridor_length": 10},
32-
env_descriptor=SimpleCorridor,
33-
),
28+
entry_point=SimpleCorridor,
29+
kwargs={"corridor_length": 10},
30+
)
31+
32+
gym.register(
33+
"TestEnv-v1",
34+
entry_point=SwimmerEnv,
35+
kwargs={"forward_reward_weight": 2.0, "reset_noise_scale": 0.2},
3436
)
3537

3638
@classmethod
3739
def tearDownClass(cls) -> None:
3840
ray.shutdown()
3941

42+
_global_registry.unregister(ENV_CREATOR, "tune-registered")
43+
gym.registry.pop("TestEnv-v0")
44+
gym.registry.pop("TestEnv-v1")
45+
4046
def test_distributed_env_runner(self):
4147
"""Tests, whether SingleAgentEnvRunner can be distributed."""
4248

@@ -68,7 +74,8 @@ def test_distributed_env_runner(self):
6874
results = ray.get(results)
6975
# Loop over individual EnvRunner Actor's results and inspect each.
7076
for episodes in results:
71-
# Assert length of all fragments is `rollout_fragment_length`.
77+
# Assert length of all fragments >= `rollout_fragment_length * num_envs_per_env_runner` and
78+
# < rollout_fragment_length * (num_envs_per_env_runner + 1)
7279
self.assertIn(
7380
sum(len(e) for e in episodes),
7481
[
@@ -79,13 +86,19 @@ def test_distributed_env_runner(self):
7986
],
8087
)
8188

82-
def test_sample(self):
89+
def test_sample(
90+
self,
91+
num_envs_per_env_runner=5,
92+
expected_episodes=10,
93+
expected_timesteps=20,
94+
rollout_fragment_length=64,
95+
):
8396
config = (
8497
AlgorithmConfig()
8598
.environment("CartPole-v1")
8699
.env_runners(
87-
num_envs_per_env_runner=2,
88-
rollout_fragment_length=64,
100+
num_envs_per_env_runner=num_envs_per_env_runner,
101+
rollout_fragment_length=rollout_fragment_length,
89102
)
90103
)
91104
env_runner = SingleAgentEnvRunner(config=config)
@@ -97,32 +110,86 @@ def test_sample(self):
97110
num_timesteps=10, num_episodes=10, random_actions=True
98111
),
99112
)
113+
# Verify that an error is raised if a negative number is used
114+
self.assertRaises(
115+
AssertionError,
116+
lambda: env_runner.sample(num_timesteps=-1, random_actions=True),
117+
)
118+
self.assertRaises(
119+
AssertionError,
120+
lambda: env_runner.sample(num_episodes=-1, random_actions=True),
121+
)
100122

101-
# Sample 10 episodes (5 per env, because num_envs_per_env_runner=2)
123+
# Sample 10 episodes (2 per env, because num_envs_per_env_runner=5)
102124
# Repeat 100 times
103125
for _ in range(100):
104-
episodes = env_runner.sample(num_episodes=10, random_actions=True)
105-
check(len(episodes), 10)
126+
episodes = env_runner.sample(
127+
num_episodes=expected_episodes, random_actions=True
128+
)
129+
self.assertTrue(len(episodes) == expected_episodes)
106130
# Since we sampled complete episodes, there should be no ongoing episodes
107131
# being returned.
108132
self.assertTrue(all(e.is_done for e in episodes))
133+
self.assertTrue(all(e.t_started == 0 for e in episodes))
109134

110-
# Sample 10 timesteps (5 per env)
135+
# Sample 20 timesteps (4 per env)
111136
# Repeat 100 times
137+
env_runner.sample(random_actions=True) # for the `e.t_started > 0`
112138
for _ in range(100):
113-
episodes = env_runner.sample(num_timesteps=10, random_actions=True)
139+
episodes = env_runner.sample(
140+
num_timesteps=expected_timesteps, random_actions=True
141+
)
142+
# Check the sum of lengths of all episodes returned.
143+
total_timesteps = sum(len(e) for e in episodes)
144+
self.assertTrue(
145+
expected_timesteps
146+
<= total_timesteps
147+
<= expected_timesteps + num_envs_per_env_runner
148+
)
149+
self.assertTrue(any(e.t_started > 0 for e in episodes))
150+
151+
# Sample a number of timesteps that's not a factor of the number of environments
152+
# Repeat 100 times
153+
expected_uneven_timesteps = expected_timesteps + num_envs_per_env_runner // 2
154+
for _ in range(100):
155+
episodes = env_runner.sample(
156+
num_timesteps=expected_uneven_timesteps, random_actions=True
157+
)
114158
# Check the sum of lengths of all episodes returned.
115-
sum_ = sum(map(len, episodes))
116-
self.assertTrue(sum_ in [10, 11])
159+
total_timesteps = sum(len(e) for e in episodes)
160+
self.assertTrue(
161+
expected_uneven_timesteps
162+
<= total_timesteps
163+
<= expected_uneven_timesteps + num_envs_per_env_runner,
164+
)
165+
self.assertTrue(any(e.t_started > 0 for e in episodes))
117166

118167
# Sample rollout_fragment_length=64, 100 times
119168
# Repeat 100 times
120169
for _ in range(100):
121170
episodes = env_runner.sample(random_actions=True)
122-
# Check, whether the sum of lengths of all episodes returned is 128
123-
# 2 (num_env_per_worker) * 64 (rollout_fragment_length).
124-
sum_ = sum(map(len, episodes))
125-
self.assertTrue(sum_ in [128, 129])
171+
# Check, whether the sum of lengths of all episodes returned is 320
172+
# 5 (num_env_per_worker) * 64 (rollout_fragment_length).
173+
total_timesteps = sum(len(e) for e in episodes)
174+
self.assertTrue(
175+
num_envs_per_env_runner * rollout_fragment_length
176+
<= total_timesteps
177+
<= (
178+
num_envs_per_env_runner * rollout_fragment_length
179+
+ num_envs_per_env_runner
180+
)
181+
)
182+
self.assertTrue(any(e.t_started > 0 for e in episodes))
183+
184+
# Test that force_reset will create episodes from scratch even with `num_timesteps`
185+
episodes = env_runner.sample(
186+
num_timesteps=expected_timesteps, random_actions=True, force_reset=True
187+
)
188+
self.assertTrue(all(e.t_started == 0 for e in episodes))
189+
episodes = env_runner.sample(
190+
num_timesteps=expected_timesteps, random_actions=True, force_reset=False
191+
)
192+
self.assertTrue(any(e.t_started > 0 for e in episodes))
126193

127194
@patch(target="ray.rllib.env.env_runner.logger")
128195
def test_step_failed_reset_required(self, mock_logger):
@@ -172,29 +239,87 @@ def step(self, action):
172239

173240
assert mock_logger.exception.call_count == 1
174241

175-
def test_vector_env(self):
242+
def test_vector_env(self, num_envs_per_env_runner=5, rollout_fragment_length=10):
176243
"""Tests, whether SingleAgentEnvRunner can run various vectorized envs."""
177244

178245
for env in ["CartPole-v1", SimpleCorridor, "tune-registered"]:
179246
config = (
180247
AlgorithmConfig()
181248
.environment(env)
182249
.env_runners(
183-
num_envs_per_env_runner=5,
184-
rollout_fragment_length=10,
250+
num_envs_per_env_runner=num_envs_per_env_runner,
251+
rollout_fragment_length=rollout_fragment_length,
185252
)
186253
)
187254

188255
env_runner = SingleAgentEnvRunner(config=config)
189256

190257
# Sample with the async-vectorized env.
191-
episodes = env_runner.sample(random_actions=True)
192-
self.assertEqual(
193-
sum(len(e) for e in episodes),
194-
config.num_envs_per_env_runner * config.rollout_fragment_length,
195-
)
258+
for i in range(100):
259+
episodes = env_runner.sample(random_actions=True)
260+
total_timesteps = sum(len(e) for e in episodes)
261+
self.assertTrue(
262+
num_envs_per_env_runner * rollout_fragment_length
263+
<= total_timesteps
264+
<= (
265+
num_envs_per_env_runner * rollout_fragment_length
266+
+ num_envs_per_env_runner
267+
)
268+
)
196269
env_runner.stop()
197270

271+
def test_env_context(self):
272+
"""Tests, whether SingleAgentEnvRunner can pass kwargs to the environments correctly."""
273+
274+
# default without env configs
275+
config = AlgorithmConfig().environment("Swimmer-v4")
276+
env_runner = SingleAgentEnvRunner(config=config)
277+
assert env_runner.env.env.get_attr("_forward_reward_weight") == (1.0,)
278+
assert env_runner.env.env.get_attr("_reset_noise_scale") == (0.1,)
279+
280+
# Test gym registered environment env with kwargs
281+
config = AlgorithmConfig().environment(
282+
"Swimmer-v4",
283+
env_config={"forward_reward_weight": 2.0, "reset_noise_scale": 0.2},
284+
)
285+
env_runner = SingleAgentEnvRunner(config=config)
286+
assert env_runner.env.env.get_attr("_forward_reward_weight") == (2.0,)
287+
assert env_runner.env.env.get_attr("_reset_noise_scale") == (0.2,)
288+
289+
# Test gym registered environment env with pre-set kwargs
290+
config = AlgorithmConfig().environment("TestEnv-v1")
291+
env_runner = SingleAgentEnvRunner(config=config)
292+
assert env_runner.env.env.get_attr("_forward_reward_weight") == (2.0,)
293+
assert env_runner.env.env.get_attr("_reset_noise_scale") == (0.2,)
294+
295+
# Test using a mixture of registered kwargs and env configs
296+
config = AlgorithmConfig().environment(
297+
"TestEnv-v1", env_config={"forward_reward_weight": 3.0}
298+
)
299+
env_runner = SingleAgentEnvRunner(config=config)
300+
assert env_runner.env.env.get_attr("_forward_reward_weight") == (3.0,)
301+
assert env_runner.env.env.get_attr("_reset_noise_scale") == (0.2,)
302+
303+
# Test env-config with Tune registered or callable
304+
# default
305+
config = AlgorithmConfig().environment("tune-registered")
306+
env_runner = SingleAgentEnvRunner(config=config)
307+
assert env_runner.env.env.get_attr("end_pos") == (10.0,)
308+
309+
# tune-registered
310+
config = AlgorithmConfig().environment(
311+
"tune-registered", env_config={"corridor_length": 5.0}
312+
)
313+
env_runner = SingleAgentEnvRunner(config=config)
314+
assert env_runner.env.env.get_attr("end_pos") == (5.0,)
315+
316+
# callable
317+
config = AlgorithmConfig().environment(
318+
SimpleCorridor, env_config={"corridor_length": 5.0}
319+
)
320+
env_runner = SingleAgentEnvRunner(config=config)
321+
assert env_runner.env.env.get_attr("end_pos") == (5.0,)
322+
198323

199324
if __name__ == "__main__":
200325
import sys

0 commit comments

Comments
 (0)