Skip to content

Commit c93f512

Browse files
authored
Fix/list indexing length leq 2 (#2857)
* fix list indexing with length less than or equal to 2 * update changelog
1 parent 528184d commit c93f512

File tree

3 files changed

+115
-18
lines changed

3 files changed

+115
-18
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2222
- Fixed a bug in `SKLearnModel.get_estimator()` for univariate quantile models that use `multi_models=False` , where using `quantile` did not return the correct fitted quantile model / estimator. [#2838](https://github.com/unit8co/darts/pull/2838) by [Dennis Bader](https://github.com/dennisbader).
2323
- Fixed a bug in `LightGBMModel` and `CatBoostModel` when using component-specific lags and categorical features, where certain lag scenarios could result in incorrect categorical feature declaration. [#2852](https://github.com/unit8co/darts/pull/2852) by [Dennis Bader](https://github.com/dennisbader).
2424
- Fixed a bug in `darts.utils.timeseries_generation.sine_timeseries()`, where the returned series ignored the specified `dtype`. [#2856](https://github.com/unit8co/darts/pull/2856) by [Dennis Bader](https://github.com/dennisbader).
25+
- Fixed a bug in `TimeSeries.__getitem__()`, where indexing with a list of integers of `length <= 2` resulted in an error. [#2857](https://github.com/unit8co/darts/pull/2857) by [Dennis Bader](https://github.com/dennisbader).
2526
- Removed `darts/tests` and `examples` from the Darts package distribution. These are only required for internal testing. [#2854](https://github.com/unit8co/darts/pull/2854) by [Dennis Bader](https://github.com/dennisbader).
2627

2728
**Dependencies**

darts/tests/test_timeseries.py

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ def test_integer_range_indexing(self):
229229
with pytest.raises(ValueError) as exc:
230230
_ = series_int[[2, 3, 6]]
231231
assert str(exc.value).startswith(
232-
"Could not convert integer index to a `pandas.RangeIndex`. "
233-
"Found non-unique step sizes/frequencies: `{1, 3}`. "
232+
"Cannot index a `TimeSeries` with a list of integers with non-constant step sizes. "
233+
"Observed step sizes: `{1, 3}`"
234234
)
235235

236236
# check integer indexing features when series index does not start at 0
@@ -901,16 +901,6 @@ def test_getitem_datetime_index(self):
901901
assert series.freq == 3 * self.series1.freq
902902
assert series == self.series1[: 3 * 3 : 3]
903903

904-
series = self.series1[[0, 3, 6]]
905-
assert series.freq == 3 * self.series1.freq
906-
assert series == self.series1[: 3 * 3 : 3]
907-
908-
with pytest.raises(ValueError) as exc:
909-
_ = self.series1[[0, 2, 6]]
910-
assert str(exc.value).startswith(
911-
"The time index is missing the `freq` attribute, and the frequency could not be directly inferred"
912-
)
913-
914904
# not all dates in index
915905
with pytest.raises(KeyError):
916906
self.series1[pd.date_range("19990101", "19990201")]
@@ -921,6 +911,8 @@ def test_getitem_datetime_index(self):
921911
with pytest.raises(ValueError):
922912
self.series1[::-1]
923913

914+
# list of integers is tested below
915+
924916
def test_getitem_integer_index(self):
925917
freq = 3
926918
start = 1
@@ -972,6 +964,92 @@ def test_getitem_integer_index(self):
972964
with pytest.raises(KeyError):
973965
_ = series[pd.RangeIndex(start, stop=end + 2 * freq, step=freq)]
974966

967+
# list of integers is tested below
968+
969+
@pytest.mark.parametrize("is_dti", [True, False])
970+
def test_getitem_list_of_integers(self, is_dti):
971+
if is_dti:
972+
# datetime-indexed series
973+
series = self.series1
974+
else:
975+
# range-indexed series
976+
idx_int = generate_index(start=1, length=len(self.series1), freq=3)
977+
series = TimeSeries(times=idx_int, values=self.series1.values())
978+
979+
# list of integers (relative to the start) with constant step size
980+
# length 3; multiple steps of size 1 starting at 0
981+
series_new = series[[0, 1, 2]]
982+
assert series_new.freq == series.freq
983+
assert series_new == series[:3]
984+
985+
# length 3; multiple steps of size 1 starting after 0
986+
series_new = series[[1, 2, 3]]
987+
assert series_new.freq == series.freq
988+
assert series_new == series[1:4]
989+
990+
# length 3; multiple steps starting at 0
991+
series_new = series[[0, 3, 6]]
992+
assert series_new.freq == 3 * series.freq
993+
assert series_new == series[: 3 * 3 : 3]
994+
995+
# length 3; multiple steps starting after 0
996+
series_new = series[[2, 4, 6]]
997+
assert series_new.freq == 2 * series.freq
998+
assert series_new == series[2:8:2]
999+
1000+
# length 2; 1 step
1001+
series_new = series[[0, 3]]
1002+
assert series_new.freq == 3 * series.freq
1003+
assert series_new == series[: 2 * 3 : 3]
1004+
1005+
# length 1; no steps
1006+
series_new = series[[0]]
1007+
assert series_new.freq == series.freq
1008+
assert series_new == series[:1]
1009+
1010+
series_new = series[0]
1011+
assert series_new.freq == series.freq
1012+
assert series_new == series[:1]
1013+
1014+
# list of integers (relative to the end) with constant step size
1015+
n_steps = len(series)
1016+
# length 3; multiple steps
1017+
series_new = series[[-n_steps, -n_steps + 3, -n_steps + 6]]
1018+
assert series_new.freq == 3 * series.freq
1019+
assert series_new == series[: 3 * 3 : 3]
1020+
1021+
# length 2; 1 step
1022+
series_new = series[[-n_steps, -n_steps + 3]]
1023+
assert series_new.freq == 3 * series.freq
1024+
assert series_new == series[: 2 * 3 : 3]
1025+
1026+
# length 1; no steps
1027+
series_new = series[[-n_steps]]
1028+
assert series_new.freq == series.freq
1029+
assert series_new == series[:1]
1030+
1031+
# different step sizes
1032+
with pytest.raises(ValueError) as exc:
1033+
_ = series[[0, 2, 6]]
1034+
assert str(exc.value).startswith(
1035+
"Cannot index a `TimeSeries` with a list of integers with non-constant step sizes. "
1036+
"Observed step sizes: `{2, 4}`."
1037+
)
1038+
1039+
# decreasing step sizes
1040+
with pytest.raises(ValueError) as exc:
1041+
_ = series[[3, 2, 1]]
1042+
assert str(exc.value).startswith(
1043+
"Indexing a `TimeSeries` with a list of integers with `step<=0` is not possible"
1044+
)
1045+
1046+
# constant step sizes
1047+
with pytest.raises(ValueError) as exc:
1048+
_ = series[[3, 3]]
1049+
assert str(exc.value).startswith(
1050+
"Indexing a `TimeSeries` with a list of integers with `step<=0` is not possible"
1051+
)
1052+
9751053
def test_getitem_frequency_inferrence(self):
9761054
ts = self.series1
9771055
assert ts.freq == "D"

darts/timeseries.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5737,12 +5737,30 @@ def _check_range():
57375737
metadata=self.metadata,
57385738
)
57395739
elif all(isinstance(i, (int, np.int64)) for i in key):
5740-
return self.__class__(
5741-
times=self._time_index[key],
5742-
values=self._values[key],
5743-
components=self.components,
5744-
**self._attrs,
5745-
)
5740+
# convert list of integers to slice (must have constant step size)
5741+
step_sizes = set(right - left for left, right in zip(key[:-1], key[1:]))
5742+
if len(step_sizes) > 1:
5743+
raise_log(
5744+
ValueError(
5745+
f"Cannot index a `TimeSeries` with a list of integers with non-constant step sizes. "
5746+
f"Observed step sizes: `{step_sizes}`."
5747+
),
5748+
logger,
5749+
)
5750+
elif len(step_sizes) == 1:
5751+
step_size = step_sizes.pop()
5752+
else:
5753+
step_size = 1
5754+
5755+
if step_size <= 0:
5756+
raise_log(
5757+
ValueError(
5758+
"Indexing a `TimeSeries` with a list of integers with `step<=0` is not "
5759+
"possible since `TimeSeries` must have a monotonically increasing time index."
5760+
),
5761+
logger=logger,
5762+
)
5763+
return self[key[0] : key[-1] + step_size : step_size]
57465764

57475765
elif all(isinstance(t, pd.Timestamp) for t in key):
57485766
_check_dt()

0 commit comments

Comments
 (0)