Skip to content

Commit c8d80ef

Browse files
committed
make dataframe generic
1 parent 7a66daa commit c8d80ef

File tree

6 files changed

+55
-14
lines changed

6 files changed

+55
-14
lines changed

pandas-stubs/_typing.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,10 @@ ListLikeHashable: TypeAlias = (
972972
MutableSequence[HashableT] | np_1darray | tuple[HashableT, ...] | range
973973
)
974974

975+
if TYPE_CHECKING: # noqa: PYI002
976+
IndexT0 = TypeVar("IndexT0", bound=Index, default=Index)
977+
IndexStrT0 = TypeVar("IndexStrT0", bound=Index, default=Index[str])
978+
975979
class SupportsDType(Protocol[GenericT_co]):
976980
@property
977981
def dtype(self) -> np.dtype[GenericT_co]: ...

pandas-stubs/core/frame.pyi

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ from pandas._typing import (
117117
IndexingInt,
118118
IndexKeyFunc,
119119
IndexLabel,
120+
IndexStrT0,
121+
IndexT0,
120122
IndexType,
121123
InterpolateOptions,
122124
IntervalClosedType,
@@ -378,10 +380,19 @@ _AstypeArgExt: TypeAlias = (
378380
)
379381
_AstypeArgExtList: TypeAlias = _AstypeArgExt | list[_AstypeArgExt]
380382

381-
class DataFrame(NDFrame, OpsMixin, _GetItemHack):
383+
class DataFrame(NDFrame, OpsMixin, _GetItemHack, Generic[IndexT0, IndexStrT0]):
382384

383385
__hash__: ClassVar[None] # type: ignore[assignment] # pyright: ignore[reportIncompatibleMethodOverride]
384386

387+
@overload
388+
def __new__(
389+
cls,
390+
data: DataFrame[IndexT0, IndexStrT0],
391+
index: None = None,
392+
columns: None = None,
393+
dtype: Dtype | None = None,
394+
copy: _bool | None = None,
395+
) -> DataFrame[IndexT0, IndexStrT0]: ...
385396
@overload
386397
def __new__(
387398
cls,
@@ -398,6 +409,15 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
398409
copy: _bool | None = None,
399410
) -> Self: ...
400411
@overload
412+
def __new__(
413+
cls,
414+
data: Scalar,
415+
index: IndexT0,
416+
columns: IndexStrT0,
417+
dtype: Dtype | None = None,
418+
copy: _bool | None = None,
419+
) -> DataFrame[IndexT0, IndexStrT0]: ...
420+
@overload
401421
def __new__(
402422
cls,
403423
data: Scalar,
@@ -1898,7 +1918,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
18981918
@property
18991919
def at(self) -> _AtIndexerFrame: ...
19001920
@property
1901-
def columns(self) -> Index[str]: ...
1921+
def columns(self) -> IndexStrT0: ...
19021922
@columns.setter # setter needs to be right next to getter; otherwise mypy complains
19031923
def columns(
19041924
self, cols: AnyArrayLike | SequenceNotStr[Hashable] | tuple[Hashable, ...]
@@ -1912,7 +1932,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
19121932
@property
19131933
def iloc(self) -> _iLocIndexerFrame[Self]: ...
19141934
@property
1915-
def index(self) -> Index: ...
1935+
def index(self) -> IndexT0: ...
19161936
@index.setter
19171937
def index(
19181938
self, idx: AnyArrayLike | SequenceNotStr[Hashable] | tuple[Hashable, ...]
@@ -2289,7 +2309,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
22892309
inplace: Literal[False] = False,
22902310
**kwargs: Any,
22912311
) -> Self: ...
2292-
def keys(self) -> Index: ...
2312+
def keys(self) -> IndexStrT0: ...
22932313
def kurt(
22942314
self,
22952315
axis: Axis | None = ...,

tests/frame/test_indexing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,12 +500,18 @@ def select1(df: pd.DataFrame) -> pd.Series:
500500
def select2(df: pd.DataFrame) -> list[Hashable]:
501501
return [i for i in df.index if cast(int, i) % 2 == 1]
502502

503-
check(assert_type(df.loc[select2, "x"], pd.Series), pd.Series)
503+
# I think it has to do with the two overlapping overloads of __getitem__ in _LocIndexerFrame
504+
# tuple[Callable[[DataFrame], ScalarT], int | StrLike] must be overlapping with
505+
# tuple[Callable[[DataFrame], ScalarT | list[HashableT] | IndexType | MaskType], ScalarT]
506+
check(assert_type(df.loc[select2, "x"], pd.Series), pd.Series) # type: ignore[assert-type]
504507

505508
def select3(_: pd.DataFrame) -> int:
506509
return 1
507510

508-
check(assert_type(df.loc[select3, "x"], Scalar), np.integer)
511+
# I think it has to do with the two overlapping overloads of __getitem__ in _LocIndexerFrame
512+
# tuple[Callable[[DataFrame], ScalarT], int | StrLike] must be overlapping with
513+
# tuple[Callable[[DataFrame], ScalarT | list[HashableT] | IndexType | MaskType], ScalarT]
514+
check(assert_type(df.loc[select3, "x"], Scalar), np.integer) # type: ignore[assert-type]
509515

510516
check(
511517
assert_type(df.loc[:, lambda df: df.columns.str.startswith("x")], pd.DataFrame),

tests/series/test_series.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3157,7 +3157,14 @@ def test_rank() -> None:
31573157

31583158
def test_round() -> None:
31593159
# GH 791
3160-
check(assert_type(round(pd.DataFrame([])), pd.DataFrame), pd.DataFrame)
3160+
# TODO: microsoft/pyright#11179
3161+
check(
3162+
assert_type(
3163+
round(pd.DataFrame([])), # pyright: ignore[reportAssertTypeFailure]
3164+
pd.DataFrame,
3165+
),
3166+
pd.DataFrame,
3167+
)
31613168
check(assert_type(round(pd.Series([1], dtype=int)), "pd.Series[int]"), pd.Series)
31623169

31633170

tests/test_groupby.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,9 @@ def df2scalar(val: DataFrame) -> float:
583583

584584
# iter
585585
iterator = iter(GB_DF.rolling(1))
586-
check(assert_type(iterator, Iterator[DataFrame]), Iterator)
587-
check(assert_type(next(iterator), DataFrame), DataFrame)
586+
# TODO: reported python/mypy#20436 python/mypy#20435
587+
check(assert_type(iterator, Iterator[DataFrame]), Iterator) # type: ignore[assert-type]
588+
check(assert_type(next(iterator), DataFrame), DataFrame) # type: ignore[assert-type]
588589
check(assert_type(list(GB_DF.rolling(1)), list[DataFrame]), list, DataFrame)
589590

590591

@@ -782,8 +783,9 @@ def df2scalar(val: DataFrame) -> float:
782783

783784
# iter
784785
iterator = iter(GB_DF.expanding(1))
785-
check(assert_type(iterator, Iterator[DataFrame]), Iterator)
786-
check(assert_type(next(iterator), DataFrame), DataFrame)
786+
# TODO: reported python/mypy#20436 python/mypy#20435
787+
check(assert_type(iterator, Iterator[DataFrame]), Iterator) # type: ignore[assert-type]
788+
check(assert_type(next(iterator), DataFrame), DataFrame) # type: ignore[assert-type]
787789
check(assert_type(list(GB_DF.expanding(1)), list[DataFrame]), list, DataFrame)
788790

789791

@@ -960,8 +962,9 @@ def test_frame_groupby_ewm() -> None:
960962

961963
# iter
962964
iterator = iter(GB_DF.ewm(1))
963-
check(assert_type(iterator, Iterator[DataFrame]), Iterator)
964-
check(assert_type(next(iterator), DataFrame), DataFrame)
965+
# TODO: reported python/mypy#20436 python/mypy#20435
966+
check(assert_type(iterator, Iterator[DataFrame]), Iterator) # type: ignore[assert-type]
967+
check(assert_type(next(iterator), DataFrame), DataFrame) # type: ignore[assert-type]
965968
check(assert_type(list(GB_DF.ewm(1)), list[DataFrame]), list, DataFrame)
966969

967970

tests/test_resampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def test_props() -> None:
4545

4646

4747
def test_iter() -> None:
48-
assert_type(iter(DF.resample("ME")), Iterator[tuple[Hashable, DataFrame]])
48+
# TODO: reported python/mypy#20436 python/mypy#20435
49+
assert_type(iter(DF.resample("ME")), Iterator[tuple[Hashable, DataFrame]]) # type: ignore[assert-type]
4950
for v in DF.resample("ME"):
5051
check(assert_type(v, tuple[Hashable, DataFrame]), tuple)
5152

0 commit comments

Comments
 (0)