Skip to content

Commit 513537e

Browse files
authored
fix torch likelihood imports for missing torch dependency (#2761)
* fix torch likelihood imports for missing torch dependency * move NotImportedModule to utils/utils.py
1 parent e687fb7 commit 513537e

File tree

14 files changed

+99
-63
lines changed

14 files changed

+99
-63
lines changed

darts/explainability/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from darts.explainability.shap_explainer import ShapExplainer
1212
from darts.logging import get_logger
13-
from darts.models.utils import NotImportedModule
13+
from darts.utils.utils import NotImportedModule
1414

1515
logger = get_logger(__name__)
1616
try:

darts/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
logger = get_logger(__name__)
99

10-
from darts.models.utils import NotImportedModule
10+
from darts.utils.utils import NotImportedModule
1111

1212
try:
1313
# `lightgbm` needs to be imported first to avoid segmentation fault

darts/models/forecasting/conformal_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from darts.logging import get_logger, raise_log
3939
from darts.metrics.metrics import METRIC_TYPE
4040
from darts.models.forecasting.forecasting_model import GlobalForecastingModel
41-
from darts.models.utils import TORCH_AVAILABLE
4241
from darts.utils import _build_tqdm_iterator, _with_sanity_checks
4342
from darts.utils.historical_forecasts.utils import (
4443
_adjust_historical_forecasts_time_index,
@@ -50,6 +49,7 @@
5049
series2seq,
5150
)
5251
from darts.utils.utils import (
52+
TORCH_AVAILABLE,
5353
_check_quantiles,
5454
generate_index,
5555
n_steps_between,

darts/models/forecasting/ensemble_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
GlobalForecastingModel,
2222
LocalForecastingModel,
2323
)
24-
from darts.models.utils import TORCH_AVAILABLE
2524
from darts.timeseries import TimeSeries, concatenate
2625
from darts.utils.ts_utils import series2seq
26+
from darts.utils.utils import TORCH_AVAILABLE
2727

2828
if TORCH_AVAILABLE:
2929
from darts.models.forecasting.torch_forecasting_model import TorchForecastingModel

darts/models/utils.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

darts/tests/explainability/test_shap_explainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
ExponentialSmoothing,
2121
LightGBMModel,
2222
LinearRegressionModel,
23-
NotImportedModule,
2423
RegressionModel,
2524
XGBModel,
2625
)
2726
from darts.utils.timeseries_generation import linear_timeseries
27+
from darts.utils.utils import NotImportedModule
2828

2929
lgbm_available = not isinstance(LightGBMModel, NotImportedModule)
3030
cb_available = not isinstance(CatBoostModel, NotImportedModule)

darts/tests/models/forecasting/test_local_forecasting_models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
NaiveMean,
2828
NaiveMovingAverage,
2929
NaiveSeasonal,
30-
NotImportedModule,
3130
Prophet,
3231
RandomForest,
3332
RegressionModel,
@@ -44,7 +43,13 @@
4443
)
4544
from darts.timeseries import TimeSeries
4645
from darts.utils import timeseries_generation as tg
47-
from darts.utils.utils import ModelMode, SeasonalityMode, TrendMode, generate_index
46+
from darts.utils.utils import (
47+
ModelMode,
48+
NotImportedModule,
49+
SeasonalityMode,
50+
TrendMode,
51+
generate_index,
52+
)
4853

4954
logger = get_logger(__name__)
5055

darts/tests/models/forecasting/test_probabilistic_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
ExponentialSmoothing,
1717
LightGBMModel,
1818
LinearRegressionModel,
19-
NotImportedModule,
2019
XGBModel,
2120
)
2221
from darts.tests.conftest import TORCH_AVAILABLE, tfm_kwargs
2322
from darts.utils import timeseries_generation as tg
23+
from darts.utils.utils import NotImportedModule
2424

2525
logger = get_logger(__name__)
2626

darts/tests/models/forecasting/test_prophet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
from darts import TimeSeries
88
from darts.logging import get_logger
9-
from darts.models import NotImportedModule, Prophet
9+
from darts.models import Prophet
1010
from darts.utils import timeseries_generation as tg
11-
from darts.utils.utils import freqs, generate_index
11+
from darts.utils.utils import NotImportedModule, freqs, generate_index
1212

1313
logger = get_logger(__name__)
1414

darts/tests/models/forecasting/test_regression_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sklearn.neighbors import KNeighborsRegressor
1616

1717
import darts
18+
import darts.utils.utils
1819
from darts import TimeSeries
1920
from darts.dataprocessing.encoders import (
2021
FutureCyclicEncoder,
@@ -26,7 +27,6 @@
2627
CatBoostModel,
2728
LightGBMModel,
2829
LinearRegressionModel,
29-
NotImportedModule,
3030
RandomForest,
3131
RegressionModel,
3232
XGBModel,
@@ -35,7 +35,7 @@
3535
from darts.utils.likelihood_models.base import Likelihood, LikelihoodType
3636
from darts.utils.likelihood_models.sklearn import _get_likelihood
3737
from darts.utils.multioutput import MultiOutputRegressor
38-
from darts.utils.utils import generate_index
38+
from darts.utils.utils import NotImportedModule, generate_index
3939

4040
logger = get_logger(__name__)
4141

@@ -3422,7 +3422,7 @@ def test_get_categorical_features_helper(self):
34223422
(
34233423
darts.models.forecasting.lgbm.lgb.LGBMRegressor
34243424
if lgbm_available
3425-
else darts.models.utils.NotImportedModule
3425+
else darts.utils.utils.NotImportedModule
34263426
),
34273427
"fit",
34283428
)

0 commit comments

Comments
 (0)