Skip to content

Commit 528184d

Browse files
authored
fix dtype for sine_timeseries (#2856)
1 parent 98b3455 commit 528184d

File tree

3 files changed

+41
-7
lines changed

3 files changed

+41
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2121

2222
- Fixed a bug in `SKLearnModel.get_estimator()` for univariate quantile models that use `multi_models=False` , where using `quantile` did not return the correct fitted quantile model / estimator. [#2838](https://github.com/unit8co/darts/pull/2838) by [Dennis Bader](https://github.com/dennisbader).
2323
- Fixed a bug in `LightGBMModel` and `CatBoostModel` when using component-specific lags and categorical features, where certain lag scenarios could result in incorrect categorical feature declaration. [#2852](https://github.com/unit8co/darts/pull/2852) by [Dennis Bader](https://github.com/dennisbader).
24+
- Fixed a bug in `darts.utils.timeseries_generation.sine_timeseries()`, where the returned series ignored the specified `dtype`. [#2856](https://github.com/unit8co/darts/pull/2856) by [Dennis Bader](https://github.com/dennisbader).
2425
- Removed `darts/tests` and `examples` from the Darts package distribution. These are only required for internal testing. [#2854](https://github.com/unit8co/darts/pull/2854) by [Dennis Bader](https://github.com/dennisbader).
2526

2627
**Dependencies**

darts/tests/utils/test_timeseries_generation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
from typing import Union
23

34
import numpy as np
@@ -697,3 +698,35 @@ def components_f(*args, **kwargs):
697698
series_new = series_new.with_static_covariates(static_covs)
698699
series_new = series_new.with_hierarchy({"new2": ["new1"]})
699700
assert series_new == series_renamed
701+
702+
@pytest.mark.parametrize(
703+
"config",
704+
itertools.product(
705+
[np.float32, np.float64],
706+
[
707+
(autoregressive_timeseries, {"coef": [1.0]}, False),
708+
(constant_timeseries, {}, False),
709+
(datetime_attribute_timeseries, {"attribute": "dayofweek"}, True),
710+
(gaussian_timeseries, {}, False),
711+
(holidays_timeseries, {"country_code": "CH"}, True),
712+
(linear_timeseries, {}, False),
713+
(random_walk_timeseries, {}, False),
714+
(sine_timeseries, {}, False),
715+
],
716+
),
717+
)
718+
def test_generation_dtype(self, config):
719+
dtype, (gen_func, gen_kwargs, requires_idx) = config
720+
721+
if requires_idx:
722+
gen_kwargs["time_index"] = generate_index(
723+
start="2000-01-01", length=10, freq="D"
724+
)
725+
else:
726+
gen_kwargs["length"] = 10
727+
728+
# generate a TimeSeries with the specified dtype
729+
ts = gen_func(dtype=dtype, **gen_kwargs)
730+
731+
# check that the dtype of the values is as expected
732+
assert ts.dtype == dtype

darts/utils/timeseries_generation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
--------------------------------
44
"""
55

6-
import math
76
from collections.abc import Sequence
87
from typing import Any, Callable, Optional, Union
98

@@ -210,12 +209,10 @@ def sine_timeseries(
210209
start=start, end=end, freq=freq, length=length, name=TIMES_NAME
211210
)
212211
values = np.array(range(len(index)), dtype=dtype)
213-
f = np.vectorize(
214-
lambda x: value_amplitude
215-
* math.sin(2 * math.pi * value_frequency * x + value_phase)
212+
values = (
213+
value_amplitude * np.sin(2 * np.pi * value_frequency * values + value_phase)
216214
+ value_y_offset
217215
)
218-
values = f(values)
219216
return TimeSeries(
220217
times=index,
221218
values=values,
@@ -371,6 +368,7 @@ def autoregressive_timeseries(
371368
length: Optional[int] = None,
372369
freq: Union[str, int] = None,
373370
column_name: Optional[str] = "autoregressive",
371+
dtype: np.dtype = np.float64,
374372
) -> TimeSeries:
375373
"""
376374
Creates a univariate, autoregressive TimeSeries whose values are calculated using specified coefficients `coef` and
@@ -402,6 +400,8 @@ def autoregressive_timeseries(
402400
The freq is optional for generating an integer index (if not specified, 1 is used).
403401
column_name
404402
Optionally, the name of the value column for the returned TimeSeries
403+
dtype
404+
The desired NumPy dtype (np.float32 or np.float64) for the resulting series
405405
406406
Returns
407407
-------
@@ -411,7 +411,7 @@ def autoregressive_timeseries(
411411

412412
# if no start values specified default to a list of 1s
413413
if start_values is None:
414-
start_values = np.ones(len(coef))
414+
start_values = np.ones(len(coef), dtype=dtype)
415415
else:
416416
raise_if_not(
417417
len(start_values) == len(coef),
@@ -422,7 +422,7 @@ def autoregressive_timeseries(
422422
start=start, end=end, freq=freq, length=length, name=TIMES_NAME
423423
)
424424

425-
values = np.empty(len(coef) + len(index))
425+
values = np.empty(len(coef) + len(index), dtype=dtype)
426426
values[: len(coef)] = start_values
427427

428428
for i in range(len(coef), len(coef) + len(index)):

0 commit comments

Comments
 (0)