Skip to content

Commit 94b5ddf

Browse files
committed
add A2C algorithm
1 parent e1e1a79 commit 94b5ddf

File tree

14 files changed

+328
-6
lines changed

14 files changed

+328
-6
lines changed

Gallery.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Users are also welcome to contribute their own training examples and demos to th
4141
| [JRPO](https://arxiv.org/abs/2302.07515) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [code](./examples/mpe/) |
4242
| [GAIL](https://arxiv.org/abs/1606.03476) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [code](./examples/gail/) |
4343
| [Behavior Cloning](http://www.cse.unsw.edu.au/~claude/papers/MI15.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [code](./examples/behavior_cloning/) |
44+
| [A2C](https://arxiv.org/abs/1602.01783) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/cartpole/) |
4445
| Self-Play | ![selfplay](https://img.shields.io/badge/-selfplay-blue) ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/selfplay/) |
4546
| [DQN](https://arxiv.org/abs/1312.5602) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![value](https://img.shields.io/badge/-value-orange) ![offpolicy](https://img.shields.io/badge/-offpolicy-blue) | [code](./examples/toy_env) [code](./examples/gridworld/) |
4647
| [MAT](https://arxiv.org/abs/2205.14953) | ![MARL](https://img.shields.io/badge/-MARL-yellow) ![Transformer](https://img.shields.io/badge/-Transformer-blue) | [code](./examples/mpe/) |

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ Algorithms currently supported by OpenRL (for more details, please refer to [Gal
9898
- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515)
9999
- [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/abs/1606.03476)
100100
- [Behavior Cloning (BC)](http://www.cse.unsw.edu.au/~claude/papers/MI15.pdf)
101+
- [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783)
101102
- Self-Play
102103
- [Deep Q-Network (DQN)](https://arxiv.org/abs/1312.5602)
103104
- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953)

README_zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ OpenRL目前支持的算法(更多详情请参考 [Gallery](Gallery.md)):
7474
- [Joint-ratio Policy Optimization (JRPO)](https://arxiv.org/abs/2302.07515)
7575
- [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/abs/1606.03476)
7676
- [Behavior Cloning (BC)](http://www.cse.unsw.edu.au/~claude/papers/MI15.pdf)
77+
- [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783)
7778
- Self-Play
7879
- [Deep Q-Network (DQN)](https://arxiv.org/abs/1312.5602)
7980
- [Multi-Agent Transformer (MAT)](https://arxiv.org/abs/2205.14953)

examples/cartpole/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ To train with [Dual-clip PPO](https://arxiv.org/abs/1912.09729):
1313
python train_ppo.py --config dual_clip_ppo.yaml
1414
```
1515

16+
To train with [A2C](https://arxiv.org/abs/1602.01783) algorithm:
17+
18+
```shell
19+
python train_a2c.py
20+
```
21+
1622
If you want to evaluate the agent during training and save the best model and save checkpoints, try to train with callbacks:
1723

1824
```shell

examples/cartpole/a2c.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
seed: 0
2+
run_dir: ./run_results/
3+
wandb_entity: openrl-lab

examples/cartpole/train_a2c.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
""""""
2+
import numpy as np
3+
import torch
4+
5+
from openrl.configs.config import create_config_parser
6+
from openrl.envs.common import make
7+
from openrl.modules.common import A2CNet as Net
8+
from openrl.runners.common import A2CAgent as Agent
9+
10+
11+
def train():
12+
# create the neural network
13+
cfg_parser = create_config_parser()
14+
cfg = cfg_parser.parse_args(["--config", "a2c.yaml"])
15+
16+
# create environment, set environment parallelism to 9
17+
env = make("CartPole-v1", env_num=9)
18+
19+
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
20+
# initialize the trainer
21+
agent = Agent(net, use_wandb=False, project_name="CartPole-v1")
22+
# start training, set total number of training steps to 20000
23+
agent.train(total_time_steps=30000)
24+
25+
env.close()
26+
27+
agent.save("./a2c_agent")
28+
return agent
29+
30+
31+
def evaluation():
32+
# begin to test
33+
34+
cfg_parser = create_config_parser()
35+
cfg = cfg_parser.parse_args(["--config", "a2c.yaml"])
36+
37+
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
38+
render_mode = "group_human"
39+
render_mode = None
40+
env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True)
41+
42+
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
43+
# initialize the trainer
44+
agent = Agent(
45+
net,
46+
)
47+
agent.load("./a2c_agent")
48+
# The trained agent sets up the interactive environment it needs.
49+
agent.set_env(env)
50+
# Initialize the environment and get initial observations and environmental information.
51+
obs, info = env.reset()
52+
done = False
53+
54+
total_step = 0
55+
total_reward = 0.0
56+
while not np.any(done):
57+
# Based on environmental observation input, predict next action.
58+
action, _ = agent.act(obs, deterministic=True)
59+
obs, r, done, info = env.step(action)
60+
total_step += 1
61+
if total_step % 50 == 0:
62+
print(f"{total_step}: reward:{np.mean(r)}")
63+
env.close()
64+
print("total step:", total_step)
65+
print("total reward:", total_reward)
66+
67+
68+
if __name__ == "__main__":
69+
train()
70+
evaluation()

openrl/algorithms/a2c.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
from typing import Union
19+
20+
import numpy as np
21+
import torch
22+
from torch.nn.parallel import DistributedDataParallel
23+
24+
from openrl.algorithms.ppo import PPOAlgorithm
25+
26+
27+
class A2CAlgorithm(PPOAlgorithm):
28+
def __init__(
29+
self,
30+
cfg,
31+
init_module,
32+
agent_num: int = 1,
33+
device: Union[str, torch.device] = "cpu",
34+
) -> None:
35+
super(A2CAlgorithm, self).__init__(cfg, init_module, agent_num, device)
36+
37+
self.num_mini_batch = 1
38+
39+
def prepare_loss(
40+
self,
41+
critic_obs_batch,
42+
obs_batch,
43+
rnn_states_batch,
44+
rnn_states_critic_batch,
45+
actions_batch,
46+
masks_batch,
47+
action_masks_batch,
48+
old_action_log_probs_batch,
49+
adv_targ,
50+
value_preds_batch,
51+
return_batch,
52+
active_masks_batch,
53+
turn_on,
54+
):
55+
if self.use_joint_action_loss:
56+
critic_obs_batch = self.to_single_np(critic_obs_batch)
57+
rnn_states_critic_batch = self.to_single_np(rnn_states_critic_batch)
58+
critic_masks_batch = self.to_single_np(masks_batch)
59+
value_preds_batch = self.to_single_np(value_preds_batch)
60+
return_batch = self.to_single_np(return_batch)
61+
adv_targ = adv_targ.reshape(-1, self.agent_num, 1)
62+
adv_targ = adv_targ[:, 0, :]
63+
64+
else:
65+
critic_masks_batch = masks_batch
66+
67+
(
68+
values,
69+
action_log_probs,
70+
dist_entropy,
71+
policy_values,
72+
) = self.algo_module.evaluate_actions(
73+
critic_obs_batch,
74+
obs_batch,
75+
rnn_states_batch,
76+
rnn_states_critic_batch,
77+
actions_batch,
78+
masks_batch,
79+
action_masks_batch,
80+
active_masks_batch,
81+
critic_masks_batch=critic_masks_batch,
82+
)
83+
84+
if self.use_joint_action_loss:
85+
active_masks_batch = active_masks_batch.reshape(-1, self.agent_num, 1)
86+
active_masks_batch = active_masks_batch[:, 0, :]
87+
88+
policy_gradient_loss = -adv_targ.detach() * action_log_probs
89+
if self._use_policy_active_masks:
90+
policy_action_loss = (
91+
torch.sum(policy_gradient_loss, dim=-1, keepdim=True)
92+
* active_masks_batch
93+
).sum() / active_masks_batch.sum()
94+
else:
95+
policy_action_loss = torch.sum(
96+
policy_gradient_loss, dim=-1, keepdim=True
97+
).mean()
98+
99+
if self._use_policy_vhead:
100+
if isinstance(self.algo_module.models["actor"], DistributedDataParallel):
101+
policy_value_normalizer = self.algo_module.models[
102+
"actor"
103+
].module.value_normalizer
104+
else:
105+
policy_value_normalizer = self.algo_module.models[
106+
"actor"
107+
].value_normalizer
108+
policy_value_loss = self.cal_value_loss(
109+
policy_value_normalizer,
110+
policy_values,
111+
value_preds_batch,
112+
return_batch,
113+
active_masks_batch,
114+
)
115+
policy_loss = (
116+
policy_action_loss + policy_value_loss * self.policy_value_loss_coef
117+
)
118+
else:
119+
policy_loss = policy_action_loss
120+
121+
# critic update
122+
if self._use_share_model:
123+
value_normalizer = self.algo_module.models["model"].value_normalizer
124+
elif isinstance(self.algo_module.models["critic"], DistributedDataParallel):
125+
value_normalizer = self.algo_module.models["critic"].module.value_normalizer
126+
else:
127+
value_normalizer = self.algo_module.get_critic_value_normalizer()
128+
value_loss = self.cal_value_loss(
129+
value_normalizer,
130+
values,
131+
value_preds_batch,
132+
return_batch,
133+
active_masks_batch,
134+
)
135+
136+
loss_list = self.construct_loss_list(
137+
policy_loss, dist_entropy, value_loss, turn_on
138+
)
139+
ratio = np.zeros(1)
140+
return loss_list, value_loss, policy_loss, dist_entropy, ratio
141+
142+
def train(self, buffer, turn_on: bool = True):
143+
train_info = super(A2CAlgorithm, self).train(buffer, turn_on)
144+
train_info.pop("ratio", None)
145+
return train_info

openrl/modules/common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .a2c_net import A2CNet
12
from .base_net import BaseNet
23
from .bc_net import BCNet
34
from .ddpg_net import DDPGNet
@@ -18,4 +19,5 @@
1819
"GAILNet",
1920
"BCNet",
2021
"SACNet",
22+
"A2CNet",
2123
]

openrl/modules/common/a2c_net.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
from openrl.modules.common.ppo_net import PPONet
19+
20+
21+
class A2CNet(PPONet):
22+
pass

openrl/runners/common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from openrl.runners.common.a2c_agent import A2CAgent
12
from openrl.runners.common.bc_agent import BCAgent
23
from openrl.runners.common.chat_agent import Chat6BAgent, ChatAgent
34
from openrl.runners.common.ddpg_agent import DDPGAgent
@@ -19,4 +20,5 @@
1920
"GAILAgent",
2021
"BCAgent",
2122
"SACAgent",
23+
"A2CAgent",
2224
]

0 commit comments

Comments
 (0)