@@ -115,60 +115,6 @@ def generic_aggregate(
115115 return result
116116
117117
118- def _normalize_dtype (dtype : DTypeLike , array_dtype : np .dtype , fill_value = None ) -> np .dtype :
119- if dtype is None :
120- dtype = array_dtype
121- if dtype is np .floating :
122- # mean, std, var always result in floating
123- # but we preserve the array's dtype if it is floating
124- if array_dtype .kind in "fcmM" :
125- dtype = array_dtype
126- else :
127- dtype = np .dtype ("float64" )
128- elif not isinstance (dtype , np .dtype ):
129- dtype = np .dtype (dtype )
130- if fill_value not in [None , dtypes .INF , dtypes .NINF , dtypes .NA ]:
131- dtype = np .result_type (dtype , fill_value )
132- return dtype
133-
134-
135- def _maybe_promote_int (dtype ) -> np .dtype :
136- # https://numpy.org/doc/stable/reference/generated/numpy.prod.html
137- # The dtype of a is used by default unless a has an integer dtype of less precision
138- # than the default platform integer.
139- if not isinstance (dtype , np .dtype ):
140- dtype = np .dtype (dtype )
141- if dtype .kind == "i" :
142- dtype = np .result_type (dtype , np .intp )
143- elif dtype .kind == "u" :
144- dtype = np .result_type (dtype , np .uintp )
145- return dtype
146-
147-
148- def _get_fill_value (dtype , fill_value ):
149- """Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
150- if fill_value in [None , dtypes .NA ] and dtype .kind in "US" :
151- return ""
152- if fill_value == dtypes .INF or fill_value is None :
153- return dtypes .get_pos_infinity (dtype , max_for_int = True )
154- if fill_value == dtypes .NINF :
155- return dtypes .get_neg_infinity (dtype , min_for_int = True )
156- if fill_value == dtypes .NA :
157- if np .issubdtype (dtype , np .floating ) or np .issubdtype (dtype , np .complexfloating ):
158- return np .nan
159- # This is madness, but npg checks that fill_value is compatible
160- # with array dtype even if the fill_value is never used.
161- elif (
162- np .issubdtype (dtype , np .integer )
163- or np .issubdtype (dtype , np .timedelta64 )
164- or np .issubdtype (dtype , np .datetime64 )
165- ):
166- return dtypes .get_neg_infinity (dtype , min_for_int = True )
167- else :
168- return None
169- return fill_value
170-
171-
172118def _atleast_1d (inp , min_length : int = 1 ):
173119 if xrutils .is_scalar (inp ):
174120 inp = (inp ,) * min_length
@@ -210,6 +156,7 @@ def __init__(
210156 final_dtype : DTypeLike | None = None ,
211157 reduction_type : Literal ["reduce" , "argreduce" ] = "reduce" ,
212158 new_dims_func : Callable | None = None ,
159+ preserves_dtype : bool = False ,
213160 ):
214161 """
215162 Blueprint for computing grouped aggregations.
@@ -256,6 +203,8 @@ def __init__(
256203 Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions
257204 added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2,
258205 so returns (2,)
206+ preserves_dtype: bool,
207+ Whether a function preserves the dtype on return E.g. min, max, first, last, mode
259208 """
260209 self .name = name
261210 # preprocess before blockwise
@@ -292,6 +241,7 @@ def __init__(
292241 self .new_dims_func : Callable = (
293242 returns_empty_tuple if new_dims_func is None else new_dims_func
294243 )
244+ self .preserves_dtype = preserves_dtype
295245
296246 @cached_property
297247 def new_dims (self ) -> tuple [Dim ]:
@@ -434,10 +384,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
434384)
435385
436386
437- min_ = Aggregation ("min" , chunk = "min" , combine = "min" , fill_value = dtypes .INF )
438- nanmin = Aggregation ("nanmin" , chunk = "nanmin" , combine = "nanmin" , fill_value = np .nan )
439- max_ = Aggregation ("max" , chunk = "max" , combine = "max" , fill_value = dtypes .NINF )
440- nanmax = Aggregation ("nanmax" , chunk = "nanmax" , combine = "nanmax" , fill_value = np .nan )
387+ min_ = Aggregation ("min" , chunk = "min" , combine = "min" , fill_value = dtypes .INF , preserves_dtype = True )
388+ nanmin = Aggregation (
389+ "nanmin" , chunk = "nanmin" , combine = "nanmin" , fill_value = dtypes .NA , preserves_dtype = True
390+ )
391+ max_ = Aggregation ("max" , chunk = "max" , combine = "max" , fill_value = dtypes .NINF , preserves_dtype = True )
392+ nanmax = Aggregation (
393+ "nanmax" , chunk = "nanmax" , combine = "nanmax" , fill_value = dtypes .NA , preserves_dtype = True
394+ )
441395
442396
443397def argreduce_preprocess (array , axis ):
@@ -525,10 +479,14 @@ def _pick_second(*x):
525479 final_dtype = np .intp ,
526480)
527481
528- first = Aggregation ("first" , chunk = None , combine = None , fill_value = None )
529- last = Aggregation ("last" , chunk = None , combine = None , fill_value = None )
530- nanfirst = Aggregation ("nanfirst" , chunk = "nanfirst" , combine = "nanfirst" , fill_value = dtypes .NA )
531- nanlast = Aggregation ("nanlast" , chunk = "nanlast" , combine = "nanlast" , fill_value = dtypes .NA )
482+ first = Aggregation ("first" , chunk = None , combine = None , fill_value = None , preserves_dtype = True )
483+ last = Aggregation ("last" , chunk = None , combine = None , fill_value = None , preserves_dtype = True )
484+ nanfirst = Aggregation (
485+ "nanfirst" , chunk = "nanfirst" , combine = "nanfirst" , fill_value = dtypes .NA , preserves_dtype = True
486+ )
487+ nanlast = Aggregation (
488+ "nanlast" , chunk = "nanlast" , combine = "nanlast" , fill_value = dtypes .NA , preserves_dtype = True
489+ )
532490
533491all_ = Aggregation (
534492 "all" ,
@@ -579,8 +537,12 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
579537 final_dtype = np .floating ,
580538 new_dims_func = quantile_new_dims_func ,
581539)
582- mode = Aggregation (name = "mode" , fill_value = dtypes .NA , chunk = None , combine = None )
583- nanmode = Aggregation (name = "nanmode" , fill_value = dtypes .NA , chunk = None , combine = None )
540+ mode = Aggregation (
541+ name = "mode" , fill_value = dtypes .NA , chunk = None , combine = None , preserves_dtype = True
542+ )
543+ nanmode = Aggregation (
544+ name = "nanmode" , fill_value = dtypes .NA , chunk = None , combine = None , preserves_dtype = True
545+ )
584546
585547
586548@dataclass
@@ -634,7 +596,7 @@ def last(self) -> AlignedArrays:
634596 # TODO: automate?
635597 engine = "flox" ,
636598 dtype = self .array .dtype ,
637- fill_value = _get_fill_value (self .array .dtype , dtypes .NA ),
599+ fill_value = dtypes . _get_fill_value (self .array .dtype , dtypes .NA ),
638600 expected_groups = None ,
639601 )
640602 return AlignedArrays (array = reduced ["intermediates" ][0 ], group_idx = reduced ["groups" ])
@@ -729,6 +691,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
729691 binary_op = None ,
730692 reduction = "nanlast" ,
731693 scan = "ffill" ,
694+ # Important: this must be NaN otherwise, ffill does not work.
732695 identity = np .nan ,
733696 mode = "concat_then_scan" ,
734697)
@@ -737,6 +700,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
737700 binary_op = None ,
738701 reduction = "nanlast" ,
739702 scan = "ffill" ,
703+ # Important: this must be NaN otherwise, bfill does not work.
740704 identity = np .nan ,
741705 mode = "concat_then_scan" ,
742706 preprocess = reverse ,
@@ -815,17 +779,18 @@ def _initialize_aggregation(
815779 dtype_ : np .dtype | None = (
816780 np .dtype (dtype ) if dtype is not None and not isinstance (dtype , np .dtype ) else dtype
817781 )
818-
819- final_dtype = _normalize_dtype (dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value )
820- if agg .name not in ["first" , "last" , "nanfirst" , "nanlast" , "min" , "max" , "nanmin" , "nanmax" ]:
821- final_dtype = _maybe_promote_int (final_dtype )
782+ final_dtype = dtypes ._normalize_dtype (
783+ dtype_ or agg .dtype_init ["final" ], array_dtype , fill_value
784+ )
785+ if not agg .preserves_dtype :
786+ final_dtype = dtypes ._maybe_promote_int (final_dtype )
822787 agg .dtype = {
823788 "user" : dtype , # Save to automatically choose an engine
824789 "final" : final_dtype ,
825790 "numpy" : (final_dtype ,),
826791 "intermediate" : tuple (
827792 (
828- _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
793+ dtypes . _normalize_dtype (int_dtype , np .result_type (array_dtype , final_dtype ), int_fv )
829794 if int_dtype is None
830795 else np .dtype (int_dtype )
831796 )
@@ -838,10 +803,10 @@ def _initialize_aggregation(
838803 # Replace sentinel fill values according to dtype
839804 agg .fill_value ["user" ] = fill_value
840805 agg .fill_value ["intermediate" ] = tuple (
841- _get_fill_value (dt , fv )
806+ dtypes . _get_fill_value (dt , fv )
842807 for dt , fv in zip (agg .dtype ["intermediate" ], agg .fill_value ["intermediate" ])
843808 )
844- agg .fill_value [func ] = _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
809+ agg .fill_value [func ] = dtypes . _get_fill_value (agg .dtype ["final" ], agg .fill_value [func ])
845810
846811 fv = fill_value if fill_value is not None else agg .fill_value [agg .name ]
847812 if _is_arg_reduction (agg ):
0 commit comments