3333 generic_aggregate ,
3434)
3535from .cache import memoize
36- from .xrutils import is_duck_array , is_duck_dask_array , isnull
36+ from .xrutils import is_duck_array , is_duck_dask_array , isnull , module_available
37+
38+ HAS_NUMBAGG = module_available ("numbagg" , minversion = "0.3.0" )
3739
3840if TYPE_CHECKING :
3941 try :
6971 T_Dtypes = Union [np .typing .DTypeLike , Sequence [np .typing .DTypeLike ], None ]
7072 T_FillValues = Union [np .typing .ArrayLike , Sequence [np .typing .ArrayLike ], None ]
7173 T_Engine = Literal ["flox" , "numpy" , "numba" , "numbagg" ]
74+ T_EngineOpt = None | T_Engine
7275 T_Method = Literal ["map-reduce" , "blockwise" , "cohorts" ]
7376 T_IsBins = Union [bool | Sequence [bool ]]
7477
8386DUMMY_AXIS = - 2
8487
8588
89+ def _issorted (arr : np .ndarray ) -> bool :
90+ return bool ((arr [:- 1 ] <= arr [1 :]).all ())
91+
92+
8693def _is_arg_reduction (func : T_Agg ) -> bool :
8794 if isinstance (func , str ) and func in ["argmin" , "argmax" , "nanargmax" , "nanargmin" ]:
8895 return True
@@ -632,6 +639,7 @@ def chunk_argreduce(
632639 reindex : bool = False ,
633640 engine : T_Engine = "numpy" ,
634641 sort : bool = True ,
642+ user_dtype = None ,
635643) -> IntermediateDict :
636644 """
637645 Per-chunk arg reduction.
@@ -652,6 +660,7 @@ def chunk_argreduce(
652660 dtype = dtype ,
653661 engine = engine ,
654662 sort = sort ,
663+ user_dtype = user_dtype ,
655664 )
656665 if not isnull (results ["groups" ]).all ():
657666 idx = np .broadcast_to (idx , array .shape )
@@ -685,6 +694,7 @@ def chunk_reduce(
685694 engine : T_Engine = "numpy" ,
686695 kwargs : Sequence [dict ] | None = None ,
687696 sort : bool = True ,
697+ user_dtype = None ,
688698) -> IntermediateDict :
689699 """
690700 Wrapper for numpy_groupies aggregate that supports nD ``array`` and
@@ -785,6 +795,7 @@ def chunk_reduce(
785795 group_idx = group_idx .reshape (- 1 )
786796
787797 assert group_idx .ndim == 1
798+
788799 empty = np .all (props .nanmask )
789800
790801 results : IntermediateDict = {"groups" : [], "intermediates" : []}
@@ -1100,6 +1111,7 @@ def _grouped_combine(
11001111 dtype = (np .intp ,),
11011112 engine = engine ,
11021113 sort = sort ,
1114+ user_dtype = agg .dtype ["user" ],
11031115 )["intermediates" ][0 ]
11041116 )
11051117
@@ -1129,6 +1141,7 @@ def _grouped_combine(
11291141 dtype = (dtype ,),
11301142 engine = engine ,
11311143 sort = sort ,
1144+ user_dtype = agg .dtype ["user" ],
11321145 )
11331146 results ["intermediates" ].append (* _results ["intermediates" ])
11341147 results ["groups" ] = _results ["groups" ]
@@ -1174,6 +1187,7 @@ def _reduce_blockwise(
11741187 engine = engine ,
11751188 sort = sort ,
11761189 reindex = reindex ,
1190+ user_dtype = agg .dtype ["user" ],
11771191 )
11781192
11791193 if _is_arg_reduction (agg ):
@@ -1366,6 +1380,7 @@ def dask_groupby_agg(
13661380 fill_value = agg .fill_value ["intermediate" ],
13671381 dtype = agg .dtype ["intermediate" ],
13681382 reindex = reindex ,
1383+ user_dtype = agg .dtype ["user" ],
13691384 )
13701385 if do_simple_combine :
13711386 # Add a dummy dimension that then gets reduced over
@@ -1757,6 +1772,23 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
17571772 return expected_groups
17581773
17591774
1775+ def _choose_engine (by , agg : Aggregation ):
1776+ dtype = agg .dtype ["user" ]
1777+
1778+ not_arg_reduce = not _is_arg_reduction (agg )
1779+
1780+ # numbagg only supports nan-skipping reductions
1781+ # without dtype specified
1782+ if HAS_NUMBAGG and "nan" in agg .name :
1783+ if not_arg_reduce and dtype is None :
1784+ return "numbagg"
1785+
1786+ if not_arg_reduce and (not is_duck_dask_array (by ) and _issorted (by )):
1787+ return "flox"
1788+ else :
1789+ return "numpy"
1790+
1791+
17601792def groupby_reduce (
17611793 array : np .ndarray | DaskArray ,
17621794 * by : T_By ,
@@ -1769,7 +1801,7 @@ def groupby_reduce(
17691801 dtype : np .typing .DTypeLike = None ,
17701802 min_count : int | None = None ,
17711803 method : T_Method = "map-reduce" ,
1772- engine : T_Engine = "numpy" ,
1804+ engine : T_EngineOpt = None ,
17731805 reindex : bool | None = None ,
17741806 finalize_kwargs : dict [Any , Any ] | None = None ,
17751807) -> tuple [DaskArray , Unpack [tuple [np .ndarray | DaskArray , ...]]]: # type: ignore[misc] # Unpack not in mypy yet
@@ -2027,9 +2059,14 @@ def groupby_reduce(
20272059 # overwrite than when min_count is set
20282060 fill_value = np .nan
20292061
2030- kwargs = dict (axis = axis_ , fill_value = fill_value , engine = engine )
2062+ kwargs = dict (axis = axis_ , fill_value = fill_value )
20312063 agg = _initialize_aggregation (func , dtype , array .dtype , fill_value , min_count_ , finalize_kwargs )
20322064
2065+ # Need to set this early using `agg`
2066+ # It cannot be done in the core loop of chunk_reduce
2067+ # since we "prepare" the data for flox.
2068+ kwargs ["engine" ] = _choose_engine (by_ , agg ) if engine is None else engine
2069+
20332070 groups : tuple [np .ndarray | DaskArray , ...]
20342071 if not has_dask :
20352072 results = _reduce_blockwise (
@@ -2080,7 +2117,7 @@ def groupby_reduce(
20802117 assert len (groups ) == 1
20812118 sorted_idx = np .argsort (groups [0 ])
20822119 # This optimization helps specifically with resampling
2083- if not (sorted_idx [: - 1 ] <= sorted_idx [ 1 :]). all ( ):
2120+ if not _issorted (sorted_idx ):
20842121 result = result [..., sorted_idx ]
20852122 groups = (groups [0 ][sorted_idx ],)
20862123
0 commit comments