Skip to content

Commit c117a00

Browse files
authored
Merge pull request #293 from wgawmy/beta
Initial Attempt at Koopa Algorithm
2 parents 61ad2bc + 64549ac commit c117a00

File tree

7 files changed

+447
-0
lines changed

7 files changed

+447
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .arch import Koopa
2+
from .config.koopa_config import KoopaConfig
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .koopa_arch import Koopa
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import torch
2+
from torch import nn
3+
from .layers import FourierFilter, MLP, TimeInvKP, TimeVarKP
4+
from ..config.koopa_config import KoopaConfig
5+
6+
class Koopa(nn.Module):
7+
"""
8+
Paper: Koopa: Learning Non-stationary Time Series Dynamics with Koopman Predictors
9+
Official Code: https://github.com/thuml/Koopa
10+
Link: https://arxiv.org/abs/2305.18803
11+
Venue: NeurIPS 2024
12+
Task: Long-term Time Series Forecasting
13+
"""
14+
def __init__(self, config: KoopaConfig):
15+
super().__init__()
16+
self.mask_spectrum = None
17+
self.amps = None
18+
self.alpha = config.alpha
19+
self.enc_in = config.enc_in
20+
self.input_len = config.input_len
21+
self.output_len = config.output_len
22+
self.seg_len = config.seg_len
23+
self.num_blocks = config.num_blocks
24+
self.dynamic_dim = config.dynamic_dim
25+
self.hidden_dim = config.hidden_dim
26+
self.hidden_layers = config.hidden_layers
27+
self.multistep = config.multistep
28+
self.disentanglement = FourierFilter(self.mask_spectrum)
29+
# shared encoder/decoder to make koopman embedding consistent
30+
self.time_inv_encoder = MLP(f_in=self.input_len, f_out=self.dynamic_dim, activation='relu',
31+
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
32+
# fix: use self.output_len instead of non-existent attribute
33+
self.time_inv_decoder = MLP(f_in=self.dynamic_dim, f_out=self.output_len, activation='relu',
34+
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
35+
# separate module lists for time-invariant and time-variant KPs
36+
self.time_inv_kps = nn.ModuleList([
37+
TimeInvKP(input_len=self.input_len,
38+
pred_len=self.output_len,
39+
dynamic_dim=self.dynamic_dim,
40+
encoder=self.time_inv_encoder,
41+
decoder=self.time_inv_decoder)
42+
for _ in range(self.num_blocks)])
43+
44+
# shared encoder/decoder to make koopman embedding consistent
45+
self.time_var_encoder = MLP(f_in=self.seg_len * self.enc_in, f_out=self.dynamic_dim, activation='tanh',
46+
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
47+
self.time_var_decoder = MLP(f_in=self.dynamic_dim, f_out=self.seg_len * self.enc_in, activation='tanh',
48+
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
49+
self.time_var_kps = nn.ModuleList([
50+
TimeVarKP(enc_in=self.enc_in,
51+
input_len=self.input_len,
52+
pred_len=self.output_len,
53+
seg_len=self.seg_len,
54+
dynamic_dim=self.dynamic_dim,
55+
encoder=self.time_var_encoder,
56+
decoder=self.time_var_decoder,
57+
multistep=self.multistep)
58+
for _ in range(self.num_blocks)])
59+
def forward(self, inputs: torch.Tensor = None) -> torch.Tensor:
60+
"""
61+
Single-`inputs` forward to match runner API.
62+
63+
Args:
64+
inputs (torch.Tensor): history input with shape [B, L, C] or [B, L, C, 1]
65+
66+
Returns:
67+
torch.Tensor: prediction tensor with shape [B, output_len, num_features] (may include trailing feature dim)
68+
"""
69+
history_data = inputs
70+
if history_data is None:
71+
raise AssertionError('Model forward requires inputs(history data) as first argument.')
72+
73+
if history_data.dim() == 4:
74+
x_enc = history_data[..., 0]
75+
elif history_data.dim() == 3:
76+
x_enc = history_data
77+
else:
78+
raise ValueError(f'Unsupported inputs shape: {tuple(history_data.shape)}')
79+
80+
mean_enc = x_enc.mean(1, keepdim=True).detach()
81+
x_enc = x_enc - mean_enc
82+
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
83+
x_enc = x_enc / std_enc
84+
if self.disentanglement is None:
85+
raise ValueError('Koopa mask_spectrum is not initialized.')
86+
87+
residual, forecast = x_enc, None
88+
for i in range(self.num_blocks):
89+
time_var_input, time_inv_input = self.disentanglement(residual)
90+
time_inv_output = self.time_inv_kps[i](time_inv_input)
91+
time_var_backcast, time_var_output = self.time_var_kps[i](time_var_input)
92+
residual = residual - time_var_backcast
93+
if forecast is None:
94+
forecast = time_inv_output + time_var_output
95+
else:
96+
forecast += (time_inv_output + time_var_output)
97+
res = forecast * std_enc + mean_enc
98+
if history_data is not None and history_data.dim() == 4 and res.dim() == 3:
99+
res = res.unsqueeze(-1)
100+
return res
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import math
2+
import torch
3+
from torch import nn
4+
class FourierFilter(nn.Module):
5+
"""
6+
Fourier Filter: to time-variant and time-invariant term
7+
"""
8+
def __init__(self, mask_spectrum):
9+
super().__init__()
10+
self.mask_spectrum = mask_spectrum
11+
12+
def forward(self, x):
13+
xf = torch.fft.rfft(x, dim=1)
14+
mask = torch.ones_like(xf)
15+
mask[:, self.mask_spectrum, :] = 0
16+
x_var = torch.fft.irfft(xf * mask, dim=1)
17+
x_inv = x - x_var
18+
19+
return x_var, x_inv
20+
21+
22+
class MLP(nn.Module):
23+
'''
24+
Multilayer perceptron to encode/decode high dimension representation of sequential data
25+
'''
26+
27+
def __init__(self,
28+
f_in,
29+
f_out,
30+
hidden_dim=128,
31+
hidden_layers=2,
32+
dropout=0.05,
33+
activation='tanh'):
34+
super().__init__()
35+
self.f_in = f_in
36+
self.f_out = f_out
37+
self.hidden_dim = hidden_dim
38+
self.hidden_layers = hidden_layers
39+
self.dropout = dropout
40+
if activation == 'relu':
41+
self.activation = nn.ReLU()
42+
elif activation == 'tanh':
43+
self.activation = nn.Tanh()
44+
else:
45+
raise NotImplementedError
46+
47+
layers = [nn.Linear(self.f_in, self.hidden_dim),
48+
self.activation, nn.Dropout(self.dropout)]
49+
for _ in range(self.hidden_layers - 2):
50+
layers += [nn.Linear(self.hidden_dim, self.hidden_dim),
51+
self.activation, nn.Dropout(dropout)]
52+
53+
layers += [nn.Linear(hidden_dim, f_out)]
54+
self.layers = nn.Sequential(*layers)
55+
56+
def forward(self, x):
57+
# x: B x S x f_in
58+
# y: B x S x f_out
59+
y = self.layers(x)
60+
return y
61+
62+
63+
class KPLayer(nn.Module):
64+
"""
65+
A demonstration of finding one step transition of linear system by DMD iteratively
66+
"""
67+
68+
def __init__(self):
69+
super().__init__()
70+
71+
self.K = None # B E E
72+
73+
def one_step_forward(self, z, return_rec=False):
74+
B, input_len, _ = z.shape
75+
assert input_len > 1, 'snapshots number should be larger than 1'
76+
x, y = z[:, :-1], z[:, 1:]
77+
78+
# solve linear system
79+
self.K = torch.linalg.lstsq(x, y).solution # B E E
80+
if torch.isnan(self.K).any():
81+
print('Encounter K with nan, replace K by identity matrix')
82+
self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1)
83+
84+
z_pred = torch.bmm(z[:, -1:], self.K)
85+
if return_rec:
86+
z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1)
87+
return z_rec, z_pred
88+
89+
return z_pred
90+
91+
def forward(self, z, pred_len=1):
92+
assert pred_len >= 1, 'prediction length should not be less than 1'
93+
z_rec, z_pred = self.one_step_forward(z, return_rec=True)
94+
z_preds = [z_pred]
95+
for _ in range(1, pred_len):
96+
z_pred = torch.bmm(z_pred, self.K)
97+
z_preds.append(z_pred)
98+
z_preds = torch.cat(z_preds, dim=1)
99+
return z_rec, z_preds
100+
101+
102+
class KPLayerApprox(nn.Module):
103+
"""
104+
Find koopman transition of linear system by DMD with multistep K approximation
105+
"""
106+
107+
def __init__(self):
108+
super().__init__()
109+
110+
self.K = None # B E E
111+
self.K_step = None # B E E
112+
113+
def forward(self, z, pred_len=1):
114+
# z: B L E, koopman invariance space representation
115+
# z_rec: B L E, reconstructed representation
116+
# z_pred: B S E, forecasting representation
117+
B, input_len, _ = z.shape
118+
assert input_len > 1, 'snapshots number should be larger than 1'
119+
x, y = z[:, :-1], z[:, 1:]
120+
121+
# solve linear system
122+
self.K = torch.linalg.lstsq(x, y).solution # B E E
123+
124+
if torch.isnan(self.K).any():
125+
print('Encounter K with nan, replace K by identity matrix')
126+
self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1)
127+
128+
z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) # B L E
129+
130+
if pred_len <= input_len:
131+
self.K_step = torch.linalg.matrix_power(self.K, pred_len)
132+
if torch.isnan(self.K_step).any():
133+
print('Encounter multistep K with nan, replace it by identity matrix')
134+
self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1)
135+
z_pred = torch.bmm(z[:, -pred_len:, :], self.K_step)
136+
else:
137+
self.K_step = torch.linalg.matrix_power(self.K, input_len)
138+
if torch.isnan(self.K_step).any():
139+
print('Encounter multistep K with nan, replace it by identity matrix')
140+
self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1)
141+
temp_z_pred, all_pred = z, []
142+
for _ in range(math.ceil(pred_len / input_len)):
143+
temp_z_pred = torch.bmm(temp_z_pred, self.K_step)
144+
all_pred.append(temp_z_pred)
145+
z_pred = torch.cat(all_pred, dim=1)[:, :pred_len, :]
146+
147+
return z_rec, z_pred
148+
149+
150+
class TimeVarKP(nn.Module):
151+
"""
152+
Koopman Predictor with DMD (analysitical solution of Koopman operator)
153+
Utilize local variations within individual sliding window to predict the future of time-variant term
154+
"""
155+
156+
def __init__(self,
157+
enc_in=8,
158+
input_len=96,
159+
pred_len=96,
160+
seg_len=24,
161+
dynamic_dim=128,
162+
encoder=None,
163+
decoder=None,
164+
multistep=False,
165+
):
166+
super().__init__()
167+
self.input_len = input_len
168+
self.pred_len = pred_len
169+
self.enc_in = enc_in
170+
self.seg_len = seg_len
171+
self.dynamic_dim = dynamic_dim
172+
self.multistep = multistep
173+
self.encoder, self.decoder = encoder, decoder
174+
self.freq = math.ceil(self.input_len / self.seg_len) # segment number of input
175+
self.step = math.ceil(self.pred_len / self.seg_len) # segment number of output
176+
self.padding_len = self.seg_len * self.freq - self.input_len
177+
# Approximate mulitstep K by KPLayerApprox when pred_len is large
178+
self.dynamics = KPLayerApprox() if self.multistep else KPLayer()
179+
180+
def forward(self, x):
181+
B, L, _ = x.shape
182+
183+
res = torch.cat((x[:, L - self.padding_len:, :], x), dim=1)
184+
185+
res = res.chunk(self.freq, dim=1) # F x B P C, P means seg_len
186+
res = torch.stack(res, dim=1).reshape(B, self.freq, -1) # B F PC
187+
188+
res = self.encoder(res) # B F H
189+
x_rec, x_pred = self.dynamics(res, self.step) # B F H, B S H
190+
191+
x_rec = self.decoder(x_rec) # B F PC
192+
x_rec = x_rec.reshape(B, self.freq, self.seg_len, self.enc_in)
193+
x_rec = x_rec.reshape(B, -1, self.enc_in)[:, :self.input_len, :] # B L C
194+
195+
x_pred = self.decoder(x_pred) # B S PC
196+
x_pred = x_pred.reshape(B, self.step, self.seg_len, self.enc_in)
197+
x_pred = x_pred.reshape(B, -1, self.enc_in)[:, :self.pred_len, :] # B S C
198+
199+
return x_rec, x_pred
200+
201+
202+
class TimeInvKP(nn.Module):
203+
"""
204+
Koopman Predictor with learnable Koopman operator
205+
Utilize lookback and forecast window snapshots to predict the future of time-invariant term
206+
"""
207+
208+
def __init__(self,
209+
input_len=96,
210+
pred_len=96,
211+
dynamic_dim=128,
212+
encoder=None,
213+
decoder=None):
214+
super().__init__()
215+
self.dynamic_dim = dynamic_dim
216+
self.input_len = input_len
217+
self.pred_len = pred_len
218+
self.encoder = encoder
219+
self.decoder = decoder
220+
221+
K_init = torch.randn(self.dynamic_dim, self.dynamic_dim)
222+
U, _, V = torch.svd(K_init) # stable initialization
223+
self.K = nn.Linear(self.dynamic_dim, self.dynamic_dim, bias=False)
224+
self.K.weight.data = torch.mm(U, V.t())
225+
226+
def forward(self, x):
227+
# x: B L C
228+
res = x.transpose(1, 2) # B C L
229+
res = self.encoder(res) # B C H
230+
res = self.K(res) # B C H
231+
res = self.decoder(res) # B C S
232+
res = res.transpose(1, 2) # B S C
233+
234+
return res
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from dataclasses import dataclass, field
2+
3+
from basicts.configs import BasicTSModelConfig
4+
5+
@dataclass
6+
class KoopaConfig(BasicTSModelConfig):
7+
"""
8+
Config class for Koopa model.
9+
"""
10+
alpha: float = field(default=0.2, metadata={"help": "Scaling coefficient."})
11+
enc_in: int = field(default=7, metadata={"help": "Input feature dimension."})
12+
input_len: int = field(default=None, metadata={"help": "Input sequence length."})
13+
output_len: int = field(default=None, metadata={"help": "Prediction length."})
14+
seg_len: int = field(default=48, metadata={"help": "Segment length. Recommended: e.g., 24 for hourly data."})
15+
num_blocks: int = field(default=3, metadata={"help": "Number of blocks."})
16+
dynamic_dim: int = field(default=64, metadata={"help": "Dynamic feature dimension. Must be > 0."})
17+
hidden_dim: int = field(default=64, metadata={"help": "Hidden dimension."})
18+
hidden_layers: int = field(default=2, metadata={"help": "Number of hidden layers (>=2 recommended)."})
19+
multistep: bool = field(default=False, metadata={"help": "Whether to use multistep forecasting."})

0 commit comments

Comments
 (0)