Skip to content

Commit e106771

Browse files
authored
Migrated .plot() from using xarray (#2932)
1 parent 9b61374 commit e106771

File tree

3 files changed

+284
-28
lines changed

3 files changed

+284
-28
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1616
**Fixed**
1717

1818
**Dependencies**
19+
- Changed `Timeseries.plot()` implementation to no longer rely on xarray under the hood while keeping the same functionality. [#2932](https://github.com/unit8co/darts/pull/2932) by [Jakub Chłapek](https://github.com/jakubchlapek)
1920

2021
### For developers of the library:
2122

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
from itertools import product
2+
from unittest.mock import patch
3+
4+
import matplotlib.collections as mcollections
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import pandas as pd
8+
import pytest
9+
10+
from darts import TimeSeries
11+
from darts.utils.utils import generate_index
12+
13+
14+
class TestTimeSeriesPlot:
15+
# datetime index, deterministic
16+
n_comps = 2
17+
series_dt_d = TimeSeries.from_times_and_values(
18+
times=generate_index(start="2000-01-01", length=10, freq="D"),
19+
values=np.random.random((10, n_comps, 1)),
20+
)
21+
# datetime index, probabilistic
22+
series_dt_p = TimeSeries.from_times_and_values(
23+
times=generate_index(start="2000-01-01", length=10, freq="D"),
24+
values=np.random.random((10, n_comps, 5)),
25+
)
26+
# range index, deterministic
27+
series_ri_d = TimeSeries.from_times_and_values(
28+
times=generate_index(start=0, length=10, freq=1),
29+
values=np.random.random((10, n_comps, 1)),
30+
)
31+
# range index, probabilistic
32+
series_ri_p = TimeSeries.from_times_and_values(
33+
times=generate_index(start=0, length=10, freq=1),
34+
values=np.random.random((10, n_comps, 5)),
35+
)
36+
37+
@patch("matplotlib.pyplot.show")
38+
@pytest.mark.parametrize(
39+
"config",
40+
product(
41+
["dt", "ri"],
42+
["d", "p"],
43+
[True, False],
44+
),
45+
)
46+
def test_plot_single_series(self, mock_show, config):
47+
index_type, stoch_type, use_ax = config
48+
series = getattr(self, f"series_{index_type}_{stoch_type}")
49+
if use_ax:
50+
_, ax = plt.subplots()
51+
else:
52+
ax = None
53+
series.plot(ax=ax)
54+
55+
# For deterministic series with len > 1: one line per component
56+
# For probabilistic series with len > 1: one line per component + one area per component
57+
ax = ax if use_ax else plt.gca()
58+
59+
# Count lines (Line2D objects with multiple data points representing actual lines)
60+
lines = [line for line in ax.lines if len(line.get_xdata()) > 1]
61+
assert len(lines) == self.n_comps
62+
63+
# For probabilistic: count filled areas (PolyCollection from fill_between)
64+
if series.is_stochastic:
65+
areas = [
66+
coll
67+
for coll in ax.collections
68+
if isinstance(coll, mcollections.PolyCollection)
69+
]
70+
assert len(areas) == self.n_comps
71+
72+
plt.show()
73+
plt.close()
74+
75+
@patch("matplotlib.pyplot.show")
76+
@pytest.mark.parametrize(
77+
"config",
78+
product(
79+
["dt", "ri"],
80+
["d", "p"],
81+
),
82+
)
83+
def test_plot_point_series(self, mock_show, config):
84+
index_type, stoch_type = config
85+
series = getattr(self, f"series_{index_type}_{stoch_type}")
86+
series = series[:1]
87+
series.plot()
88+
89+
# For deterministic series with len == 1: one point per component
90+
# For probabilistic series with len == 1: one point per component + one vertical line per component
91+
ax = plt.gca()
92+
93+
# Count points (Line2D objects with markers representing single points)
94+
points = [
95+
line
96+
for line in ax.lines
97+
if len(line.get_xdata()) == 1 and line.get_marker() != "None"
98+
]
99+
assert len(points) == self.n_comps
100+
101+
# For probabilistic: count vertical lines for confidence intervals
102+
if series.is_stochastic:
103+
# The confidence interval is plotted as a line with "-+" marker
104+
# It's a vertical line where x-coordinates are the same
105+
vert_lines = []
106+
for line in ax.lines:
107+
xdata = np.asarray(line.get_xdata())
108+
ydata = np.asarray(line.get_ydata())
109+
if len(xdata) == 2 and len(ydata) == 2:
110+
# check if x-coords are the same (vertical line)
111+
xdiff = xdata[0] - xdata[1]
112+
113+
if isinstance(xdiff, pd.Timedelta):
114+
xdiff = xdiff.total_seconds()
115+
116+
if abs(xdiff) < 1e-10:
117+
vert_lines.append(line)
118+
assert len(vert_lines) == self.n_comps
119+
120+
plt.show()
121+
plt.close()
122+
123+
@patch("matplotlib.pyplot.show")
124+
@pytest.mark.parametrize(
125+
"config",
126+
product(
127+
["dt", "ri"],
128+
["d", "p"],
129+
),
130+
)
131+
def test_plot_empty_series(self, mock_show, config):
132+
index_type, stoch_type = config
133+
series = getattr(self, f"series_{index_type}_{stoch_type}")
134+
series = series[:0]
135+
series.plot()
136+
137+
# For len == 0: no points or lines should be plotted
138+
ax = plt.gca()
139+
# empty plot creates a line with empty data, but we want to check for actual plotted content
140+
# no points
141+
points = [
142+
line
143+
for line in ax.lines
144+
if len(line.get_xdata()) == 1 and line.get_marker() != "None"
145+
]
146+
assert len(points) == 0
147+
148+
# no lines
149+
lines_meaningful = [line for line in ax.lines if len(line.get_xdata()) > 1]
150+
assert len(lines_meaningful) == 0
151+
152+
# no areas
153+
areas = [
154+
coll
155+
for coll in ax.collections
156+
if isinstance(coll, mcollections.PolyCollection)
157+
]
158+
assert len(areas) == 0
159+
160+
plt.show()
161+
plt.close()
162+
163+
@patch("matplotlib.pyplot.show")
164+
@pytest.mark.parametrize(
165+
"config",
166+
product(
167+
["dt", "ri"],
168+
["d", "p"],
169+
[
170+
{"new_plot": True},
171+
{"default_formatting": False},
172+
{"title": "my title"},
173+
{"label": "comps"},
174+
{"label": ["comps_1", "comps_2"]},
175+
{"alpha": 0.1, "color": "blue"},
176+
{"color": ["blue", "red"]},
177+
{"lw": 2},
178+
],
179+
),
180+
)
181+
def test_plot_params(self, mock_show, config):
182+
index_type, stoch_type, kwargs = config
183+
series = getattr(self, f"series_{index_type}_{stoch_type}")
184+
series.plot(**kwargs)
185+
plt.show()
186+
plt.close()
187+
188+
@patch("matplotlib.pyplot.show")
189+
@pytest.mark.parametrize(
190+
"config",
191+
product(
192+
["dt", "ri"],
193+
[
194+
{"central_quantile": "mean"},
195+
{"central_quantile": 0.5},
196+
{
197+
"low_quantile": 0.2,
198+
"central_quantile": 0.6,
199+
"high_quantile": 0.7,
200+
"alpha": 0.1,
201+
},
202+
],
203+
),
204+
)
205+
def test_plot_stochastic_params(self, mock_show, config):
206+
(index_type, kwargs), stoch_type = config, "p"
207+
series = getattr(self, f"series_{index_type}_{stoch_type}")
208+
series.plot(**kwargs)
209+
plt.show()
210+
plt.close()
211+
212+
@patch("matplotlib.pyplot.show")
213+
@pytest.mark.parametrize("config", ["dt", "ri"])
214+
def test_plot_multiple_series(self, mock_show, config):
215+
index_type = config
216+
series1 = getattr(self, f"series_{index_type}_d")
217+
series2 = getattr(self, f"series_{index_type}_p")
218+
series1.plot()
219+
series2.plot()
220+
plt.show()
221+
plt.close()
222+
223+
@patch("matplotlib.pyplot.show")
224+
@pytest.mark.parametrize("config", ["dt", "ri"])
225+
def test_plot_deterministic_and_stochastic(self, mock_show, config):
226+
index_type = config
227+
series1 = getattr(self, f"series_{index_type}_d")
228+
series2 = getattr(self, f"series_{index_type}_p")
229+
series1.plot()
230+
series2.plot()
231+
plt.show()
232+
plt.close()
233+
234+
@patch("matplotlib.pyplot.show")
235+
@pytest.mark.parametrize("config", ["d", "p"])
236+
def test_cannot_plot_different_index_types(self, mock_show, config):
237+
stoch_type = config
238+
series1 = getattr(self, f"series_dt_{stoch_type}")
239+
series2 = getattr(self, f"series_ri_{stoch_type}")
240+
# datetime index plot changes x-axis to use datetime index
241+
series1.plot()
242+
# cannot plot a range index on datetime index
243+
with pytest.raises(TypeError):
244+
series2.plot()
245+
plt.show()
246+
plt.close()

darts/timeseries.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4323,8 +4323,6 @@ def plot(
43234323
) -> matplotlib.axes.Axes:
43244324
"""Plot the series.
43254325
4326-
This is a wrapper method around :func:`xarray.DataArray.plot()`.
4327-
43284326
Parameters
43294327
----------
43304328
new_plot
@@ -4345,7 +4343,7 @@ def plot(
43454343
default_formatting
43464344
Whether to use the darts default scheme.
43474345
title
4348-
Optionally, a custom plot title. If `None`, will use the name of the underlying `xarray.DataArray`.
4346+
Optionally, a plot title.
43494347
label
43504348
Can either be a string or list of strings. If a string and the series only has a single component, it is
43514349
used as the label for that component. If a string and the series has multiple components, it is used as
@@ -4457,19 +4455,18 @@ def plot(
44574455
if ax is None:
44584456
ax = plt.gca()
44594457

4460-
# TODO: migrate from xarray plotting to something else
4461-
data_array = self.data_array(copy=False)
4462-
for i, c in enumerate(data_array.component[:n_components_to_plot]):
4463-
comp_name = str(c.values)
4464-
comp = data_array.sel(component=c)
4458+
for i, comp_name in enumerate(self.components[:n_components_to_plot]):
4459+
comp_ts = self[comp_name]
44654460

4466-
if comp.sample.size > 1:
4461+
if self.is_stochastic:
44674462
if central_quantile == "mean":
4468-
central_series = comp.mean(dim=DIMS[2])
4463+
central_ts = comp_ts.mean()
44694464
else:
4470-
central_series = comp.quantile(q=central_quantile, dim=DIMS[2])
4465+
central_ts = comp_ts.quantile(q=central_quantile)
44714466
else:
4472-
central_series = comp.mean(dim=DIMS[2])
4467+
central_ts = comp_ts
4468+
4469+
central_series = central_ts.to_series() # shape: (time,)
44734470

44744471
if custom_labels:
44754472
label_to_use = label[i]
@@ -4484,46 +4481,58 @@ def plot(
44844481
kwargs["c"] = color[i] if custom_colors else color
44854482

44864483
kwargs_central = deepcopy(kwargs)
4487-
if not self.is_deterministic:
4484+
if self.is_stochastic:
44884485
kwargs_central["alpha"] = 1
4489-
if central_series.shape[0] > 1:
4490-
p = central_series.plot(*args, ax=ax, **kwargs_central)
4491-
# empty TimeSeries
4492-
elif central_series.shape[0] == 0:
4493-
p = ax.plot(
4494-
[],
4495-
[],
4486+
# line plot
4487+
if len(central_series) > 1:
4488+
p = central_series.plot(
44964489
*args,
4490+
ax=ax,
44974491
**kwargs_central,
44984492
)
4499-
else:
4493+
color_used = (
4494+
p.get_lines()[-1].get_color() if default_formatting else None
4495+
)
4496+
# point plot
4497+
elif len(central_series) == 1:
45004498
p = ax.plot(
45014499
[self.start_time()],
45024500
central_series.values[0],
45034501
"o",
45044502
*args,
45054503
**kwargs_central,
45064504
)
4505+
color_used = p[0].get_color() if default_formatting else None
4506+
# empty plot
4507+
else:
4508+
p = ax.plot(
4509+
[],
4510+
[],
4511+
*args,
4512+
**kwargs_central,
4513+
)
4514+
color_used = p[0].get_color() if default_formatting else None
45074515
ax.set_xlabel(self.time_dim)
4508-
color_used = p[0].get_color() if default_formatting else None
45094516

45104517
# Optionally show confidence intervals
45114518
if (
4512-
comp.sample.size > 1
4519+
self.is_stochastic
45134520
and low_quantile is not None
45144521
and high_quantile is not None
45154522
):
4516-
low_series = comp.quantile(q=low_quantile, dim=DIMS[2])
4517-
high_series = comp.quantile(q=high_quantile, dim=DIMS[2])
4518-
if low_series.shape[0] > 1:
4523+
low_series = comp_ts.quantile(q=low_quantile).to_series()
4524+
high_series = comp_ts.quantile(q=high_quantile).to_series()
4525+
# filled area
4526+
if len(low_series) > 1:
45194527
ax.fill_between(
45204528
self.time_index,
45214529
low_series,
45224530
high_series,
45234531
color=color_used,
45244532
alpha=(alpha if alpha is not None else alpha_confidence_intvls),
45254533
)
4526-
else:
4534+
# filled line
4535+
elif len(low_series) == 1:
45274536
ax.plot(
45284537
[self.start_time(), self.start_time()],
45294538
[low_series.values[0], high_series.values[0]],
@@ -4533,7 +4542,7 @@ def plot(
45334542
)
45344543

45354544
ax.legend()
4536-
ax.set_title(title if title is not None else data_array.name)
4545+
ax.set_title(title if title is not None else "")
45374546
return ax
45384547

45394548
def with_columns_renamed(

0 commit comments

Comments
 (0)