Skip to content

Commit 4339fbb

Browse files
authored
TYP: #1544 all sys.version-related changes (#1565)
#1544 all sys.version related changes
1 parent 7a66daa commit 4339fbb

File tree

4 files changed

+79
-39
lines changed

4 files changed

+79
-39
lines changed

pandas-stubs/_typing.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,10 @@ np_1darray_dt: TypeAlias = np_1darray[np.datetime64]
959959
np_1darray_td: TypeAlias = np_1darray[np.timedelta64]
960960
np_2darray: TypeAlias = np.ndarray[tuple[int, int], np.dtype[GenericT]]
961961

962-
NDArrayT = TypeVar("NDArrayT", bound=np.ndarray)
962+
if sys.version_info >= (3, 11):
963+
NDArrayT = TypeVar("NDArrayT", bound=np.ndarray)
964+
else:
965+
NDArrayT = TypeVar("NDArrayT", bound=np.ndarray[Any, Any])
963966

964967
DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic])
965968
KeysArgType: TypeAlias = Any

pandas-stubs/core/frame.pyi

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -340,28 +340,23 @@ class _AtIndexerFrame(_AtIndexer):
340340
value: Scalar | NAType | NaTType | None,
341341
) -> None: ...
342342

343-
# With mypy 1.14.1 and python 3.12, the second overload needs a type-ignore statement
344-
if sys.version_info >= (3, 12):
345-
class _GetItemHack:
346-
@overload
347-
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
343+
class _GetItemHack:
344+
@overload
345+
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
346+
# With python 3.12+, the second overload needs a type-ignore statement
347+
if sys.version_info >= (3, 12):
348348
@overload
349349
def __getitem__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
350350
self, key: Iterable[Hashable] | slice
351351
) -> Self: ...
352-
@overload
353-
def __getitem__(self, key: Hashable) -> Series: ...
354-
355-
else:
356-
class _GetItemHack:
357-
@overload
358-
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
352+
else:
359353
@overload
360354
def __getitem__( # pyright: ignore[reportOverlappingOverload]
361355
self, key: Iterable[Hashable] | slice
362356
) -> Self: ...
363-
@overload
364-
def __getitem__(self, key: Hashable) -> Series: ...
357+
358+
@overload
359+
def __getitem__(self, key: Hashable) -> Series: ...
365360

366361
_AstypeArgExt: TypeAlias = (
367362
AstypeArg
@@ -562,16 +557,29 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
562557
coerce_float: bool = False,
563558
nrows: int | None = None,
564559
) -> Self: ...
565-
def to_records(
566-
self,
567-
index: _bool = True,
568-
column_dtypes: (
569-
_str | npt.DTypeLike | Mapping[HashableT1, npt.DTypeLike] | None
570-
) = None,
571-
index_dtypes: (
572-
_str | npt.DTypeLike | Mapping[HashableT2, npt.DTypeLike] | None
573-
) = None,
574-
) -> np.recarray: ...
560+
if sys.version_info >= (3, 11):
561+
def to_records(
562+
self,
563+
index: _bool = True,
564+
column_dtypes: (
565+
_str | npt.DTypeLike | Mapping[HashableT1, npt.DTypeLike] | None
566+
) = None,
567+
index_dtypes: (
568+
_str | npt.DTypeLike | Mapping[HashableT2, npt.DTypeLike] | None
569+
) = None,
570+
) -> np.recarray: ...
571+
else:
572+
def to_records(
573+
self,
574+
index: _bool = True,
575+
column_dtypes: (
576+
_str | npt.DTypeLike | Mapping[HashableT1, npt.DTypeLike] | None
577+
) = None,
578+
index_dtypes: (
579+
_str | npt.DTypeLike | Mapping[HashableT2, npt.DTypeLike] | None
580+
) = None,
581+
) -> np.recarray[Any, Any]: ...
582+
575583
@overload
576584
def to_stata(
577585
self,

tests/frame/test_frame.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2998,19 +2998,42 @@ def test_to_xarray() -> None:
29982998

29992999
def test_to_records() -> None:
30003000
df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
3001-
check(assert_type(df.to_records(False, "int8"), np.recarray), np.recarray)
3002-
check(
3003-
assert_type(df.to_records(False, index_dtypes=np.int8), np.recarray),
3004-
np.recarray,
3005-
)
3006-
check(
3007-
assert_type(
3008-
df.to_records(False, {"col1": np.int8, "col2": np.int16}), np.recarray
3009-
),
3010-
np.recarray,
3011-
)
30123001
dtypes = {"col1": np.int8, "col2": np.int16}
3013-
check(assert_type(df.to_records(False, dtypes), np.recarray), np.recarray)
3002+
if sys.version_info >= (3, 11):
3003+
check(assert_type(df.to_records(False, "int8"), np.recarray), np.recarray)
3004+
check(
3005+
assert_type(df.to_records(False, index_dtypes=np.int8), np.recarray),
3006+
np.recarray,
3007+
)
3008+
check(
3009+
assert_type(
3010+
df.to_records(False, {"col1": np.int8, "col2": np.int16}), np.recarray
3011+
),
3012+
np.recarray,
3013+
)
3014+
check(assert_type(df.to_records(False, dtypes), np.recarray), np.recarray)
3015+
else:
3016+
check(
3017+
assert_type(df.to_records(False, "int8"), np.recarray[Any, Any]),
3018+
np.recarray,
3019+
)
3020+
check(
3021+
assert_type(
3022+
df.to_records(False, index_dtypes=np.int8), np.recarray[Any, Any]
3023+
),
3024+
np.recarray,
3025+
)
3026+
check(
3027+
assert_type(
3028+
df.to_records(False, {"col1": np.int8, "col2": np.int16}),
3029+
np.recarray[Any, Any],
3030+
),
3031+
np.recarray,
3032+
)
3033+
check(
3034+
assert_type(df.to_records(False, dtypes), np.recarray[Any, Any]),
3035+
np.recarray,
3036+
)
30143037

30153038

30163039
def test_to_dict_simple() -> None:

tests/indexes/test_indexes.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,9 +1719,15 @@ def test_index_view() -> None:
17191719
# - pyright: ndarray[tuple[Any, ...], dtype[Any]]
17201720
check(assert_type(ind.view(np.ndarray), np.ndarray), np.ndarray) # type: ignore[assert-type]
17211721
else:
1722-
check(assert_type(ind.view(np.ndarray), np.ndarray), np.ndarray)
1722+
check(assert_type(ind.view(np.ndarray), np.ndarray[Any, Any]), np.ndarray)
17231723

1724-
class MyArray(np.ndarray): ...
1724+
if sys.version_info >= (3, 11):
1725+
1726+
class MyArray(np.ndarray): ...
1727+
1728+
else:
1729+
1730+
class MyArray(np.ndarray[Any, Any]): ...
17251731

17261732
check(assert_type(ind.view(MyArray), MyArray), MyArray)
17271733

0 commit comments

Comments
 (0)