11import unittest
2- from functools import partial
32from unittest .mock import patch
43
54import gymnasium as gym
5+ from gymnasium .envs .mujoco .swimmer_v4 import SwimmerEnv
66
77import ray
88from ray import tune
99from ray .rllib .algorithms .algorithm_config import AlgorithmConfig
1010from ray .rllib .env .env_runner import StepFailedRecreateEnvError
1111from ray .rllib .env .single_agent_env_runner import SingleAgentEnvRunner
12- from ray .rllib .env .utils import _gym_env_creator
1312from ray .rllib .examples .envs .classes .simple_corridor import SimpleCorridor
14- from ray .rllib . utils . test_utils import check
13+ from ray .tune . registry import ENV_CREATOR , _global_registry
1514
1615
1716class TestSingleAgentEnvRunner (unittest .TestCase ):
@@ -21,22 +20,29 @@ def setUpClass(cls) -> None:
2120
2221 tune .register_env (
2322 "tune-registered" ,
24- lambda cfg : SimpleCorridor ({"corridor_length" : 10 }),
23+ lambda cfg : SimpleCorridor ({"corridor_length" : 10 } | cfg ),
2524 )
2625
2726 gym .register (
2827 "TestEnv-v0" ,
29- partial (
30- _gym_env_creator ,
31- env_context = {"corridor_length" : 10 },
32- env_descriptor = SimpleCorridor ,
33- ),
28+ entry_point = SimpleCorridor ,
29+ kwargs = {"corridor_length" : 10 },
30+ )
31+
32+ gym .register (
33+ "TestEnv-v1" ,
34+ entry_point = SwimmerEnv ,
35+ kwargs = {"forward_reward_weight" : 2.0 , "reset_noise_scale" : 0.2 },
3436 )
3537
3638 @classmethod
3739 def tearDownClass (cls ) -> None :
3840 ray .shutdown ()
3941
42+ _global_registry .unregister (ENV_CREATOR , "tune-registered" )
43+ gym .registry .pop ("TestEnv-v0" )
44+ gym .registry .pop ("TestEnv-v1" )
45+
4046 def test_distributed_env_runner (self ):
4147 """Tests, whether SingleAgentEnvRunner can be distributed."""
4248
@@ -68,7 +74,8 @@ def test_distributed_env_runner(self):
6874 results = ray .get (results )
6975 # Loop over individual EnvRunner Actor's results and inspect each.
7076 for episodes in results :
71- # Assert length of all fragments is `rollout_fragment_length`.
77+ # Assert length of all fragments >= `rollout_fragment_length * num_envs_per_env_runner` and
78+ # < rollout_fragment_length * (num_envs_per_env_runner + 1)
7279 self .assertIn (
7380 sum (len (e ) for e in episodes ),
7481 [
@@ -79,13 +86,19 @@ def test_distributed_env_runner(self):
7986 ],
8087 )
8188
82- def test_sample (self ):
89+ def test_sample (
90+ self ,
91+ num_envs_per_env_runner = 5 ,
92+ expected_episodes = 10 ,
93+ expected_timesteps = 20 ,
94+ rollout_fragment_length = 64 ,
95+ ):
8396 config = (
8497 AlgorithmConfig ()
8598 .environment ("CartPole-v1" )
8699 .env_runners (
87- num_envs_per_env_runner = 2 ,
88- rollout_fragment_length = 64 ,
100+ num_envs_per_env_runner = num_envs_per_env_runner ,
101+ rollout_fragment_length = rollout_fragment_length ,
89102 )
90103 )
91104 env_runner = SingleAgentEnvRunner (config = config )
@@ -97,32 +110,86 @@ def test_sample(self):
97110 num_timesteps = 10 , num_episodes = 10 , random_actions = True
98111 ),
99112 )
113+ # Verify that an error is raised if a negative number is used
114+ self .assertRaises (
115+ AssertionError ,
116+ lambda : env_runner .sample (num_timesteps = - 1 , random_actions = True ),
117+ )
118+ self .assertRaises (
119+ AssertionError ,
120+ lambda : env_runner .sample (num_episodes = - 1 , random_actions = True ),
121+ )
100122
101- # Sample 10 episodes (5 per env, because num_envs_per_env_runner=2 )
123+ # Sample 10 episodes (2 per env, because num_envs_per_env_runner=5 )
102124 # Repeat 100 times
103125 for _ in range (100 ):
104- episodes = env_runner .sample (num_episodes = 10 , random_actions = True )
105- check (len (episodes ), 10 )
126+ episodes = env_runner .sample (
127+ num_episodes = expected_episodes , random_actions = True
128+ )
129+ self .assertTrue (len (episodes ) == expected_episodes )
106130 # Since we sampled complete episodes, there should be no ongoing episodes
107131 # being returned.
108132 self .assertTrue (all (e .is_done for e in episodes ))
133+ self .assertTrue (all (e .t_started == 0 for e in episodes ))
109134
110- # Sample 10 timesteps (5 per env)
135+ # Sample 20 timesteps (4 per env)
111136 # Repeat 100 times
137+ env_runner .sample (random_actions = True ) # for the `e.t_started > 0`
112138 for _ in range (100 ):
113- episodes = env_runner .sample (num_timesteps = 10 , random_actions = True )
139+ episodes = env_runner .sample (
140+ num_timesteps = expected_timesteps , random_actions = True
141+ )
142+ # Check the sum of lengths of all episodes returned.
143+ total_timesteps = sum (len (e ) for e in episodes )
144+ self .assertTrue (
145+ expected_timesteps
146+ <= total_timesteps
147+ <= expected_timesteps + num_envs_per_env_runner
148+ )
149+ self .assertTrue (any (e .t_started > 0 for e in episodes ))
150+
151+ # Sample a number of timesteps that's not a factor of the number of environments
152+ # Repeat 100 times
153+ expected_uneven_timesteps = expected_timesteps + num_envs_per_env_runner // 2
154+ for _ in range (100 ):
155+ episodes = env_runner .sample (
156+ num_timesteps = expected_uneven_timesteps , random_actions = True
157+ )
114158 # Check the sum of lengths of all episodes returned.
115- sum_ = sum (map (len , episodes ))
116- self .assertTrue (sum_ in [10 , 11 ])
159+ total_timesteps = sum (len (e ) for e in episodes )
160+ self .assertTrue (
161+ expected_uneven_timesteps
162+ <= total_timesteps
163+ <= expected_uneven_timesteps + num_envs_per_env_runner ,
164+ )
165+ self .assertTrue (any (e .t_started > 0 for e in episodes ))
117166
118167 # Sample rollout_fragment_length=64, 100 times
119168 # Repeat 100 times
120169 for _ in range (100 ):
121170 episodes = env_runner .sample (random_actions = True )
122- # Check, whether the sum of lengths of all episodes returned is 128
123- # 2 (num_env_per_worker) * 64 (rollout_fragment_length).
124- sum_ = sum (map (len , episodes ))
125- self .assertTrue (sum_ in [128 , 129 ])
171+ # Check, whether the sum of lengths of all episodes returned is 320
172+ # 5 (num_env_per_worker) * 64 (rollout_fragment_length).
173+ total_timesteps = sum (len (e ) for e in episodes )
174+ self .assertTrue (
175+ num_envs_per_env_runner * rollout_fragment_length
176+ <= total_timesteps
177+ <= (
178+ num_envs_per_env_runner * rollout_fragment_length
179+ + num_envs_per_env_runner
180+ )
181+ )
182+ self .assertTrue (any (e .t_started > 0 for e in episodes ))
183+
184+ # Test that force_reset will create episodes from scratch even with `num_timesteps`
185+ episodes = env_runner .sample (
186+ num_timesteps = expected_timesteps , random_actions = True , force_reset = True
187+ )
188+ self .assertTrue (all (e .t_started == 0 for e in episodes ))
189+ episodes = env_runner .sample (
190+ num_timesteps = expected_timesteps , random_actions = True , force_reset = False
191+ )
192+ self .assertTrue (any (e .t_started > 0 for e in episodes ))
126193
127194 @patch (target = "ray.rllib.env.env_runner.logger" )
128195 def test_step_failed_reset_required (self , mock_logger ):
@@ -172,29 +239,87 @@ def step(self, action):
172239
173240 assert mock_logger .exception .call_count == 1
174241
175- def test_vector_env (self ):
242+ def test_vector_env (self , num_envs_per_env_runner = 5 , rollout_fragment_length = 10 ):
176243 """Tests, whether SingleAgentEnvRunner can run various vectorized envs."""
177244
178245 for env in ["CartPole-v1" , SimpleCorridor , "tune-registered" ]:
179246 config = (
180247 AlgorithmConfig ()
181248 .environment (env )
182249 .env_runners (
183- num_envs_per_env_runner = 5 ,
184- rollout_fragment_length = 10 ,
250+ num_envs_per_env_runner = num_envs_per_env_runner ,
251+ rollout_fragment_length = rollout_fragment_length ,
185252 )
186253 )
187254
188255 env_runner = SingleAgentEnvRunner (config = config )
189256
190257 # Sample with the async-vectorized env.
191- episodes = env_runner .sample (random_actions = True )
192- self .assertEqual (
193- sum (len (e ) for e in episodes ),
194- config .num_envs_per_env_runner * config .rollout_fragment_length ,
195- )
258+ for i in range (100 ):
259+ episodes = env_runner .sample (random_actions = True )
260+ total_timesteps = sum (len (e ) for e in episodes )
261+ self .assertTrue (
262+ num_envs_per_env_runner * rollout_fragment_length
263+ <= total_timesteps
264+ <= (
265+ num_envs_per_env_runner * rollout_fragment_length
266+ + num_envs_per_env_runner
267+ )
268+ )
196269 env_runner .stop ()
197270
271+ def test_env_context (self ):
272+ """Tests, whether SingleAgentEnvRunner can pass kwargs to the environments correctly."""
273+
274+ # default without env configs
275+ config = AlgorithmConfig ().environment ("Swimmer-v4" )
276+ env_runner = SingleAgentEnvRunner (config = config )
277+ assert env_runner .env .env .get_attr ("_forward_reward_weight" ) == (1.0 ,)
278+ assert env_runner .env .env .get_attr ("_reset_noise_scale" ) == (0.1 ,)
279+
280+ # Test gym registered environment env with kwargs
281+ config = AlgorithmConfig ().environment (
282+ "Swimmer-v4" ,
283+ env_config = {"forward_reward_weight" : 2.0 , "reset_noise_scale" : 0.2 },
284+ )
285+ env_runner = SingleAgentEnvRunner (config = config )
286+ assert env_runner .env .env .get_attr ("_forward_reward_weight" ) == (2.0 ,)
287+ assert env_runner .env .env .get_attr ("_reset_noise_scale" ) == (0.2 ,)
288+
289+ # Test gym registered environment env with pre-set kwargs
290+ config = AlgorithmConfig ().environment ("TestEnv-v1" )
291+ env_runner = SingleAgentEnvRunner (config = config )
292+ assert env_runner .env .env .get_attr ("_forward_reward_weight" ) == (2.0 ,)
293+ assert env_runner .env .env .get_attr ("_reset_noise_scale" ) == (0.2 ,)
294+
295+ # Test using a mixture of registered kwargs and env configs
296+ config = AlgorithmConfig ().environment (
297+ "TestEnv-v1" , env_config = {"forward_reward_weight" : 3.0 }
298+ )
299+ env_runner = SingleAgentEnvRunner (config = config )
300+ assert env_runner .env .env .get_attr ("_forward_reward_weight" ) == (3.0 ,)
301+ assert env_runner .env .env .get_attr ("_reset_noise_scale" ) == (0.2 ,)
302+
303+ # Test env-config with Tune registered or callable
304+ # default
305+ config = AlgorithmConfig ().environment ("tune-registered" )
306+ env_runner = SingleAgentEnvRunner (config = config )
307+ assert env_runner .env .env .get_attr ("end_pos" ) == (10.0 ,)
308+
309+ # tune-registered
310+ config = AlgorithmConfig ().environment (
311+ "tune-registered" , env_config = {"corridor_length" : 5.0 }
312+ )
313+ env_runner = SingleAgentEnvRunner (config = config )
314+ assert env_runner .env .env .get_attr ("end_pos" ) == (5.0 ,)
315+
316+ # callable
317+ config = AlgorithmConfig ().environment (
318+ SimpleCorridor , env_config = {"corridor_length" : 5.0 }
319+ )
320+ env_runner = SingleAgentEnvRunner (config = config )
321+ assert env_runner .env .env .get_attr ("end_pos" ) == (5.0 ,)
322+
198323
199324if __name__ == "__main__" :
200325 import sys
0 commit comments