@@ -110,7 +110,7 @@ def forward(self, x_in: Tuple):
110110
111111 """ Here, we apply the FC network only on the last output point (at the last time step)
112112 """
113- predictions = hidden [:, - 1 , :]
113+ predictions = hidden [- 1 , : , :]
114114 predictions = self .fc (predictions )
115115 predictions = predictions .view (
116116 batch_size , self .out_len , self .target_size , self .nr_params
@@ -130,18 +130,20 @@ def _rnn_sequence(
130130 ):
131131
132132 modules = []
133+ is_lstm = self .name == "LSTM"
133134 for i in range (num_layers ):
134135 input = input_size if (i == 0 ) else hidden_dim
135136 is_last = i == num_layers - 1
136137 rnn = getattr (nn , name )(input , hidden_dim , 1 , batch_first = True )
137138
138139 modules .append (rnn )
139- modules .append (ExtractRnnOutput ())
140-
140+ modules .append (ExtractRnnOutput (not is_last , is_lstm ))
141+ modules . append ( nn . Dropout ( dropout ))
141142 if normalization :
142143 modules .append (self ._normalization_layer (normalization , hidden_dim ))
143- if is_last : # pytorch RNNs don't have dropout applied on the last layer
144+ if not is_last : # pytorch RNNs don't have dropout applied on the last layer
144145 modules .append (nn .Dropout (dropout ))
146+
145147 return nn .Sequential (* modules )
146148
147149 def _fc_layer (
0 commit comments