Skip to content

Commit 644a83b

Browse files
authored
tests: 📏 add smoke tests for classification and imputation (#288)
1 parent f006bd5 commit 644a83b

File tree

7 files changed

+108
-0
lines changed

7 files changed

+108
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
{
2+
"name": "ArticularyWordRecognition_mini",
3+
"num_classes": 25,
4+
"class_names": [
5+
"1.0",
6+
"10.0",
7+
"11.0",
8+
"12.0",
9+
"13.0",
10+
"14.0",
11+
"15.0",
12+
"16.0",
13+
"17.0",
14+
"18.0",
15+
"19.0",
16+
"2.0",
17+
"20.0",
18+
"21.0",
19+
"22.0",
20+
"23.0",
21+
"24.0",
22+
"25.0",
23+
"3.0",
24+
"4.0",
25+
"5.0",
26+
"6.0",
27+
"7.0",
28+
"8.0",
29+
"9.0"
30+
],
31+
"equal_length": true,
32+
"seq_len": 144,
33+
"num_nodes": 9,
34+
"num_features": 1,
35+
"shape": "[num_samples, seq_len, num_nodes, num_features]",
36+
"missing": false,
37+
"filling_missing": "NA",
38+
"norm_each_channel": true
39+
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# pylint: disable=wrong-import-position
2+
3+
import os
4+
import sys
5+
6+
sys.path.append(os.path.abspath(__file__ + "/../../../src/"))
7+
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__))))
8+
9+
from basicts import BasicTSLauncher
10+
from basicts.configs import BasicTSClassificationConfig
11+
from basicts.models.iTransformer import (iTransformerConfig,
12+
iTransformerForClassification)
13+
14+
15+
def test_itransformerforc_smoke_test():
16+
17+
model_config = iTransformerConfig(
18+
input_len=144,
19+
num_features=9,
20+
num_classes=25
21+
)
22+
23+
BasicTSLauncher.launch_training(BasicTSClassificationConfig(
24+
model=iTransformerForClassification,
25+
model_config=model_config,
26+
dataset_name="ArticularyWordRecognition_mini",
27+
gpus=None,
28+
batch_size=16,
29+
num_epochs=5,
30+
))
31+
32+
33+
if __name__ == "__main__":
34+
test_itransformerforc_smoke_test()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# pylint: disable=wrong-import-position
2+
3+
import os
4+
import sys
5+
6+
sys.path.append(os.path.abspath(__file__ + "/../../../src/"))
7+
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__))))
8+
9+
from basicts import BasicTSLauncher
10+
from basicts.configs import BasicTSImputationConfig
11+
from basicts.models.iTransformer import (iTransformerConfig,
12+
iTransformerForReconstruction)
13+
14+
15+
def test_itransformerforr_smoke_test():
16+
input_len=32
17+
model_config = iTransformerConfig(
18+
input_len=input_len,
19+
num_features=7
20+
)
21+
22+
BasicTSLauncher.launch_training(BasicTSImputationConfig(
23+
model=iTransformerForReconstruction,
24+
model_config=model_config,
25+
dataset_name="ETTh1_mini",
26+
mask_ratio=0.25,
27+
gpus=None,
28+
batch_size=16,
29+
input_len=input_len,
30+
num_epochs=5,
31+
))
32+
33+
34+
if __name__ == "__main__":
35+
test_itransformerforr_smoke_test()

0 commit comments

Comments
 (0)