Skip to content

Commit e89c22e

Browse files
committed
pass hidden state to fc layer
1 parent 80c9e10 commit e89c22e

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

darts/models/forecasting/block_rnn_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def forward(self, x_in: Tuple):
110110

111111
""" Here, we apply the FC network only on the last output point (at the last time step)
112112
"""
113-
predictions = hidden[:, -1, :]
113+
predictions = hidden[-1, :, :]
114114
predictions = self.fc(predictions)
115115
predictions = predictions.view(
116116
batch_size, self.out_len, self.target_size, self.nr_params
@@ -130,18 +130,20 @@ def _rnn_sequence(
130130
):
131131

132132
modules = []
133+
is_lstm = self.name == "LSTM"
133134
for i in range(num_layers):
134135
input = input_size if (i == 0) else hidden_dim
135136
is_last = i == num_layers - 1
136137
rnn = getattr(nn, name)(input, hidden_dim, 1, batch_first=True)
137138

138139
modules.append(rnn)
139-
modules.append(ExtractRnnOutput())
140-
140+
modules.append(ExtractRnnOutput(not is_last, is_lstm))
141+
modules.append(nn.Dropout(dropout))
141142
if normalization:
142143
modules.append(self._normalization_layer(normalization, hidden_dim))
143-
if is_last: # pytorch RNNs don't have dropout applied on the last layer
144+
if not is_last: # pytorch RNNs don't have dropout applied on the last layer
144145
modules.append(nn.Dropout(dropout))
146+
145147
return nn.Sequential(*modules)
146148

147149
def _fc_layer(

darts/utils/torch.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,15 @@ def _reshape_input(self, x):
131131

132132

133133
class ExtractRnnOutput(nn.Module):
134-
def __init__(self) -> None:
134+
def __init__(self, is_output, is_lstm) -> None:
135+
self.is_output = is_output
136+
self.is_lstm = is_lstm
135137
super().__init__()
136138

137139
def forward(self, input):
138-
output, _ = input
139-
return output
140+
output, hidden = input
141+
if self.is_output:
142+
return output
143+
if self.is_lstm:
144+
return hidden[0]
145+
return hidden

0 commit comments

Comments
 (0)