Skip to content

Commit a8b264e

Browse files
Small bugfixes and test
1 parent 5b7b38a commit a8b264e

File tree

2 files changed

+89
-29
lines changed

2 files changed

+89
-29
lines changed

pymc/data.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import pytensor.tensor as pt
2828
import xarray as xr
2929

30-
from narwhals.typing import IntoFrameT, IntoLazyFrameT, IntoSeriesT
30+
from narwhals.typing import IntoFrameT, IntoSeriesT
3131
from pytensor.compile import SharedVariable
3232
from pytensor.compile.builders import OpFromGraph
3333
from pytensor.graph.basic import Variable
@@ -174,11 +174,11 @@ def _handle_none_dims(
174174

175175
@singledispatch
176176
def determine_coords(
177-
value,
177+
value: typing.Any,
178178
model: "Model",
179179
dims: Sequence[str | None] | None = None,
180180
coords: dict[str, Sequence | np.ndarray] | None = None,
181-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
181+
) -> tuple[typing.Any, dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
182182
"""Determine coordinate values from data or the model (via ``dims``)."""
183183
raise NotImplementedError(
184184
f"Cannot determine coordinates for data of type {type(value)}, please provide `coords` explicitly or "
@@ -192,12 +192,12 @@ def determine_array_coords(
192192
model: "Model",
193193
dims: Sequence[str] | None = None,
194194
coords: dict[str, Sequence | np.ndarray] | None = None,
195-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
195+
) -> tuple[np.ndarray, dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
196196
if coords is None:
197197
coords = {}
198198

199199
if dims is None:
200-
return coords, _handle_none_dims(dims, value.ndim)
200+
return value, coords, _handle_none_dims(dims, value.ndim)
201201

202202
if len(dims) != value.ndim:
203203
raise ShapeError(
@@ -211,7 +211,7 @@ def determine_array_coords(
211211
if coord is None and dim is not None:
212212
coords[dim] = range(size)
213213

214-
return coords, _handle_none_dims(dims, value.ndim)
214+
return value, coords, _handle_none_dims(dims, value.ndim)
215215

216216

217217
@determine_coords.register(xr.DataArray)
@@ -220,41 +220,41 @@ def determine_xarray_coords(
220220
model: "Model",
221221
dims: Sequence[str | None] | None = None,
222222
coords: dict[str, Sequence | np.ndarray] | None = None,
223-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
223+
) -> tuple[xr.DataArray, dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
224224
if coords is None:
225225
coords = {}
226226

227227
if dims is None:
228-
return coords, _handle_none_dims(dims, value.ndim)
228+
return value, coords, _handle_none_dims(dims, value.ndim)
229229

230230
for dim in dims:
231231
dim_name = dim
232232
# str is applied because dim entries may be None
233233
coords[str(dim_name)] = cast(xr.DataArray, value[dim]).to_numpy()
234234

235-
return coords, _handle_none_dims(dims, value.ndim)
235+
return value, coords, _handle_none_dims(dims, value.ndim)
236236

237237

238238
def _dataframe_agnostic_coords(
239-
value: IntoFrameT | IntoLazyFrameT | nw.DataFrame | nw.LazyFrame,
239+
value: IntoFrameT,
240240
model: "Model",
241241
ndim_in: int = 2,
242242
dims: Sequence[str | None] | None = None,
243243
coords: dict[str, Sequence | np.ndarray] | None = None,
244-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
244+
) -> tuple[IntoFrameT, dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
245245
if coords is None:
246246
coords = {}
247247

248248
value = cast(nw.DataFrame | nw.LazyFrame, nw.from_native(value, allow_series=False)) # type: ignore[type-var]
249249
if isinstance(value, nw.LazyFrame):
250250
value = value.collect()
251251

252-
index = nw.maybe_get_index(value)
253-
if index is not None:
254-
value = value.with_columns(**{index.name: index.to_numpy()})
255-
256252
if dims is None:
257-
return coords, _handle_none_dims(dims, ndim_in)
253+
if ndim_in == 1:
254+
value = value[value.columns[0]]
255+
return value.to_native(), coords, _handle_none_dims(dims, ndim_in)
256+
257+
index = nw.maybe_get_index(value)
258258

259259
if len(dims) != ndim_in:
260260
raise ShapeError(
@@ -265,13 +265,13 @@ def _dataframe_agnostic_coords(
265265

266266
index_dim = dims[0]
267267
if index_dim is not None:
268-
if index_dim in value.columns:
269-
coords[index_dim] = tuple(value.select(nw.col(index_dim)).to_numpy().flatten())
268+
if index is not None:
269+
coords[index_dim] = tuple(index)
270270
elif index_dim in model.coords:
271271
coords[index_dim] = model.coords[index_dim] # type: ignore[assignment]
272272
else:
273273
raise ValueError(
274-
f"Dimension '{index_dim}' not found in DataFrame columns or model coordinates. Cannot infer "
274+
f"Dimension '{index_dim}' not found in DataFrame index or model coordinates. Cannot infer "
275275
"index coordinates."
276276
)
277277

@@ -281,7 +281,10 @@ def _dataframe_agnostic_coords(
281281
select_expr = nw.exclude(index_dim) if index_dim is not None else nw.all()
282282
coords[column_dim] = value.select(select_expr).columns
283283

284-
return coords, _handle_none_dims(dims, ndim_in)
284+
if ndim_in == 1:
285+
value = value[value.columns[0]]
286+
287+
return value.to_native(), coords, _handle_none_dims(dims, ndim_in)
285288

286289

287290
def _series_agnostic_coords(
@@ -319,7 +322,9 @@ def determine_dataframe_coords(
319322
model: "Model",
320323
dims: Sequence[str] | None = None,
321324
coords: dict[str, Sequence | np.ndarray] | None = None,
322-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
325+
) -> tuple[
326+
IntoFrameT, dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]
327+
]:
323328
return _dataframe_agnostic_coords(value, model=model, dims=dims, coords=coords)
324329

325330
except ImportError:
@@ -453,7 +458,7 @@ def Data(
453458

454459
new_dims: Sequence[str | None] | Sequence[None] | None
455460
if infer_dims_and_coords:
456-
coords, new_dims = determine_coords(value, model, dims)
461+
value, coords, new_dims = determine_coords(value, model, dims)
457462
else:
458463
new_dims = dims
459464

tests/test_data.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def test_implicit_coords_polars_series(self):
414414

415415
with pytest.raises(
416416
ValueError,
417-
match="Dimension 'date2' not found in DataFrame columns or model coordinates",
417+
match="Dimension 'date2' not found in DataFrame index or model coordinates",
418418
):
419419
pm.Data("sales_invalid", ser_sales, dims=["date2"], infer_dims_and_coords=True)
420420

@@ -428,14 +428,69 @@ def test_implicit_coords_polars_dataframe(self):
428428
df_data = pl.DataFrame(
429429
np.random.normal(size=size),
430430
schema={f"Column {c + 1}": pl.Float64 for c in range(size[1])},
431-
).with_row_count("rows")
431+
)
432432

433-
with pm.Model() as pmodel:
434-
pm.Data("observations", df_data, dims=("rows", "columns"), infer_dims_and_coords=True)
433+
# We currently count on the presence of an index in the DataFrame to infer dims. Polars has no index, so
434+
# this case errors because we can't find the 'rows' dim.
435435

436-
assert "rows" in pmodel.coords
437-
assert "columns" in pmodel.coords
438-
assert pmodel.named_vars_to_dims == {"observations": ("rows", "columns")}
436+
with pytest.raises(
437+
ValueError, match="Dimension 'rows' not found in DataFrame index or model coordinates"
438+
):
439+
with pm.Model() as pmodel:
440+
pm.Data(
441+
"observations", df_data, dims=("rows", "columns"), infer_dims_and_coords=True
442+
)
443+
444+
def test_implicit_coords_agnostic(self):
445+
pl = pytest.importorskip("polars")
446+
pd = pytest.importorskip("pandas")
447+
448+
size = (5, 7)
449+
data_np = np.random.normal(size=size)
450+
columns = [f"C{c + 1}" for c in range(size[1])]
451+
rows = [f"R{r + 1}" for r in range(size[0])]
452+
df_pd = pd.DataFrame(data_np, columns=columns, index=rows)
453+
df_pd.index.name = "rows"
454+
df_pl = pl.DataFrame(
455+
data_np,
456+
schema=dict.fromkeys(columns, pl.Float64),
457+
)
458+
459+
def make_model(coords, df, dims, infer_dims_and_coords) -> pm.Model:
460+
with pm.Model(coords=coords) as pmodel:
461+
pm.Data("X", df, dims=dims, infer_dims_and_coords=infer_dims_and_coords)
462+
return pmodel
463+
464+
expected_coords = {"rows": tuple(rows), "columns": tuple(columns)}
465+
dims = ("rows", "columns")
466+
467+
m = make_model(coords=None, df=df_pd, dims=dims, infer_dims_and_coords=True)
468+
assert m.coords == expected_coords
469+
np.testing.assert_allclose(m["X"].eval(), df_pd.values)
470+
471+
# TODO: Is infer_dims_and_coords supposed to infer dims? The current behavior is that it doesn't, it only
472+
# infers the dimension labels.
473+
for df in [df_pd, df_pl]:
474+
m = make_model(coords=None, df=df, dims=None, infer_dims_and_coords=True)
475+
assert m.coords == {}
476+
477+
m = make_model(coords=None, df=df, dims=dims, infer_dims_and_coords=False)
478+
assert m.coords == {"rows": None, "columns": None}
479+
480+
m = make_model(coords=None, df=df, dims=None, infer_dims_and_coords=False)
481+
assert m.coords == {}
482+
483+
# Pandas is special because we will infer the index dim from the DataFrame index, if one exists.
484+
m = make_model(coords=None, df=df_pd, dims=dims, infer_dims_and_coords=True)
485+
assert m.coords == expected_coords
486+
487+
# Polars (and other dataframe backends with no index concept) won't infer dims from index. This case currently
488+
# errors, because we can't find the 'rows' dim in either the DataFrame columns or the model coords.
489+
with pytest.raises(
490+
ValueError,
491+
match="Dimension 'rows' not found in DataFrame index or model coordinates",
492+
):
493+
make_model(coords=None, df=df_pl, dims=dims, infer_dims_and_coords=True)
439494

440495
def test_implicit_coords_xarray(self):
441496
xr = pytest.importorskip("xarray")

0 commit comments

Comments
 (0)