Skip to content

Commit eac4c97

Browse files
authored
Merge pull request #301 from duyifanict/beta
tests: 📏 fix smoke tests for koopa
2 parents c117a00 + 59d48b9 commit eac4c97

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+955
-16
lines changed

.github/workflows/python-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
pip install flake8 pytest
3030
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
3131
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
32+
if [ -f tests/requirements.txt ]; then pip install -r tests/requirements.txt; fi
3233
- name: Test with pytest
3334
run: |
3435
pytest

src/basicts/configs/base_config.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
from functools import partial
1010
from numbers import Number
1111
from types import FunctionType
12-
from typing import Callable, List, Literal, Optional, Tuple, Union
12+
from typing import Callable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union
1313

1414
import numpy as np
1515
import torch
1616
from easydict import EasyDict
1717
from torch.optim.lr_scheduler import LRScheduler
1818

19-
from basicts.runners.callback import BasicTSCallback
20-
from basicts.runners.taskflow import BasicTSTaskFlow
19+
# avoid circular imports
20+
if TYPE_CHECKING:
21+
from basicts.runners.callback import BasicTSCallback
22+
from basicts.runners.taskflow import BasicTSTaskFlow
2123

2224
from .model_config import BasicTSModelConfig
2325

@@ -36,8 +38,8 @@ class BasicTSConfig(EasyDict):
3638
model_config: BasicTSModelConfig
3739

3840
dataset_name: str
39-
taskflow: BasicTSTaskFlow
40-
callbacks: List[BasicTSCallback]
41+
taskflow: "BasicTSTaskFlow"
42+
callbacks: List["BasicTSCallback"]
4143

4244
############################## General Configuration ##############################
4345

src/basicts/models/Koopa/arch/koopa_arch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
22
from torch import nn
3-
from .layers import FourierFilter, MLP, TimeInvKP, TimeVarKP
3+
44
from ..config.koopa_config import KoopaConfig
5+
from .layers import MLP, FourierFilter, TimeInvKP, TimeVarKP
6+
57

68
class Koopa(nn.Module):
79
"""

src/basicts/models/Koopa/arch/layers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import math
2+
23
import torch
34
from torch import nn
5+
6+
47
class FourierFilter(nn.Module):
58
"""
69
Fourier Filter: to time-variant and time-invariant term

src/basicts/models/Koopa/config/koopa_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from basicts.configs import BasicTSModelConfig
44

5+
56
@dataclass
67
class KoopaConfig(BasicTSModelConfig):
78
"""

src/basicts/runners/callback/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .grad_accumulation import GradAccumulation
77
from .no_bp import NoBP
88
from .selective_learning import SelectiveLearning
9+
from .koopa_mask_init import KoopaMaskInitCallbackFullTrain
910

1011
__ALL__ = [
1112
'AddAuxiliaryLoss',
@@ -17,4 +18,5 @@
1718
'GradAccumulation',
1819
'NoBP',
1920
'SelectiveLearning',
21+
'KoopaMaskInitCallbackFullTrain',
2022
]

src/basicts/runners/callback/koopa_mask_init.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
2-
from basicts.runners.callback.callback import BasicTSCallback
32
from easytorch.utils import get_logger
3+
44
from basicts.models.Koopa.arch.layers import FourierFilter
5+
from basicts.runners.callback.callback import BasicTSCallback
6+
57
logger = get_logger("KoopaMaskInitCallbackFullTrain")
68

79
class KoopaMaskInitCallbackFullTrain(BasicTSCallback):

tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
einops
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# pylint: disable=wrong-import-position
2+
import os
3+
import sys
4+
5+
sys.path.append(os.path.abspath(__file__ + "/../../../src/"))
6+
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__))))
7+
8+
from basicts.configs import BasicTSForecastingConfig
9+
from basicts.launcher import BasicTSLauncher
10+
from basicts.models.Autoformer.arch.autoformer_arch import Autoformer
11+
from basicts.models.Autoformer.config.autoformer_config import AutoformerConfig
12+
13+
14+
def test_autoformer_smoke_test():
15+
output_len = 24
16+
input_len = 96
17+
autoformer_config = AutoformerConfig(
18+
input_len=input_len,
19+
output_len=output_len,
20+
label_len=input_len/2,
21+
num_features=7,
22+
use_timestamps=True,
23+
timestamp_sizes=[24, 7, 31, 366],
24+
)
25+
BasicTSLauncher.launch_training(
26+
BasicTSForecastingConfig(
27+
model=Autoformer,
28+
dataset_name="ETTh1",
29+
model_config=autoformer_config,
30+
gpus=None,
31+
num_epochs=1,
32+
input_len=input_len,
33+
output_len=output_len,
34+
lr=0.001,
35+
use_timestamps=True,
36+
)
37+
)
38+
39+
if __name__ == "__main__":
40+
test_autoformer_smoke_test()

tests/smoke_test/bug-test_hi.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# pylint: disable=wrong-import-position
2+
import os
3+
import sys
4+
5+
sys.path.append(os.path.abspath(__file__ + "/../../../src/"))
6+
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__))))
7+
8+
from basicts.configs import BasicTSForecastingConfig
9+
from basicts.launcher import BasicTSLauncher
10+
from basicts.models.HI.arch.hi_arch import HI
11+
from basicts.models.HI.config.hi_config import HIConfig
12+
from basicts.runners.callback import NoBP
13+
14+
15+
def test_hi_smoke_test():
16+
output_len = 24
17+
input_len = 96
18+
hi_config = HIConfig(
19+
input_len=input_len,
20+
output_len=output_len,
21+
)
22+
BasicTSLauncher.launch_training(
23+
BasicTSForecastingConfig(
24+
model=HI,
25+
dataset_name="ETTh1_mini",
26+
model_config=hi_config,
27+
gpus=None,
28+
num_epochs=1,
29+
input_len=input_len,
30+
output_len=output_len,
31+
lr=0.001,
32+
callbacks=[NoBP()],
33+
)
34+
)
35+
36+
if __name__ == "__main__":
37+
test_hi_smoke_test()

0 commit comments

Comments
 (0)