11import numpy as np
2+ import torch
23from gymnasium .wrappers import FlattenObservation
34
45from openrl .configs .config import create_config_parser
56from openrl .envs .common import make
67from openrl .envs .wrappers .base_wrapper import BaseWrapper
7- from openrl .envs .wrappers .extra_wrappers import FrameSkip , GIFWrapper
8+ from openrl .envs .wrappers .extra_wrappers import (
9+ ConvertEmptyBoxWrapper ,
10+ FrameSkip ,
11+ GIFWrapper ,
12+ )
813from openrl .modules .common import PPONet as Net
914from openrl .runners .common import PPOAgent as Agent
1015
@@ -18,15 +23,15 @@ def train():
1823 cfg = cfg_parser .parse_args (["--config" , "ppo.yaml" ])
1924
2025 # create environment, set environment parallelism to 64
26+ env_num = 64
2127 env = make (
2228 env_name ,
23- env_num = 64 ,
24- cfg = cfg ,
29+ env_num = env_num ,
2530 asynchronous = True ,
26- env_wrappers = [FrameSkip , FlattenObservation ],
31+ env_wrappers = [FrameSkip , FlattenObservation , ConvertEmptyBoxWrapper ],
2732 )
2833
29- net = Net (env , cfg = cfg , device = "cuda" )
34+ net = Net (env , cfg = cfg , device = "cuda" if torch . cuda . is_available () else "cpu" )
3035 # initialize the trainer
3136 agent = Agent (
3237 net ,
@@ -44,18 +49,18 @@ def evaluation():
4449 # begin to test
4550 # Create an environment for testing and set the number of environments to interact with to 4. Set rendering mode to group_rgb_array.
4651 render_mode = "group_rgb_array"
52+
4753 env = make (
4854 env_name ,
4955 render_mode = render_mode ,
5056 env_num = 4 ,
5157 asynchronous = True ,
52- env_wrappers = [FrameSkip , FlattenObservation ],
53- cfg = cfg ,
58+ env_wrappers = [FrameSkip , FlattenObservation , ConvertEmptyBoxWrapper ],
5459 )
5560 # Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5.
5661 env = GIFWrapper (env , gif_path = "./new.gif" , fps = 5 )
5762
58- net = Net (env , cfg = cfg , device = "cuda" )
63+ net = Net (env , cfg = cfg , device = "cuda" if torch . cuda . is_available () else "cpu" )
5964 # initialize the trainer
6065 agent = Agent (
6166 net ,
0 commit comments