Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Fixed**

**Dependencies**
- Changed `Timeseries.plot()` implementation to no longer rely on xarray under the hood while keeping same functionality. [#2932](https://github.com/unit8co/darts/pull/2932) by [Jakub Chłapek](https://github.com/jakubchlapek)

### For developers of the library:

Expand Down
43 changes: 24 additions & 19 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4318,8 +4318,6 @@ def plot(
) -> matplotlib.axes.Axes:
"""Plot the series.

This is a wrapper method around :func:`xarray.DataArray.plot()`.

Parameters
----------
new_plot
Expand All @@ -4340,7 +4338,7 @@ def plot(
default_formatting
Whether to use the darts default scheme.
title
Optionally, a custom plot title. If `None`, will use the name of the underlying `xarray.DataArray`.
Optionally, a custom plot title. If `None`, will use an empty string.
label
Can either be a string or list of strings. If a string and the series only has a single component, it is
used as the label for that component. If a string and the series has multiple components, it is used as
Expand Down Expand Up @@ -4452,19 +4450,18 @@ def plot(
if ax is None:
ax = plt.gca()

# TODO: migrate from xarray plotting to something else
data_array = self.data_array(copy=False)
for i, c in enumerate(data_array.component[:n_components_to_plot]):
comp_name = str(c.values)
comp = data_array.sel(component=c)
for i, comp_name in enumerate(self.components[:n_components_to_plot]):
comp_ts = self[comp_name]

if comp.sample.size > 1:
if self.is_stochastic:
if central_quantile == "mean":
central_series = comp.mean(dim=DIMS[2])
central_ts = comp_ts.mean()
else:
central_series = comp.quantile(q=central_quantile, dim=DIMS[2])
central_ts = comp_ts.quantile(q=central_quantile)
else:
central_series = comp.mean(dim=DIMS[2])
central_ts = comp_ts.mean()

central_series = central_ts.to_series() # shape: (time,)

if custom_labels:
label_to_use = label[i]
Expand All @@ -4479,10 +4476,17 @@ def plot(
kwargs["c"] = color[i] if custom_colors else color

kwargs_central = deepcopy(kwargs)
if not self.is_deterministic:
if self.is_stochastic:
kwargs_central["alpha"] = 1
if central_series.shape[0] > 1:
p = central_series.plot(*args, ax=ax, **kwargs_central)
p = central_series.plot(
*args,
ax=ax,
**kwargs_central,
)
color_used = (
p.get_lines()[-1].get_color() if default_formatting else None
)
# empty TimeSeries
elif central_series.shape[0] == 0:
p = ax.plot(
Expand All @@ -4491,6 +4495,7 @@ def plot(
*args,
**kwargs_central,
)
color_used = p[0].get_color() if default_formatting else None
else:
p = ax.plot(
[self.start_time()],
Expand All @@ -4499,17 +4504,17 @@ def plot(
*args,
**kwargs_central,
)
color_used = p[0].get_color() if default_formatting else None
ax.set_xlabel(self.time_dim)
color_used = p[0].get_color() if default_formatting else None

# Optionally show confidence intervals
if (
comp.sample.size > 1
self.is_stochastic
and low_quantile is not None
and high_quantile is not None
):
low_series = comp.quantile(q=low_quantile, dim=DIMS[2])
high_series = comp.quantile(q=high_quantile, dim=DIMS[2])
low_series = comp_ts.quantile(q=low_quantile).to_series()
high_series = comp_ts.quantile(q=high_quantile).to_series()
if low_series.shape[0] > 1:
ax.fill_between(
self.time_index,
Expand All @@ -4528,7 +4533,7 @@ def plot(
)

ax.legend()
ax.set_title(title if title is not None else data_array.name)
ax.set_title(title if title is not None else "")
return ax

def with_columns_renamed(
Expand Down
Loading