Skip to content

Commit 544ab44

Browse files
committed
Merge branch 'master' into feature/rnn-normalization
2 parents e89c22e + 83e9171 commit 544ab44

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed

darts/models/forecasting/arima.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
q: int = 0,
3333
seasonal_order: Tuple[int, int, int, int] = (0, 0, 0, 0),
3434
trend: Optional[str] = None,
35-
random_state: int = 0,
35+
random_state: Optional[int] = None,
3636
add_encoders: Optional[dict] = None,
3737
):
3838
"""ARIMA
@@ -81,7 +81,11 @@ def __init__(
8181
self.seasonal_order = seasonal_order
8282
self.trend = trend
8383
self.model = None
84-
np.random.seed(random_state)
84+
self._random_state = (
85+
random_state
86+
if random_state is None
87+
else np.random.RandomState(random_state)
88+
)
8589

8690
def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
8791
super()._fit(series, future_covariates)
@@ -144,6 +148,7 @@ def _predict(
144148
nsimulations=n,
145149
repetitions=num_samples,
146150
initial_state=self.model.states.predicted[-1, :],
151+
random_state=self._random_state,
147152
exog=future_covariates.values(copy=False)
148153
if future_covariates
149154
else None,

darts/tests/models/forecasting/test_local_forecasting_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def test_model_repr_call(self):
616616
), # no params changed
617617
(
618618
ARIMA(1, 1, 1),
619-
"ARIMA(p=1, d=1, q=1, seasonal_order=(0, 0, 0, 0), trend=None, random_state=0, add_encoders=None)",
619+
"ARIMA(p=1, d=1, q=1, seasonal_order=(0, 0, 0, 0), trend=None, random_state=None, add_encoders=None)",
620620
), # default value for a param
621621
]
622622
for model, expected in model_expected_name_pairs:

darts/tests/models/forecasting/test_probabilistic_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
models_cls_kwargs_errs = [
5050
(ExponentialSmoothing, {}, 0.3),
51-
(ARIMA, {"p": 1, "d": 0, "q": 1}, 0.03),
51+
(ARIMA, {"p": 1, "d": 0, "q": 1, "random_state": 42}, 0.03),
5252
]
5353

5454
models_cls_kwargs_errs += [
@@ -150,7 +150,6 @@ class ProbabilisticTorchModelsTestCase(DartsBaseTestClass):
150150
def test_fit_predict_determinism(self):
151151

152152
for model_cls, model_kwargs, _ in models_cls_kwargs_errs:
153-
154153
# whether the first predictions of two models initiated with the same random state are the same
155154
model = model_cls(**model_kwargs)
156155
model.fit(self.constant_noisy_ts)

requirements/core.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ joblib>=0.16.0
44
lightgbm>=3.2.0
55
matplotlib>=3.3.0
66
nfoursid>=1.0.0
7-
numpy>=1.19.0
7+
numpy>=1.19.0,<1.24.0
88
pandas>=1.0.5
99
pmdarima>=1.8.0
1010
prophet>=1.1.1

0 commit comments

Comments
 (0)