Skip to content

Commit 00f8269

Browse files
committed
feat: 🎸 add the CFG.EVAL.SAVE_RESULTS option (issue #222)
1 parent 485e0be commit 00f8269

File tree

5 files changed

+9
-3
lines changed

5 files changed

+9
-3
lines changed

basicts/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .launcher import launch_evaluation, launch_training
22
from .runners import BaseEpochRunner
33

4-
__version__ = '0.5.0'
4+
__version__ = '0.5.1'
55

66
__all__ = ['__version__', 'launch_training', 'launch_evaluation', 'BaseEpochRunner']

basicts/runners/base_epoch_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(self, cfg: Dict) -> None:
8787
self.val_interval = cfg.get('VAL', {}).get('INTERVAL', 1)
8888
self.test_interval = cfg.get('TEST', {}).get('INTERVAL', 1)
8989

90+
self.save_results = cfg.get('EVAL', {}).get('SAVE_RESULTS', False)
91+
9092
# create checkpoint save dir
9193
if not os.path.isdir(self.ckpt_save_dir):
9294
os.makedirs(self.ckpt_save_dir)
@@ -691,7 +693,7 @@ def on_training_end(self, cfg: Dict, train_epoch: Optional[int] = None):
691693
)
692694
self.logger.info('Evaluating the best model on the test set.')
693695
self.load_model(ckpt_path=best_model_path, strict=True)
694-
self.test_pipeline(cfg=cfg, train_epoch=train_epoch, save_metrics=True, save_results=True)
696+
self.test_pipeline(cfg=cfg, train_epoch=train_epoch, save_metrics=True, save_results=self.save_results)
695697

696698
# endregion Hook Functions
697699

basicts/runners/base_iteration_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def __init__(self, cfg: Dict) -> None:
8686
self.val_data_loader = None
8787
self.test_data_loader = None
8888

89+
self.save_results = cfg.get('EVAL', {}).get('SAVE_RESULTS', False)
90+
8991
# declare meter pool
9092
self.meter_pool = None
9193

@@ -574,7 +576,7 @@ def on_training_end(self, cfg: Dict, train_iteration: Optional[int] = None) -> N
574576
)
575577
self.logger.info('Evaluating the best model on the test set.')
576578
self.load_model(ckpt_path=best_model_path, strict=True)
577-
self.test_pipeline(cfg=cfg, train_iteration=train_iteration, save_metrics=True, save_results=True)
579+
self.test_pipeline(cfg=cfg, train_iteration=train_iteration, save_metrics=True, save_results=self.save_results)
578580

579581
def get_ckpt_path(self, iteration: int) -> str:
580582
"""Get checkpoint path.

examples/complete_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,4 @@
229229
# This is a common setting in spatiotemporal forecasting. For long-sequence predictions, it is recommended to keep HORIZONS set to the default value [] to avoid confusion.
230230
CFG.EVAL.HORIZONS = []
231231
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True
232+
CFG.EVAL.SAVE_RESULTS = False # Whether to save evaluation results in a numpy file. Default: False

examples/complete_config_cn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,4 @@
226226
# 这是时空预测中的常见配置。对于长序列预测,建议将 HORIZONS 保持为默认值 [],以避免引发误解。
227227
CFG.EVAL.HORIZONS = []
228228
CFG.EVAL.USE_GPU = True # 是否在评估时使用 GPU。默认值:True
229+
CFG.EVAL.SAVE_RESULTS = False # 是否将评估结果保存为一个numpy文件。 默认值:False

0 commit comments

Comments
 (0)