Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,10 @@ ListLikeHashable: TypeAlias = (
MutableSequence[HashableT] | np_1darray | tuple[HashableT, ...] | range
)

if TYPE_CHECKING: # noqa: PYI002
IndexT0 = TypeVar("IndexT0", bound=Index, default=Index)
IndexStrT0 = TypeVar("IndexStrT0", bound=Index, default=Index[str])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the default be RangeIndex here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For df.columns we somehow have a preference for it being Index[str], see def columns(self) -> Index[str]: ... in the old code. That's the reason.


class SupportsDType(Protocol[GenericT_co]):
@property
def dtype(self) -> np.dtype[GenericT_co]: ...
Expand Down
28 changes: 24 additions & 4 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ from pandas._typing import (
IndexingInt,
IndexKeyFunc,
IndexLabel,
IndexStrT0,
IndexT0,
IndexType,
InterpolateOptions,
IntervalClosedType,
Expand Down Expand Up @@ -378,10 +380,19 @@ _AstypeArgExt: TypeAlias = (
)
_AstypeArgExtList: TypeAlias = _AstypeArgExt | list[_AstypeArgExt]

class DataFrame(NDFrame, OpsMixin, _GetItemHack):
class DataFrame(NDFrame, OpsMixin, _GetItemHack, Generic[IndexT0, IndexStrT0]):

__hash__: ClassVar[None] # type: ignore[assignment] # pyright: ignore[reportIncompatibleMethodOverride]

@overload
def __new__(
cls,
data: DataFrame[IndexT0, IndexStrT0],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not totally sure if this is gonna break a lot of things, by default DataFrame will have RangeIndex as index and columns, isn't this gonna change that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copying another DataFrame, I suppose, so nothing is changed. Will add tests later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you create a DataFrame from data without an index, RangeIndex will surely take over. That takes further overloads of __new__.

As for now I just want to show a prototype of what can happen.

index: None = None,
columns: None = None,
dtype: Dtype | None = None,
copy: _bool | None = None,
) -> DataFrame[IndexT0, IndexStrT0]: ...
@overload
def __new__(
cls,
Expand All @@ -398,6 +409,15 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
copy: _bool | None = None,
) -> Self: ...
@overload
def __new__(
cls,
data: Scalar,
index: IndexT0,
columns: IndexStrT0,
dtype: Dtype | None = None,
copy: _bool | None = None,
) -> DataFrame[IndexT0, IndexStrT0]: ...
@overload
def __new__(
cls,
data: Scalar,
Expand Down Expand Up @@ -1898,7 +1918,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@property
def at(self) -> _AtIndexerFrame: ...
@property
def columns(self) -> Index[str]: ...
def columns(self) -> IndexStrT0: ...
@columns.setter # setter needs to be right next to getter; otherwise mypy complains
def columns(
self, cols: AnyArrayLike | SequenceNotStr[Hashable] | tuple[Hashable, ...]
Expand All @@ -1912,7 +1932,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@property
def iloc(self) -> _iLocIndexerFrame[Self]: ...
@property
def index(self) -> Index: ...
def index(self) -> IndexT0: ...
@index.setter
def index(
self, idx: AnyArrayLike | SequenceNotStr[Hashable] | tuple[Hashable, ...]
Expand Down Expand Up @@ -2289,7 +2309,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
inplace: Literal[False] = False,
**kwargs: Any,
) -> Self: ...
def keys(self) -> Index: ...
def keys(self) -> IndexStrT0: ...
def kurt(
self,
axis: Axis | None = ...,
Expand Down
10 changes: 8 additions & 2 deletions tests/frame/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,18 @@ def select1(df: pd.DataFrame) -> pd.Series:
def select2(df: pd.DataFrame) -> list[Hashable]:
return [i for i in df.index if cast(int, i) % 2 == 1]

check(assert_type(df.loc[select2, "x"], pd.Series), pd.Series)
# I think it has to do with the two overlapping overloads of __getitem__ in _LocIndexerFrame
# tuple[Callable[[DataFrame], ScalarT], int | StrLike] must be overlapping with
# tuple[Callable[[DataFrame], ScalarT | list[HashableT] | IndexType | MaskType], ScalarT]
check(assert_type(df.loc[select2, "x"], pd.Series), pd.Series) # type: ignore[assert-type]

def select3(_: pd.DataFrame) -> int:
return 1

check(assert_type(df.loc[select3, "x"], Scalar), np.integer)
# I think it has to do with the two overlapping overloads of __getitem__ in _LocIndexerFrame
# tuple[Callable[[DataFrame], ScalarT], int | StrLike] must be overlapping with
# tuple[Callable[[DataFrame], ScalarT | list[HashableT] | IndexType | MaskType], ScalarT]
check(assert_type(df.loc[select3, "x"], Scalar), np.integer) # type: ignore[assert-type]

check(
assert_type(df.loc[:, lambda df: df.columns.str.startswith("x")], pd.DataFrame),
Expand Down
9 changes: 8 additions & 1 deletion tests/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3157,7 +3157,14 @@ def test_rank() -> None:

def test_round() -> None:
# GH 791
check(assert_type(round(pd.DataFrame([])), pd.DataFrame), pd.DataFrame)
# TODO: microsoft/pyright#11179
check(
assert_type(
round(pd.DataFrame([])), # pyright: ignore[reportAssertTypeFailure]
pd.DataFrame,
),
pd.DataFrame,
)
check(assert_type(round(pd.Series([1], dtype=int)), "pd.Series[int]"), pd.Series)


Expand Down
15 changes: 9 additions & 6 deletions tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,9 @@ def df2scalar(val: DataFrame) -> float:

# iter
iterator = iter(GB_DF.rolling(1))
check(assert_type(iterator, Iterator[DataFrame]), Iterator)
check(assert_type(next(iterator), DataFrame), DataFrame)
# TODO: reported python/mypy#20436 python/mypy#20435
check(assert_type(iterator, Iterator[DataFrame]), Iterator) # type: ignore[assert-type]
check(assert_type(next(iterator), DataFrame), DataFrame) # type: ignore[assert-type]
check(assert_type(list(GB_DF.rolling(1)), list[DataFrame]), list, DataFrame)


Expand Down Expand Up @@ -782,8 +783,9 @@ def df2scalar(val: DataFrame) -> float:

# iter
iterator = iter(GB_DF.expanding(1))
check(assert_type(iterator, Iterator[DataFrame]), Iterator)
check(assert_type(next(iterator), DataFrame), DataFrame)
# TODO: reported python/mypy#20436 python/mypy#20435
check(assert_type(iterator, Iterator[DataFrame]), Iterator) # type: ignore[assert-type]
check(assert_type(next(iterator), DataFrame), DataFrame) # type: ignore[assert-type]
check(assert_type(list(GB_DF.expanding(1)), list[DataFrame]), list, DataFrame)


Expand Down Expand Up @@ -960,8 +962,9 @@ def test_frame_groupby_ewm() -> None:

# iter
iterator = iter(GB_DF.ewm(1))
check(assert_type(iterator, Iterator[DataFrame]), Iterator)
check(assert_type(next(iterator), DataFrame), DataFrame)
# TODO: reported python/mypy#20436 python/mypy#20435
check(assert_type(iterator, Iterator[DataFrame]), Iterator) # type: ignore[assert-type]
check(assert_type(next(iterator), DataFrame), DataFrame) # type: ignore[assert-type]
check(assert_type(list(GB_DF.ewm(1)), list[DataFrame]), list, DataFrame)


Expand Down
3 changes: 2 additions & 1 deletion tests/test_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def test_props() -> None:


def test_iter() -> None:
assert_type(iter(DF.resample("ME")), Iterator[tuple[Hashable, DataFrame]])
# TODO: reported python/mypy#20436 python/mypy#20435
assert_type(iter(DF.resample("ME")), Iterator[tuple[Hashable, DataFrame]]) # type: ignore[assert-type]
for v in DF.resample("ME"):
check(assert_type(v, tuple[Hashable, DataFrame]), tuple)

Expand Down