@@ -115,60 +115,6 @@ def generic_aggregate(
115115 return result
116116
117117
118- def _normalize_dtype (dtype : DTypeLike , array_dtype : np .dtype , fill_value = None ) -> np .dtype :
119- if dtype is None :
120- dtype = array_dtype
121- if dtype is np .floating :
122- # mean, std, var always result in floating
123- # but we preserve the array's dtype if it is floating
124- if array_dtype .kind in "fcmM" :
125- dtype = array_dtype
126- else :
127- dtype = np .dtype ("float64" )
128- elif not isinstance (dtype , np .dtype ):
129- dtype = np .dtype (dtype )
130- if fill_value not in [None , dtypes .INF , dtypes .NINF , dtypes .NA ]:
131- dtype = np .result_type (dtype , fill_value )
132- return dtype
133-
134-
135- def _maybe_promote_int (dtype ) -> np .dtype :
136- # https://numpy.org/doc/stable/reference/generated/numpy.prod.html
137- # The dtype of a is used by default unless a has an integer dtype of less precision
138- # than the default platform integer.
139- if not isinstance (dtype , np .dtype ):
140- dtype = np .dtype (dtype )
141- if dtype .kind == "i" :
142- dtype = np .result_type (dtype , np .intp )
143- elif dtype .kind == "u" :
144- dtype = np .result_type (dtype , np .uintp )
145- return dtype
146-
147-
148- def _get_fill_value (dtype , fill_value ):
149- """Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
150- if fill_value in [None , dtypes .NA ] and dtype .kind in "US" :
151- return ""
152- if fill_value == dtypes .INF or fill_value is None :
153- return dtypes .get_pos_infinity (dtype , max_for_int = True )
154- if fill_value == dtypes .NINF :
155- return dtypes .get_neg_infinity (dtype , min_for_int = True )
156- if fill_value == dtypes .NA :
157- if np .issubdtype (dtype , np .floating ) or np .issubdtype (dtype , np .complexfloating ):
158- return np .nan
159- # This is madness, but npg checks that fill_value is compatible
160- # with array dtype even if the fill_value is never used.
161- elif (
162- np .issubdtype (dtype , np .integer )
163- or np .issubdtype (dtype , np .timedelta64 )
164- or np .issubdtype (dtype , np .datetime64 )
165- ):
166- return dtypes .get_neg_infinity (dtype , min_for_int = True )
167- else :
168- return None
169- return fill_value
170-
171-
172118def _atleast_1d (inp , min_length : int = 1 ):
173119 if xrutils .is_scalar (inp ):
174120 inp = (inp ,) * min_length
@@ -435,9 +381,9 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
435381
436382
437383min_ = Aggregation ("min" , chunk = "min" , combine = "min" , fill_value = dtypes .INF )
438- nanmin = Aggregation ("nanmin" , chunk = "nanmin" , combine = "nanmin" , fill_value = np . nan )
384+ nanmin = Aggregation ("nanmin" , chunk = "nanmin" , combine = "nanmin" , fill_value = dtypes . NA )
439385max_ = Aggregation ("max" , chunk = "max" , combine = "max" , fill_value = dtypes .NINF )
440- nanmax = Aggregation ("nanmax" , chunk = "nanmax" , combine = "nanmax" , fill_value = np . nan )
386+ nanmax = Aggregation ("nanmax" , chunk = "nanmax" , combine = "nanmax" , fill_value = dtypes . NA )
441387
442388
443389def argreduce_preprocess (array , axis ):
@@ -634,7 +580,7 @@ def last(self) -> AlignedArrays:
634580 # TODO: automate?
635581 engine = "flox" ,
636582 dtype = self .array .dtype ,
637- fill_value = _get_fill_value (self .array .dtype , dtypes .NA ),
583+ fill_value = dtypes . _get_fill_value (self .array .dtype , dtypes .NA ),
638584 expected_groups = None ,
639585 )
640586 return AlignedArrays (array = reduced ["intermediates" ][0 ], group_idx = reduced ["groups" ])
@@ -729,15 +675,15 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
729675 binary_op = None ,
730676 reduction = "nanlast" ,
731677 scan = "ffill" ,
732- identity = np . nan ,
678+ identity = dtypes . NA ,
733679 mode = "concat_then_scan" ,
734680)
735681bfill = Scan (
736682 "bfill" ,
737683 binary_op = None ,
738684 reduction = "nanlast" ,
739685 scan = "ffill" ,
740- identity = np . nan ,
686+ identity = dtypes . NA ,
741687 mode = "concat_then_scan" ,
742688 preprocess = reverse ,
743689 finalize = reverse ,
@@ -816,16 +762,27 @@ def _initialize_aggregation(
816762 np .dtype (dtype ) if dtype is not None and not isinstance (dtype , np .dtype ) else dtype
817763 )
818764
819- final_dtype = _normalize_dtype (dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value )
820- if agg .name not in ["first" , "last" , "nanfirst" , "nanlast" , "min" , "max" , "nanmin" , "nanmax" ]:
821- final_dtype = _maybe_promote_int (final_dtype )
765+ final_dtype = dtypes ._normalize_dtype (
766+ dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value
767+ )
768+ if agg .name not in [
769+ "first" ,
770+ "last" ,
771+ "nanfirst" ,
772+ "nanlast" ,
773+ "min" ,
774+ "max" ,
775+ "nanmin" ,
776+ "nanmax" ,
777+ ]:
778+ final_dtype = dtypes ._maybe_promote_int (final_dtype )
822779 agg .dtype = {
823780 "user" : dtype , # Save to automatically choose an engine
824781 "final" : final_dtype ,
825782 "numpy" : (final_dtype ,),
826783 "intermediate" : tuple (
827784 (
828- _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
785+ dtypes . _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
829786 if int_dtype is None
830787 else np .dtype (int_dtype )
831788 )
@@ -838,10 +795,10 @@ def _initialize_aggregation(
838795 # Replace sentinel fill values according to dtype
839796 agg .fill_value ["user" ] = fill_value
840797 agg .fill_value ["intermediate" ] = tuple (
841- _get_fill_value (dt , fv )
798+ dtypes . _get_fill_value (dt , fv )
842799 for dt , fv in zip (agg .dtype ["intermediate" ], agg .fill_value ["intermediate" ])
843800 )
844- agg .fill_value [func ] = _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
801+ agg .fill_value [func ] = dtypes . _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
845802
846803 fv = fill_value if fill_value is not None else agg .fill_value [agg .name ]
847804 if _is_arg_reduction (agg ):
0 commit comments