Skip to content

Commit 117bb07

Browse files
authored
[ENH] pandas.Multiindex support for distributions (#580)
Closes #212. Introdues `MultiIndex` support for distributions, including through `iloc` and `loc` calls. Behaves like `pandas`, except that: * `loc` calls never remove levels * the dimension after `loc` and `iloc` call always remains 2D, since there are no 1D distributions
1 parent 475485a commit 117bb07

File tree

4 files changed

+130
-2
lines changed

4 files changed

+130
-2
lines changed

skpro/distributions/base/_base.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,52 @@ def _loc(self, rowidx=None, colidx=None):
236236
colidx = pd.Index([colidx])
237237

238238
if rowidx is not None:
239-
row_iloc = self.index.get_indexer_for(rowidx)
239+
row_iloc = self._get_indexer_like_pandas(self.index, rowidx)
240240
else:
241241
row_iloc = None
242242
if colidx is not None:
243-
col_iloc = self.columns.get_indexer_for(colidx)
243+
col_iloc = self._get_indexer_like_pandas(self.columns, colidx)
244244
else:
245245
col_iloc = None
246246
return self._iloc(rowidx=row_iloc, colidx=col_iloc)
247247

248+
def _get_indexer_like_pandas(self, index, keys):
249+
"""Return indexer for keys in index.
250+
251+
A unified helper that mimics pandas' get_indexer_for but supports:
252+
253+
- scalar key (e.g., "a", ("a", 1))
254+
- tuple key (partial or full)
255+
- list of keys (partial or full)
256+
- works for both Index and MultiIndex
257+
258+
Returns
259+
-------
260+
np.ndarray of positions (integers)
261+
"""
262+
# regular index, not multiindex
263+
if not isinstance(index, pd.MultiIndex):
264+
return index.get_indexer_for(keys)
265+
266+
# if isinstance(index, pd.MultiIndex):
267+
268+
if is_scalar_notnone(keys) or isinstance(keys, tuple):
269+
keys = [keys]
270+
271+
# Use get_locs for each key (full or partial)
272+
ilocs = []
273+
for key in keys:
274+
if isinstance(key, slice):
275+
ilocs.append(index.slice_indexer(key.start, key.stop, key.step))
276+
else:
277+
if not isinstance(key, tuple):
278+
key = [key]
279+
iloc = index.get_locs(key)
280+
if isinstance(iloc, slice):
281+
iloc = np.arange(len(index))[iloc]
282+
ilocs.append(iloc)
283+
return np.concatenate(ilocs) if ilocs else np.array([], dtype=int)
284+
248285
def _at(self, rowidx=None, colidx=None):
249286
if rowidx is not None:
250287
row_iloc = self.index.get_indexer_for([rowidx])[0]
@@ -1762,6 +1799,13 @@ def is_noneslice(obj):
17621799
ref = self.ref
17631800
indexer = getattr(ref, self.method)
17641801

1802+
# handle special case of multiindex in loc with single tuple key
1803+
if isinstance(key, tuple) and not any(isinstance(k, tuple) for k in key):
1804+
if isinstance(ref.index, pd.MultiIndex) and self.method == "_loc":
1805+
if type(ref).__name__ != "Empirical":
1806+
return indexer(rowidx=key, colidx=None)
1807+
1808+
# general case
17651809
if isinstance(key, tuple):
17661810
if not len(key) == 2:
17671811
raise ValueError(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for skpro distribution base class."""
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Test cases for the MultiIndex functionality of the BaseDistribution.
2+
3+
Uses the Normal distribution, but is intended to trigger the base layer.
4+
"""
5+
6+
import numpy as np
7+
import pandas as pd
8+
import pytest
9+
10+
from skpro.distributions.normal import Normal
11+
12+
13+
@pytest.fixture
14+
def normal_dist():
15+
ix = pd.MultiIndex.from_product([(1, 2), (2, 3)])
16+
return Normal(np.array([[1, 2], [2, 3], [4, 5], [6, 7]]), 2, index=ix)
17+
18+
19+
def test_loc_partial_level(normal_dist):
20+
result = normal_dist.loc[1]
21+
expected_index = pd.MultiIndex.from_tuples([(1, 2), (1, 3)])
22+
np.testing.assert_array_equal(result.index, expected_index)
23+
assert result.mean().shape == (2, 2)
24+
25+
26+
def test_loc_full_tuple(normal_dist):
27+
result = normal_dist.loc[(2, 2)]
28+
expected_index = pd.MultiIndex.from_tuples([(2, 2)])
29+
np.testing.assert_array_equal(result.index, expected_index)
30+
assert result.mean().shape == (1, 2)
31+
32+
33+
def test_loc_list_of_keys(normal_dist):
34+
result = normal_dist.loc[[(1, 2), (2, 3)]]
35+
expected_index = pd.MultiIndex.from_tuples([(1, 2), (2, 3)])
36+
np.testing.assert_array_equal(result.index, expected_index)
37+
assert result.mean().shape == (2, 2)
38+
39+
40+
def test_iloc_single_row(normal_dist):
41+
result = normal_dist.iloc[0]
42+
expected_index = pd.MultiIndex.from_tuples([(1, 2)])
43+
np.testing.assert_array_equal(result.index, expected_index)
44+
assert result.mean().shape == (1, 2)
45+
46+
47+
def test_iloc_multiple_rows(normal_dist):
48+
result = normal_dist.iloc[[0, 3]]
49+
expected_index = pd.MultiIndex.from_tuples([(1, 2), (2, 3)])
50+
np.testing.assert_array_equal(result.index, expected_index)
51+
assert result.mean().shape == (2, 2)
52+
53+
54+
def test_iloc_column_slice(normal_dist):
55+
result = normal_dist.iloc[:, 1]
56+
expected_index = normal_dist.index
57+
assert result.mean().shape == (4, 1)
58+
np.testing.assert_array_equal(result.index, expected_index)
59+
60+
61+
def test_loc_row_col(normal_dist):
62+
result = normal_dist.loc[(1, 2), :]
63+
expected_index = pd.MultiIndex.from_tuples([(1, 2)])
64+
assert result.mean().shape == (1, 2)
65+
np.testing.assert_array_equal(result.index, expected_index)

skpro/distributions/empirical.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,24 @@ def _iloc(self, rowidx=None, colidx=None):
270270
columns=subs_colidx,
271271
)
272272

273+
def _loc(self, rowidx=None, colidx=None):
274+
if is_scalar_notnone(rowidx) and is_scalar_notnone(colidx):
275+
return self._at(rowidx, colidx)
276+
if is_scalar_notnone(rowidx):
277+
rowidx = pd.Index([rowidx])
278+
if is_scalar_notnone(colidx):
279+
colidx = pd.Index([colidx])
280+
281+
if rowidx is not None:
282+
row_iloc = self.index.get_indexer_for(rowidx)
283+
else:
284+
row_iloc = None
285+
if colidx is not None:
286+
col_iloc = self.columns.get_indexer_for(colidx)
287+
else:
288+
col_iloc = None
289+
return self._iloc(rowidx=row_iloc, colidx=col_iloc)
290+
273291
def _iat(self, rowidx=None, colidx=None):
274292
if rowidx is None or colidx is None:
275293
raise ValueError("iat method requires both row and column index")

0 commit comments

Comments
 (0)