Skip to content

Commit ff7af55

Browse files
code cleanup + mypy
1 parent 3553f00 commit ff7af55

File tree

1 file changed

+19
-28
lines changed

1 file changed

+19
-28
lines changed

pymc/data.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def determine_coords(
178178
model: "Model",
179179
dims: Sequence[str | None] | None = None,
180180
coords: dict[str, Sequence | np.ndarray] | None = None,
181-
) -> tuple[typing.Any, dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
181+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
182182
"""Determine coordinate values from data or the model (via ``dims``)."""
183183
raise NotImplementedError(
184184
f"Cannot determine coordinates for data of type {type(value)}, please provide `coords` explicitly or "
@@ -192,12 +192,12 @@ def determine_array_coords(
192192
model: "Model",
193193
dims: Sequence[str] | None = None,
194194
coords: dict[str, Sequence | np.ndarray] | None = None,
195-
) -> tuple[np.ndarray, dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
195+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
196196
if coords is None:
197197
coords = {}
198198

199199
if dims is None:
200-
return value, coords, _handle_none_dims(dims, value.ndim)
200+
return coords, _handle_none_dims(dims, value.ndim)
201201

202202
if len(dims) != value.ndim:
203203
raise ShapeError(
@@ -211,7 +211,7 @@ def determine_array_coords(
211211
if coord is None and dim is not None:
212212
coords[dim] = range(size)
213213

214-
return value, coords, _handle_none_dims(dims, value.ndim)
214+
return coords, _handle_none_dims(dims, value.ndim)
215215

216216

217217
@determine_coords.register(xr.DataArray)
@@ -220,46 +220,44 @@ def determine_xarray_coords(
220220
model: "Model",
221221
dims: Sequence[str | None] | None = None,
222222
coords: dict[str, Sequence | np.ndarray] | None = None,
223-
) -> tuple[xr.DataArray, dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
223+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
224224
if coords is None:
225225
coords = {}
226226

227227
if dims is None:
228-
return value, coords, _handle_none_dims(dims, value.ndim)
228+
return coords, _handle_none_dims(dims, value.ndim)
229229

230230
for dim in dims:
231231
dim_name = dim
232232
# str is applied because dim entries may be None
233233
coords[str(dim_name)] = cast(xr.DataArray, value[dim]).to_numpy()
234234

235-
return value, coords, _handle_none_dims(dims, value.ndim)
235+
return coords, _handle_none_dims(dims, value.ndim)
236236

237237

238238
def _dataframe_agnostic_coords(
239-
value: IntoFrameT,
239+
value: IntoFrameT | IntoSeriesT,
240240
model: "Model",
241241
ndim_in: int = 2,
242242
dims: Sequence[str | None] | None = None,
243243
coords: dict[str, Sequence | np.ndarray] | None = None,
244-
) -> tuple[IntoFrameT, dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
244+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
245245
if coords is None:
246246
coords = {}
247247

248-
value = cast(nw.DataFrame | nw.LazyFrame, nw.from_native(value, allow_series=False)) # type: ignore[type-var]
248+
value = nw.from_native(value, allow_series=ndim_in == 1) # type: ignore[call-overload]
249249
if isinstance(value, nw.LazyFrame):
250250
value = value.collect()
251251

252252
if dims is None:
253-
if ndim_in == 1:
254-
value = value[value.columns[0]]
255-
return value.to_native(), coords, _handle_none_dims(dims, ndim_in)
253+
return coords, _handle_none_dims(dims, ndim_in)
256254

257-
index = nw.maybe_get_index(value)
255+
index = nw.maybe_get_index(value) # type: ignore[arg-type]
258256

259257
if len(dims) != ndim_in:
260258
raise ShapeError(
261259
"Invalid data shape. The rank of the dataset must match the length of `dims`.",
262-
actual=value.shape,
260+
actual=value.shape, # type: ignore[union-attr]
263261
expected=len(dims),
264262
)
265263

@@ -278,13 +276,9 @@ def _dataframe_agnostic_coords(
278276
if len(dims) > 1:
279277
column_dim = dims[1]
280278
if column_dim is not None:
281-
select_expr = nw.exclude(index_dim) if index_dim is not None else nw.all()
282-
coords[column_dim] = value.select(select_expr).columns
283-
284-
if ndim_in == 1:
285-
value = value[value.columns[0]]
279+
coords[column_dim] = value.columns # type: ignore[union-attr]
286280

287-
return value.to_native(), coords, _handle_none_dims(dims, ndim_in)
281+
return coords, _handle_none_dims(dims, ndim_in)
288282

289283

290284
def _series_agnostic_coords(
@@ -293,14 +287,13 @@ def _series_agnostic_coords(
293287
dims: Sequence[str | None] | None = None,
294288
coords: dict[str, Sequence | np.ndarray] | None = None,
295289
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
296-
value = cast(nw.Series, nw.from_native(value, series_only=True)) # type: ignore[assignment]
297290
return _dataframe_agnostic_coords(
298-
cast(nw.DataFrame | nw.LazyFrame, value.to_frame()), # type: ignore[attr-defined]
291+
value,
299292
ndim_in=1,
300293
model=model,
301294
dims=dims,
302295
coords=coords,
303-
) # type: ignore[arg-type]
296+
)
304297

305298

306299
def _register_dataframe_backend(library_name: str):
@@ -322,9 +315,7 @@ def determine_dataframe_coords(
322315
model: "Model",
323316
dims: Sequence[str] | None = None,
324317
coords: dict[str, Sequence | np.ndarray] | None = None,
325-
) -> tuple[
326-
IntoFrameT, dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]
327-
]:
318+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None] | Sequence[None]]:
328319
return _dataframe_agnostic_coords(value, model=model, dims=dims, coords=coords)
329320

330321
except ImportError:
@@ -458,7 +449,7 @@ def Data(
458449

459450
new_dims: Sequence[str | None] | Sequence[None] | None
460451
if infer_dims_and_coords:
461-
value, coords, new_dims = determine_coords(value, model, dims)
452+
coords, new_dims = determine_coords(value, model, dims)
462453
else:
463454
new_dims = dims
464455

0 commit comments

Comments
 (0)