Skip to content

Commit 97f986a

Browse files
.map() improvements (#2911)
* .map() improvements * changelog update * example fix * again example fix * codecov changes * add shape check and update docs * update docs --------- Co-authored-by: dennisbader <[email protected]>
1 parent b2a0c17 commit 97f986a

File tree

6 files changed

+81
-49
lines changed

6 files changed

+81
-49
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313

14+
- 🔴 Improved the performance of the `TimeSeries.map()` method for functions that take two arguments. The mapping is now applied on the entire time index and values array which requires users to reshape the time index explicitly within the function. See more information in the `TimeSeries.map()` method documentation. [#2911](https://github.com/unit8co/darts/pull/2911) by [Jakub Chłapek](https://github.com/jakubchlapek)
15+
1416
**Fixed**
1517

1618
**Dependencies**

darts/tests/dataprocessing/transformers/test_mappers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ def inverse_func(x):
1818

1919
@staticmethod
2020
def ts_func(ts, x):
21-
return x - ts.month
21+
return x - ts.month.values.reshape(-1, 1, 1)
2222

2323
@staticmethod
2424
def inverse_ts_func(ts, x):
25-
return x + ts.month
25+
return x + ts.month.values.reshape(-1, 1, 1)
2626

2727
plus_ten = Mapper(func.__func__)
2828
plus_ten_invertible = InvertibleMapper(func.__func__, inverse_func.__func__)

darts/tests/test_timeseries.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1670,7 +1670,7 @@ def test_map_with_timestamp(self):
16701670
zeroes = zeroes.with_columns_renamed("constant", "linear")
16711671

16721672
def function(ts, x):
1673-
return x - ts.month
1673+
return x - ts.month.values.reshape(-1, 1, 1)
16741674

16751675
new_series = series.map(function)
16761676
assert new_series == zeroes
@@ -1695,6 +1695,20 @@ def add(x, y, z):
16951695
with pytest.raises(ValueError):
16961696
series.map(ufunc_add)
16971697

1698+
def test_map_fn_not_callable(self):
1699+
series = linear_timeseries(length=3)
1700+
with pytest.raises(TypeError) as exc:
1701+
series.map(fn=1)
1702+
assert str(exc.value) == "fn must be a callable"
1703+
1704+
def test_map_fn_wrong_output_shape(self):
1705+
series = linear_timeseries(length=3)
1706+
with pytest.raises(ValueError) as exc:
1707+
series.map(fn=lambda x: np.concatenate([x] * 2, axis=1))
1708+
assert str(exc.value) == (
1709+
"fn must return an array of shape `(3, 1, 1)`. Received shape `(3, 2, 1)`"
1710+
)
1711+
16981712
def test_gaps(self):
16991713
times1 = pd.date_range("20130101", "20130110")
17001714
times2 = pd.date_range("20120101", "20210301", freq=freqs["QE"])

darts/timeseries.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3662,16 +3662,16 @@ def is_within_range(self, ts: Union[pd.Timestamp, int]) -> bool:
36623662
def map(
36633663
self,
36643664
fn: Union[
3665-
Callable[[np.number], np.number],
3666-
Callable[[Union[pd.Timestamp, int], np.number], np.number],
3665+
Callable[[np.ndarray], np.ndarray],
3666+
Callable[[Union[pd.DatetimeIndex, pd.RangeIndex], np.ndarray], np.ndarray],
36673667
],
36683668
) -> Self: # noqa: E501
36693669
"""Return a new series with the function `fn` applied to the values of this series.
36703670
36713671
If `fn` takes 1 argument it is simply applied on the values array of shape `(time, n_components, n_samples)`.
3672-
If `fn` takes 2 arguments, it is applied repeatedly on the `(ts, value[ts])` tuples, where `ts` denotes a
3673-
timestamp value, and `value[ts]` denotes the array of values at this timestamp, of shape
3674-
`(n_components, n_samples)`.
3672+
If `fn` takes 2 arguments, it is applied on the `(ts, values)` tuple, where `ts` denotes the
3673+
series' time index, and `values` denotes the series' array of values, of shape
3674+
`(n_timestamps, n_components, n_samples)`. Timestamp index's shape should be `(n, 1, 1)`;
36753675
36763676
Parameters
36773677
----------
@@ -3687,9 +3687,36 @@ def map(
36873687
-------
36883688
TimeSeries
36893689
A new series with the function `fn` applied to the values.
3690+
3691+
Examples
3692+
--------
3693+
>>> from darts import TimeSeries
3694+
>>> from darts.utils.utils import generate_index
3695+
>>> # create a simple TimeSeries
3696+
>>> series = TimeSeries.from_times_and_values(
3697+
>>> times=generate_index("2020-01-01", length=3, freq="D"),
3698+
>>> values=range(3),
3699+
>>> )
3700+
>>> # map function on values only
3701+
>>> def fn1(values):
3702+
>>> return values / 3.
3703+
>>>
3704+
>>> series.map(fn1).values()
3705+
array([[0. ],
3706+
[0.33333333],
3707+
[0.66666667]])
3708+
>>>
3709+
>>> # map function on time index and values
3710+
>>> def fn2(times, values):
3711+
>>> return values / times.days_in_month.values.reshape(-1, 1, 1)
3712+
>>>
3713+
>>> series.map(fn2).values()
3714+
array([[0. ],
3715+
[0.03225806],
3716+
[0.06451613]])
36903717
"""
36913718
if not isinstance(fn, Callable):
3692-
raise_log(TypeError("fn should be callable"), logger)
3719+
raise_log(TypeError("fn must be a callable"), logger)
36933720

36943721
if isinstance(fn, np.ufunc):
36953722
if fn.nin == 1 and fn.nout == 1:
@@ -3716,30 +3743,18 @@ def map(
37163743

37173744
if num_args == 1: # apply fn on values directly
37183745
values = fn(self._values)
3719-
elif num_args == 2: # map function uses timestamp f(timestamp, x)
3720-
# go over shortest amount of iterations, either over time steps or components and samples
3721-
if self.n_timesteps <= self.n_components * self.n_samples:
3722-
new_vals = np.vstack([
3723-
np.expand_dims(
3724-
fn(self.time_index[i], self._values[i, :, :]), axis=0
3725-
)
3726-
for i in range(self.n_timesteps)
3727-
])
3728-
else:
3729-
new_vals = np.stack(
3730-
[
3731-
np.column_stack([
3732-
fn(self.time_index, self._values[:, i, j])
3733-
for j in range(self.n_samples)
3734-
])
3735-
for i in range(self.n_components)
3736-
],
3737-
axis=1,
3738-
)
3739-
values = new_vals
3740-
3746+
elif num_args == 2:
3747+
# apply function on (times, values)
3748+
values = fn(self._time_index, self._values)
37413749
else:
3742-
raise_log(ValueError("fn must have either one or two arguments"), logger)
3750+
raise_log(ValueError("fn must accept either one or two arguments"), logger)
3751+
3752+
if values.shape != self.shape:
3753+
raise_log(
3754+
ValueError(
3755+
f"fn must return an array of shape `{self.shape}`. Received shape `{values.shape}`"
3756+
)
3757+
)
37433758

37443759
return self.__class__(
37453760
times=self._time_index,

examples/00-quickstart.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@
420420
}
421421
],
422422
"source": [
423-
"series.map(lambda ts, x: x / ts.days_in_month).plot();"
423+
"series.map(lambda ts, x: x / ts.days_in_month.values.reshape(-1, 1, 1)).plot();"
424424
]
425425
},
426426
{

examples/02-data-processing.ipynb

Lines changed: 17 additions & 16 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)