@@ -23,7 +23,7 @@ def __init__(
2323 self ,
2424 obs_shape : Union [int , SequenceType ],
2525 action_shape : Union [int , SequenceType , EasyDict ],
26- actor_head_type : str ,
26+ action_space : str ,
2727 twin_critic : bool = False ,
2828 actor_head_hidden_size : int = 64 ,
2929 actor_head_layer_num : int = 1 ,
@@ -39,7 +39,7 @@ def __init__(
3939 - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
4040 - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, ), \
4141 EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
42- - actor_head_type (:obj:`str`): Whether choose ``regression`` or ``reparameterization`` or ``hybrid`` .
42+ - action_space (:obj:`str`): Whether choose ``regression`` or ``reparameterization`` or ``hybrid`` .
4343 - twin_critic (:obj:`bool`): Whether include twin critic.
4444 - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
4545 - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
@@ -56,9 +56,9 @@ def __init__(
5656 obs_shape : int = squeeze (obs_shape )
5757 action_shape = squeeze (action_shape )
5858 self .action_shape = action_shape
59- self .actor_head_type = actor_head_type
60- assert self .actor_head_type in ['regression' , 'reparameterization' , 'hybrid' ]
61- if self .actor_head_type == 'regression' : # DDPG, TD3
59+ self .action_space = action_space
60+ assert self .action_space in ['regression' , 'reparameterization' , 'hybrid' ]
61+ if self .action_space == 'regression' : # DDPG, TD3
6262 self .actor = nn .Sequential (
6363 nn .Linear (obs_shape , actor_head_hidden_size ), activation ,
6464 RegressionHead (
@@ -70,7 +70,7 @@ def __init__(
7070 norm_type = norm_type
7171 )
7272 )
73- elif self .actor_head_type == 'reparameterization' : # SAC
73+ elif self .action_space == 'reparameterization' : # SAC
7474 self .actor = nn .Sequential (
7575 nn .Linear (obs_shape , actor_head_hidden_size ), activation ,
7676 ReparameterizationHead (
@@ -82,7 +82,7 @@ def __init__(
8282 norm_type = norm_type
8383 )
8484 )
85- elif self .actor_head_type == 'hybrid' : # PADDPG
85+ elif self .action_space == 'hybrid' : # PADDPG
8686 # hybrid action space: action_type(discrete) + action_args(continuous),
8787 # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])}
8888 action_shape .action_args_shape = squeeze (action_shape .action_args_shape )
@@ -110,7 +110,7 @@ def __init__(
110110 )
111111 self .actor = nn .ModuleList ([actor_action_type , actor_action_args ])
112112 self .twin_critic = twin_critic
113- if self .actor_head_type == 'hybrid' :
113+ if self .action_space == 'hybrid' :
114114 critic_input_size = obs_shape + action_shape .action_type_shape + action_shape .action_args_shape
115115 else :
116116 critic_input_size = obs_shape + action_shape
@@ -194,7 +194,7 @@ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
194194
195195 Critic Examples:
196196 >>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)}
197- >>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type ='regression')
197+ >>> model = QAC(obs_shape=(N, ),action_shape=1,action_space ='regression')
198198 >>> model(inputs, mode='compute_critic')['q_value'] # q value
199199 tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
200200
@@ -245,13 +245,13 @@ def compute_actor(self, inputs: torch.Tensor) -> Dict:
245245 >>> actor_outputs['logit'][1].shape # sigma
246246 >>> torch.Size([4, 64])
247247 """
248- if self .actor_head_type == 'regression' :
248+ if self .action_space == 'regression' :
249249 x = self .actor (inputs )
250250 return {'action' : x ['pred' ]}
251- elif self .actor_head_type == 'reparameterization' :
251+ elif self .action_space == 'reparameterization' :
252252 x = self .actor (inputs )
253253 return {'logit' : [x ['mu' ], x ['sigma' ]]}
254- elif self .actor_head_type == 'hybrid' :
254+ elif self .action_space == 'hybrid' :
255255 logit = self .actor [0 ](inputs )
256256 action_args = self .actor [1 ](inputs )
257257 return {'logit' : logit ['logit' ], 'action_args' : action_args ['pred' ]}
@@ -284,14 +284,14 @@ def compute_critic(self, inputs: Dict) -> Dict:
284284
285285 Examples:
286286 >>> inputs = {'obs': torch.randn(4, N), 'action': torch.randn(4, 1)}
287- >>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type ='regression')
287+ >>> model = QAC(obs_shape=(N, ),action_shape=1,action_space ='regression')
288288 >>> model(inputs, mode='compute_critic')['q_value'] # q value
289289 >>> tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
290290 """
291291
292292 obs , action = inputs ['obs' ], inputs ['action' ]
293293 assert len (obs .shape ) == 2
294- if self .actor_head_type == 'hybrid' :
294+ if self .action_space == 'hybrid' :
295295 action_type_logit = inputs ['logit' ]
296296 action_type_logit = torch .softmax (action_type_logit , dim = - 1 )
297297 action_args = action ['action_args' ]
0 commit comments