Skip to content

Commit da87bf4

Browse files
committed
feat: 🎸 add STPGNN
1 parent 00f8269 commit da87bf4

File tree

5 files changed

+399
-0
lines changed

5 files changed

+399
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ MOIRAI (inference) | Unified Training of Universal Time Series Forecasting Trans
123123
124124
| 📊Baseline | 📝Title | 📄Paper | 💻Code | 🏛Venue | 🎯Task |
125125
| :--------- | :------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------- | :----- |
126+
| STPGNN | Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting | [Link](https://ojs.aaai.org/index.php/AAAI/article/view/28707) | [Link](https://github.com/Kongwy5689/STPGNN?tab=readme-ov-file) | AAAI'24 | STF |
126127
| BigST | Linear Complexity Spatio-Temporal Graph Neural Network for Traffic Forecasting on Large-Scale Road Networks | [Link](https://dl.acm.org/doi/10.14778/3641204.3641217) | [Link](https://github.com/usail-hkust/BigST?tab=readme-ov-file) | VLDB'24 | STF |
127128
| STDMAE | Spatio-Temporal-Decoupled Masked Pre-training for Traffic Forecasting | [Link](https://arxiv.org/abs/2312.00516) | [Link](https://github.com/Jimmy-7664/STD-MAE) | IJCAI'24 | STF |
128129
| STWave | When Spatio-Temporal Meet Wavelets: Disentangled Traffic Forecasting via Efficient Spectral Graph Attention Networks | [Link](https://ieeexplore.ieee.org/document/10184591) | [Link](https://github.com/LMissher/STWave) | ICDE'23 | STF |

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ MOIRAI (inference) | Unified Training of Universal Time Series Forecasting Trans
125125
126126
| 📊Baseline | 📝Title | 📄Paper | 💻Code | 🏛Venue | 🎯Task |
127127
| :--------- | :------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------- | :----- |
128+
| STPGNN | Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting | [Link](https://ojs.aaai.org/index.php/AAAI/article/view/28707) | [Link](https://github.com/Kongwy5689/STPGNN?tab=readme-ov-file) | AAAI'24 | STF |
128129
| BigST | Linear Complexity Spatio-Temporal Graph Neural Network for Traffic Forecasting on Large-Scale Road Networks | [Link](https://dl.acm.org/doi/10.14778/3641204.3641217) | [Link](https://github.com/usail-hkust/BigST?tab=readme-ov-file) | VLDB'24 | STF |
129130
| STDMAE | Spatio-Temporal-Decoupled Masked Pre-training for Traffic Forecasting | [Link](https://arxiv.org/abs/2312.00516) | [Link](https://github.com/Jimmy-7664/STD-MAE) | IJCAI'24 | STF |
130131
| STWave | When Spatio-Temporal Meet Wavelets: Disentangled Traffic Forecasting via Efficient Spectral Graph Attention Networks | [Link](https://ieeexplore.ieee.org/document/10184591) | [Link](https://github.com/LMissher/STWave) | ICDE'23 | STF |

baselines/STPGNN/PEMS08.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import os
2+
import sys
3+
import torch
4+
from easydict import EasyDict
5+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
6+
7+
from basicts.metrics import masked_mae, masked_mape, masked_rmse
8+
from basicts.data import TimeSeriesForecastingDataset
9+
from basicts.runners import SimpleTimeSeriesForecastingRunner
10+
from basicts.scaler import ZScoreScaler
11+
from basicts.utils import get_regular_settings, load_adj
12+
13+
from .arch import STPGNN
14+
15+
############################## Hot Parameters ##############################
16+
# Dataset & Metrics configuration
17+
DATA_NAME = 'PEMS08' # Dataset name
18+
regular_settings = get_regular_settings(DATA_NAME)
19+
INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence
20+
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
21+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
22+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
23+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
24+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
25+
# Model architecture and parameters
26+
MODEL_ARCH = STPGNN
27+
MODEL_PARAM = {
28+
"num_nodes": 170,
29+
"dropout": 0.1,
30+
"topk": 35,
31+
"out_dim": OUTPUT_LEN,
32+
"residual_channels": 32,
33+
"dilation_channels": 32,
34+
"end_channels": 512,
35+
"kernel_size": 2,
36+
"blocks": 4,
37+
"layers": 2,
38+
"days": 48, # the `days` parameter used in STPGNN
39+
"time_of_day_size": 288, # Number of time steps in a day in the specific dataset
40+
"dims": 32,
41+
"order": 2,
42+
"in_dim": 1, # Number of input features
43+
"normalization": "batch",
44+
}
45+
NUM_EPOCHS = 100
46+
47+
############################## General Configuration ##############################
48+
CFG = EasyDict()
49+
# General settings
50+
CFG.DESCRIPTION = 'An Example Config'
51+
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
52+
# Runner
53+
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
54+
55+
############################## Dataset Configuration ##############################
56+
CFG.DATASET = EasyDict()
57+
# Dataset settings
58+
CFG.DATASET.NAME = DATA_NAME
59+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
60+
CFG.DATASET.PARAM = EasyDict({
61+
'dataset_name': DATA_NAME,
62+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
63+
'input_len': INPUT_LEN,
64+
'output_len': OUTPUT_LEN,
65+
# 'mode' is automatically set by the runner
66+
})
67+
68+
############################## Scaler Configuration ##############################
69+
CFG.SCALER = EasyDict()
70+
# Scaler settings
71+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
72+
CFG.SCALER.PARAM = EasyDict({
73+
'dataset_name': DATA_NAME,
74+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
75+
'norm_each_channel': NORM_EACH_CHANNEL,
76+
'rescale': RESCALE,
77+
})
78+
79+
############################## Model Configuration ##############################
80+
CFG.MODEL = EasyDict()
81+
# Model settings
82+
CFG.MODEL.NAME = MODEL_ARCH.__name__
83+
CFG.MODEL.ARCH = MODEL_ARCH
84+
CFG.MODEL.PARAM = MODEL_PARAM
85+
CFG.MODEL.FORWARD_FEATURES = [0, 1]
86+
CFG.MODEL.TARGET_FEATURES = [0]
87+
88+
############################## Metrics Configuration ##############################
89+
90+
CFG.METRICS = EasyDict()
91+
# Metrics settings
92+
CFG.METRICS.FUNCS = EasyDict({
93+
'MAE': masked_mae,
94+
'MAPE': masked_mape,
95+
'RMSE': masked_rmse,
96+
})
97+
CFG.METRICS.TARGET = 'MAE'
98+
CFG.METRICS.NULL_VAL = NULL_VAL
99+
100+
############################## Training Configuration ##############################
101+
CFG.TRAIN = EasyDict()
102+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
103+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
104+
'checkpoints',
105+
MODEL_ARCH.__name__,
106+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
107+
)
108+
CFG.TRAIN.LOSS = masked_mae
109+
# Optimizer settings
110+
CFG.TRAIN.OPTIM = EasyDict()
111+
CFG.TRAIN.OPTIM.TYPE = "Adam"
112+
CFG.TRAIN.OPTIM.PARAM = {
113+
"lr": 0.001,
114+
"weight_decay": 0.0001,
115+
}
116+
# Learning rate scheduler settings
117+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
118+
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
119+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
120+
"milestones": [1, 50, 80],
121+
"gamma": 0.5
122+
}
123+
CFG.TRAIN.CLIP_GRAD_PARAM = {
124+
'max_norm': 5.0
125+
}
126+
# Train data loader settings
127+
CFG.TRAIN.DATA = EasyDict()
128+
CFG.TRAIN.DATA.BATCH_SIZE = 64
129+
CFG.TRAIN.DATA.SHUFFLE = True
130+
131+
############################## Validation Configuration ##############################
132+
CFG.VAL = EasyDict()
133+
CFG.VAL.INTERVAL = 1
134+
CFG.VAL.DATA = EasyDict()
135+
CFG.VAL.DATA.BATCH_SIZE = 64
136+
137+
############################## Test Configuration ##############################
138+
CFG.TEST = EasyDict()
139+
CFG.TEST.INTERVAL = 1
140+
CFG.TEST.DATA = EasyDict()
141+
CFG.TEST.DATA.BATCH_SIZE = 64
142+
143+
############################## Evaluation Configuration ##############################
144+
145+
CFG.EVAL = EasyDict()
146+
147+
# Evaluation parameters
148+
CFG.EVAL.HORIZONS = [3, 6, 12] # Prediction horizons for evaluation. Default: []
149+
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True

baselines/STPGNN/arch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .stpgnn_arch import STPGNN

0 commit comments

Comments
 (0)