|
| 1 | +import math |
| 2 | +import torch |
| 3 | +from torch import nn |
| 4 | +class FourierFilter(nn.Module): |
| 5 | + """ |
| 6 | + Fourier Filter: to time-variant and time-invariant term |
| 7 | + """ |
| 8 | + def __init__(self, mask_spectrum): |
| 9 | + super().__init__() |
| 10 | + self.mask_spectrum = mask_spectrum |
| 11 | + |
| 12 | + def forward(self, x): |
| 13 | + xf = torch.fft.rfft(x, dim=1) |
| 14 | + mask = torch.ones_like(xf) |
| 15 | + mask[:, self.mask_spectrum, :] = 0 |
| 16 | + x_var = torch.fft.irfft(xf * mask, dim=1) |
| 17 | + x_inv = x - x_var |
| 18 | + |
| 19 | + return x_var, x_inv |
| 20 | + |
| 21 | + |
| 22 | +class MLP(nn.Module): |
| 23 | + ''' |
| 24 | + Multilayer perceptron to encode/decode high dimension representation of sequential data |
| 25 | + ''' |
| 26 | + |
| 27 | + def __init__(self, |
| 28 | + f_in, |
| 29 | + f_out, |
| 30 | + hidden_dim=128, |
| 31 | + hidden_layers=2, |
| 32 | + dropout=0.05, |
| 33 | + activation='tanh'): |
| 34 | + super().__init__() |
| 35 | + self.f_in = f_in |
| 36 | + self.f_out = f_out |
| 37 | + self.hidden_dim = hidden_dim |
| 38 | + self.hidden_layers = hidden_layers |
| 39 | + self.dropout = dropout |
| 40 | + if activation == 'relu': |
| 41 | + self.activation = nn.ReLU() |
| 42 | + elif activation == 'tanh': |
| 43 | + self.activation = nn.Tanh() |
| 44 | + else: |
| 45 | + raise NotImplementedError |
| 46 | + |
| 47 | + layers = [nn.Linear(self.f_in, self.hidden_dim), |
| 48 | + self.activation, nn.Dropout(self.dropout)] |
| 49 | + for _ in range(self.hidden_layers - 2): |
| 50 | + layers += [nn.Linear(self.hidden_dim, self.hidden_dim), |
| 51 | + self.activation, nn.Dropout(dropout)] |
| 52 | + |
| 53 | + layers += [nn.Linear(hidden_dim, f_out)] |
| 54 | + self.layers = nn.Sequential(*layers) |
| 55 | + |
| 56 | + def forward(self, x): |
| 57 | + # x: B x S x f_in |
| 58 | + # y: B x S x f_out |
| 59 | + y = self.layers(x) |
| 60 | + return y |
| 61 | + |
| 62 | + |
| 63 | +class KPLayer(nn.Module): |
| 64 | + """ |
| 65 | + A demonstration of finding one step transition of linear system by DMD iteratively |
| 66 | + """ |
| 67 | + |
| 68 | + def __init__(self): |
| 69 | + super().__init__() |
| 70 | + |
| 71 | + self.K = None # B E E |
| 72 | + |
| 73 | + def one_step_forward(self, z, return_rec=False): |
| 74 | + B, input_len, _ = z.shape |
| 75 | + assert input_len > 1, 'snapshots number should be larger than 1' |
| 76 | + x, y = z[:, :-1], z[:, 1:] |
| 77 | + |
| 78 | + # solve linear system |
| 79 | + self.K = torch.linalg.lstsq(x, y).solution # B E E |
| 80 | + if torch.isnan(self.K).any(): |
| 81 | + print('Encounter K with nan, replace K by identity matrix') |
| 82 | + self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1) |
| 83 | + |
| 84 | + z_pred = torch.bmm(z[:, -1:], self.K) |
| 85 | + if return_rec: |
| 86 | + z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) |
| 87 | + return z_rec, z_pred |
| 88 | + |
| 89 | + return z_pred |
| 90 | + |
| 91 | + def forward(self, z, pred_len=1): |
| 92 | + assert pred_len >= 1, 'prediction length should not be less than 1' |
| 93 | + z_rec, z_pred = self.one_step_forward(z, return_rec=True) |
| 94 | + z_preds = [z_pred] |
| 95 | + for _ in range(1, pred_len): |
| 96 | + z_pred = torch.bmm(z_pred, self.K) |
| 97 | + z_preds.append(z_pred) |
| 98 | + z_preds = torch.cat(z_preds, dim=1) |
| 99 | + return z_rec, z_preds |
| 100 | + |
| 101 | + |
| 102 | +class KPLayerApprox(nn.Module): |
| 103 | + """ |
| 104 | + Find koopman transition of linear system by DMD with multistep K approximation |
| 105 | + """ |
| 106 | + |
| 107 | + def __init__(self): |
| 108 | + super().__init__() |
| 109 | + |
| 110 | + self.K = None # B E E |
| 111 | + self.K_step = None # B E E |
| 112 | + |
| 113 | + def forward(self, z, pred_len=1): |
| 114 | + # z: B L E, koopman invariance space representation |
| 115 | + # z_rec: B L E, reconstructed representation |
| 116 | + # z_pred: B S E, forecasting representation |
| 117 | + B, input_len, _ = z.shape |
| 118 | + assert input_len > 1, 'snapshots number should be larger than 1' |
| 119 | + x, y = z[:, :-1], z[:, 1:] |
| 120 | + |
| 121 | + # solve linear system |
| 122 | + self.K = torch.linalg.lstsq(x, y).solution # B E E |
| 123 | + |
| 124 | + if torch.isnan(self.K).any(): |
| 125 | + print('Encounter K with nan, replace K by identity matrix') |
| 126 | + self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1) |
| 127 | + |
| 128 | + z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) # B L E |
| 129 | + |
| 130 | + if pred_len <= input_len: |
| 131 | + self.K_step = torch.linalg.matrix_power(self.K, pred_len) |
| 132 | + if torch.isnan(self.K_step).any(): |
| 133 | + print('Encounter multistep K with nan, replace it by identity matrix') |
| 134 | + self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1) |
| 135 | + z_pred = torch.bmm(z[:, -pred_len:, :], self.K_step) |
| 136 | + else: |
| 137 | + self.K_step = torch.linalg.matrix_power(self.K, input_len) |
| 138 | + if torch.isnan(self.K_step).any(): |
| 139 | + print('Encounter multistep K with nan, replace it by identity matrix') |
| 140 | + self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1) |
| 141 | + temp_z_pred, all_pred = z, [] |
| 142 | + for _ in range(math.ceil(pred_len / input_len)): |
| 143 | + temp_z_pred = torch.bmm(temp_z_pred, self.K_step) |
| 144 | + all_pred.append(temp_z_pred) |
| 145 | + z_pred = torch.cat(all_pred, dim=1)[:, :pred_len, :] |
| 146 | + |
| 147 | + return z_rec, z_pred |
| 148 | + |
| 149 | + |
| 150 | +class TimeVarKP(nn.Module): |
| 151 | + """ |
| 152 | + Koopman Predictor with DMD (analysitical solution of Koopman operator) |
| 153 | + Utilize local variations within individual sliding window to predict the future of time-variant term |
| 154 | + """ |
| 155 | + |
| 156 | + def __init__(self, |
| 157 | + enc_in=8, |
| 158 | + input_len=96, |
| 159 | + pred_len=96, |
| 160 | + seg_len=24, |
| 161 | + dynamic_dim=128, |
| 162 | + encoder=None, |
| 163 | + decoder=None, |
| 164 | + multistep=False, |
| 165 | + ): |
| 166 | + super().__init__() |
| 167 | + self.input_len = input_len |
| 168 | + self.pred_len = pred_len |
| 169 | + self.enc_in = enc_in |
| 170 | + self.seg_len = seg_len |
| 171 | + self.dynamic_dim = dynamic_dim |
| 172 | + self.multistep = multistep |
| 173 | + self.encoder, self.decoder = encoder, decoder |
| 174 | + self.freq = math.ceil(self.input_len / self.seg_len) # segment number of input |
| 175 | + self.step = math.ceil(self.pred_len / self.seg_len) # segment number of output |
| 176 | + self.padding_len = self.seg_len * self.freq - self.input_len |
| 177 | + # Approximate mulitstep K by KPLayerApprox when pred_len is large |
| 178 | + self.dynamics = KPLayerApprox() if self.multistep else KPLayer() |
| 179 | + |
| 180 | + def forward(self, x): |
| 181 | + B, L, _ = x.shape |
| 182 | + |
| 183 | + res = torch.cat((x[:, L - self.padding_len:, :], x), dim=1) |
| 184 | + |
| 185 | + res = res.chunk(self.freq, dim=1) # F x B P C, P means seg_len |
| 186 | + res = torch.stack(res, dim=1).reshape(B, self.freq, -1) # B F PC |
| 187 | + |
| 188 | + res = self.encoder(res) # B F H |
| 189 | + x_rec, x_pred = self.dynamics(res, self.step) # B F H, B S H |
| 190 | + |
| 191 | + x_rec = self.decoder(x_rec) # B F PC |
| 192 | + x_rec = x_rec.reshape(B, self.freq, self.seg_len, self.enc_in) |
| 193 | + x_rec = x_rec.reshape(B, -1, self.enc_in)[:, :self.input_len, :] # B L C |
| 194 | + |
| 195 | + x_pred = self.decoder(x_pred) # B S PC |
| 196 | + x_pred = x_pred.reshape(B, self.step, self.seg_len, self.enc_in) |
| 197 | + x_pred = x_pred.reshape(B, -1, self.enc_in)[:, :self.pred_len, :] # B S C |
| 198 | + |
| 199 | + return x_rec, x_pred |
| 200 | + |
| 201 | + |
| 202 | +class TimeInvKP(nn.Module): |
| 203 | + """ |
| 204 | + Koopman Predictor with learnable Koopman operator |
| 205 | + Utilize lookback and forecast window snapshots to predict the future of time-invariant term |
| 206 | + """ |
| 207 | + |
| 208 | + def __init__(self, |
| 209 | + input_len=96, |
| 210 | + pred_len=96, |
| 211 | + dynamic_dim=128, |
| 212 | + encoder=None, |
| 213 | + decoder=None): |
| 214 | + super().__init__() |
| 215 | + self.dynamic_dim = dynamic_dim |
| 216 | + self.input_len = input_len |
| 217 | + self.pred_len = pred_len |
| 218 | + self.encoder = encoder |
| 219 | + self.decoder = decoder |
| 220 | + |
| 221 | + K_init = torch.randn(self.dynamic_dim, self.dynamic_dim) |
| 222 | + U, _, V = torch.svd(K_init) # stable initialization |
| 223 | + self.K = nn.Linear(self.dynamic_dim, self.dynamic_dim, bias=False) |
| 224 | + self.K.weight.data = torch.mm(U, V.t()) |
| 225 | + |
| 226 | + def forward(self, x): |
| 227 | + # x: B L C |
| 228 | + res = x.transpose(1, 2) # B C L |
| 229 | + res = self.encoder(res) # B C H |
| 230 | + res = self.K(res) # B C H |
| 231 | + res = self.decoder(res) # B C S |
| 232 | + res = res.transpose(1, 2) # B S C |
| 233 | + |
| 234 | + return res |
0 commit comments