Skip to content

Commit d063f4c

Browse files
authored
TYP: DataFrame.__setitem__ with None (#1550)
* setitem with none * more tests and typings * python/mypy#20420 #1550 (comment)
1 parent 8f35f7f commit d063f4c

File tree

4 files changed

+73
-139
lines changed

4 files changed

+73
-139
lines changed

pandas-stubs/_typing.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ from pandas._libs.tslibs import (
5252
Timedelta,
5353
Timestamp,
5454
)
55+
from pandas._libs.tslibs.nattype import NaTType
5556

5657
from pandas.core.dtypes.dtypes import (
5758
CategoricalDtype,
@@ -134,6 +135,7 @@ _IndexIterScalar: TypeAlias = (
134135
Scalar: TypeAlias = (
135136
_IndexIterScalar | complex | np.integer | np.floating | np.complexfloating
136137
)
138+
ScalarOrNA: TypeAlias = Scalar | NAType | NaTType | None
137139
IntStrT = TypeVar("IntStrT", int, str)
138140

139141
# timestamp and timedelta convertible types

pandas-stubs/core/frame.pyi

Lines changed: 39 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ import xarray as xr
8181
from pandas._libs.lib import _NoDefaultDoNotUse
8282
from pandas._libs.missing import NAType
8383
from pandas._libs.tslibs import BaseOffset
84-
from pandas._libs.tslibs.nattype import NaTType
8584
from pandas._typing import (
8685
S2,
8786
AggFuncTypeBase,
@@ -146,6 +145,7 @@ from pandas._typing import (
146145
Renamer,
147146
ReplaceValue,
148147
Scalar,
148+
ScalarOrNA,
149149
ScalarT,
150150
SequenceNotStr,
151151
SeriesByT,
@@ -181,6 +181,26 @@ _T_MUTABLE_MAPPING_co = TypeVar(
181181
"_T_MUTABLE_MAPPING_co", bound=MutableMapping, covariant=True
182182
)
183183

184+
_iLocSetItemKey: TypeAlias = (
185+
int
186+
| IndexType
187+
| tuple[int, int]
188+
| tuple[IndexType, int]
189+
| tuple[IndexType, IndexType]
190+
| tuple[int, IndexType]
191+
)
192+
_LocSetItemKey: TypeAlias = (
193+
MaskType | Hashable | _IndexSliceTuple | Iterable[Scalar] | IndexingInt | slice
194+
)
195+
_SetItemValueNotDataFrame: TypeAlias = (
196+
ScalarOrNA
197+
| Sequence[ScalarOrNA]
198+
| Sequence[Sequence[ScalarOrNA]]
199+
| Mapping[Any, ScalarOrNA]
200+
| ArrayLike
201+
| IndexOpsMixin
202+
)
203+
184204
class _iLocIndexerFrame(_iLocIndexer, Generic[_T]):
185205
@overload
186206
def __getitem__(self, key: tuple[int, int]) -> Scalar: ...
@@ -202,27 +222,13 @@ class _iLocIndexerFrame(_iLocIndexer, Generic[_T]):
202222
) -> _T: ...
203223

204224
# Keep in sync with `DataFrame.__setitem__`
225+
@overload
205226
def __setitem__(
206-
self,
207-
key: (
208-
int
209-
| IndexType
210-
| tuple[int, int]
211-
| tuple[IndexType, int]
212-
| tuple[IndexType, IndexType]
213-
| tuple[int, IndexType]
214-
),
215-
value: (
216-
Scalar
217-
| IndexOpsMixin
218-
| Sequence[Scalar]
219-
| DataFrame
220-
| np_ndarray
221-
| NAType
222-
| NaTType
223-
| Mapping[Hashable, Scalar | NAType | NaTType]
224-
| None
225-
),
227+
self, key: tuple[slice, Hashable], value: _SetItemValueNotDataFrame
228+
) -> None: ...
229+
@overload
230+
def __setitem__(
231+
self, key: _iLocSetItemKey, value: _SetItemValueNotDataFrame | DataFrame
226232
) -> None: ...
227233

228234
class _LocIndexerFrame(_LocIndexer, Generic[_T]):
@@ -283,61 +289,25 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]):
283289
# Keep in sync with `DataFrame.__setitem__`
284290
@overload
285291
def __setitem__(
286-
self,
287-
key: tuple[_IndexSliceTuple, Hashable],
288-
value: (
289-
Scalar
290-
| NAType
291-
| NaTType
292-
| ArrayLike
293-
| IndexOpsMixin
294-
| Sequence[Scalar]
295-
| Sequence[Sequence[Scalar]]
296-
| Mapping[Hashable, Scalar | NAType | NaTType]
297-
| None
298-
),
292+
self, key: tuple[_IndexSliceTuple, Hashable], value: _SetItemValueNotDataFrame
299293
) -> None: ...
300294
@overload
301295
def __setitem__(
302-
self,
303-
key: (
304-
MaskType
305-
| Hashable
306-
| _IndexSliceTuple
307-
| Iterable[Scalar]
308-
| IndexingInt
309-
| slice
310-
),
311-
value: (
312-
Scalar
313-
| NAType
314-
| NaTType
315-
| ArrayLike
316-
| IndexOpsMixin
317-
| Sequence[Scalar]
318-
| Sequence[Sequence[Scalar]]
319-
| DataFrame
320-
| Mapping[Hashable, Scalar | NAType | NaTType]
321-
| None
322-
),
296+
self, key: _LocSetItemKey, value: _SetItemValueNotDataFrame | DataFrame
323297
) -> None: ...
324298

325299
class _iAtIndexerFrame(_iAtIndexer):
326300
def __getitem__(self, key: tuple[int, int]) -> Scalar: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
327301
def __setitem__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
328-
self,
329-
key: tuple[int, int],
330-
value: Scalar | NAType | NaTType | None,
302+
self, key: tuple[int, int], value: ScalarOrNA
331303
) -> None: ...
332304

333305
class _AtIndexerFrame(_AtIndexer):
334306
def __getitem__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
335307
self, key: tuple[Hashable, Hashable]
336308
) -> Scalar: ...
337309
def __setitem__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
338-
self,
339-
key: tuple[Hashable, Hashable],
340-
value: Scalar | NAType | NaTType | None,
310+
self, key: tuple[Hashable, Hashable], value: ScalarOrNA
341311
) -> None: ...
342312

343313
class _GetItemHack:
@@ -816,85 +786,26 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
816786
# Keep in sync with `_iLocIndexerFrame.__setitem__`
817787
@overload
818788
def __setitem__(
819-
self,
820-
idx: (
821-
int
822-
| IndexType
823-
| tuple[int, int]
824-
| tuple[IndexType, int]
825-
| tuple[IndexType, IndexType]
826-
| tuple[int, IndexType]
827-
),
828-
value: (
829-
Scalar
830-
| IndexOpsMixin
831-
| Sequence[Scalar]
832-
| DataFrame
833-
| np_ndarray
834-
| NAType
835-
| NaTType
836-
| Mapping[Hashable, Scalar | NAType | NaTType]
837-
| None
838-
),
789+
self, idex: tuple[slice, Hashable], value: _SetItemValueNotDataFrame
790+
) -> None: ...
791+
@overload
792+
def __setitem__(
793+
self, idx: _iLocSetItemKey, value: _SetItemValueNotDataFrame | DataFrame
839794
) -> None: ...
840795
# Keep in sync with `_LocIndexerFrame.__setitem__`
841796
@overload
842797
def __setitem__(
843-
self,
844-
idx: tuple[_IndexSliceTuple, Hashable],
845-
value: (
846-
Scalar
847-
| NAType
848-
| NaTType
849-
| ArrayLike
850-
| IndexOpsMixin
851-
| Sequence[Scalar]
852-
| Sequence[Sequence[Scalar]]
853-
| Mapping[Hashable, Scalar | NAType | NaTType]
854-
| None
855-
),
798+
self, idx: tuple[_IndexSliceTuple, Hashable], value: _SetItemValueNotDataFrame
856799
) -> None: ...
857800
@overload
858801
def __setitem__(
859-
self,
860-
idx: (
861-
MaskType
862-
| Hashable
863-
| _IndexSliceTuple
864-
| Iterable[Scalar]
865-
| IndexingInt
866-
| slice
867-
),
868-
value: (
869-
Scalar
870-
| NAType
871-
| NaTType
872-
| ArrayLike
873-
| IndexOpsMixin
874-
| Sequence[Scalar]
875-
| Sequence[Sequence[Scalar]]
876-
| DataFrame
877-
| Mapping[Hashable, Scalar | NAType | NaTType]
878-
| None
879-
),
802+
self, idx: _LocSetItemKey, value: _SetItemValueNotDataFrame | DataFrame
880803
) -> None: ...
881804
# Extra cases not supported by `_LocIndexerFrame.__setitem__` /
882805
# `_iLocIndexerFrame.__setitem__`.
883806
@overload
884807
def __setitem__(
885-
self,
886-
idx: IndexOpsMixin | DataFrame,
887-
value: (
888-
Scalar
889-
| NAType
890-
| NaTType
891-
| ArrayLike
892-
| IndexOpsMixin
893-
| Sequence[Scalar]
894-
| Sequence[Sequence[Scalar]]
895-
| Mapping[Hashable, Scalar | NAType | NaTType]
896-
| None
897-
),
808+
self, idx: IndexOpsMixin | DataFrame, value: _SetItemValueNotDataFrame
898809
) -> None: ...
899810
@overload
900811
def query(

tests/frame/test_indexing.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def test_types_setitem() -> None:
9393
df[a] = [[1, 2], [3, 4]]
9494
df[i] = [8, 9]
9595

96+
df["col1"] = [None, pd.NaT]
97+
# TODO: mypy bug, remove after python/mypy#20420 has been resolved
98+
df[["col1"]] = [[None], [pd.NA]] # type: ignore[assignment,list-item]
99+
df[iter(["col1"])] = [[None], [pd.NA]] # type: ignore[assignment]
100+
96101

97102
def test_types_setitem_mask() -> None:
98103
df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4], 5: [6, 7]})
@@ -368,17 +373,6 @@ def test_isetframe() -> None:
368373
check(assert_type(frame.isetitem([0], [10, 12]), None), type(None))
369374

370375

371-
def test_setitem_none() -> None:
372-
df = pd.DataFrame(
373-
{"A": [1, 2, 3], "B": ["abc", "def", "ghi"]}, index=["x", "y", "z"]
374-
)
375-
df.loc["x", "B"] = None
376-
df.iloc[2, 0] = None
377-
sb = pd.Series([1, 2, 3], dtype=int)
378-
sb.loc["y"] = None
379-
sb.iloc[0] = None
380-
381-
382376
def test_getsetitem_multiindex() -> None:
383377
# GH 466
384378
rows = pd.Index(["project A", "project B", "project C"])
@@ -418,12 +412,32 @@ def test_frame_setitem_na() -> None:
418412

419413
df.loc[ind, :] = pd.NA
420414
df.iloc[[0, 2], :] = pd.NA
415+
df.at["a", "x"] = pd.NA
416+
df.iat[0, 0] = pd.NA
421417

422418
# reveal_type(df["y"]) gives Series[Any], so we have to cast to tell the
423419
# type checker what kind of type it is when adding to a Timedelta
424420
df["x"] = cast("pd.Series[pd.Timestamp]", df["y"]) + pd.Timedelta(days=3)
425421
df.loc[ind, :] = pd.NaT
426422
df.iloc[[0, 2], :] = pd.NaT
423+
df.at["a", "y"] = pd.NaT
424+
df.iat[0, 0] = pd.NaT
425+
426+
df.loc["a", "x"] = None
427+
df.iloc[2, 0] = None
428+
df.at["a", "y"] = None
429+
df.iat[0, 0] = None
430+
431+
df.loc[:, "x"] = [None, pd.NA, pd.NaT]
432+
df.iloc[:, 0] = [None, pd.NA, pd.NaT]
433+
434+
# TODO: mypy bug, remove after python/mypy#20420 has been resolved
435+
df.loc[:, ["x"]] = [[None], [pd.NA], [pd.NaT]] # type: ignore[assignment,index]
436+
df.iloc[:, [0]] = [[None], [pd.NA], [pd.NaT]] # type: ignore[assignment,index]
437+
438+
# TODO: mypy bug, remove after python/mypy#20420 has been resolved
439+
df.loc[:, iter(["x"])] = [[None], [pd.NA], [pd.NaT]] # type: ignore[assignment,index]
440+
df.iloc[:, iter([0])] = [[None], [pd.NA], [pd.NaT]] # type: ignore[assignment,index]
427441

428442

429443
def test_loc_set() -> None:
@@ -574,6 +588,9 @@ def test_df_loc_dict() -> None:
574588
df.iloc[0] = {"X": 0}
575589
check(assert_type(df, pd.DataFrame), pd.DataFrame)
576590

591+
df.loc[0] = {None: None, pd.NA: pd.NA, pd.NaT: pd.NaT}
592+
df.iloc[0] = {None: None, pd.NA: pd.NA, pd.NaT: pd.NaT}
593+
577594

578595
def test_iloc_npint() -> None:
579596
# GH 69

tests/series/test_indexing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ def test_series_setitem_na() -> None:
228228
s2.loc[ind] = pd.NaT
229229
s2.iloc[[0, 2]] = pd.NaT
230230

231+
sb = pd.Series([1, 2, 3], dtype=int)
232+
sb.loc["y"] = None
233+
sb.iloc[0] = None
234+
231235

232236
def test_slice_timestamp() -> None:
233237
dti = pd.date_range("1/1/2025", "2/28/2025")

0 commit comments

Comments
 (0)