Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,21 @@ def _advanced_indexer_subspaces(key):
"""Indices of the advanced indexes subspaces for mixed indexing and vindex."""
if not isinstance(key, tuple):
key = (key,)
advanced_index_positions = [
i for i, k in enumerate(key) if not isinstance(k, slice)
]

# Preallocate and reuse memory for position arrays
# Optimize by using list comprehension once, then using the result multiple times
n_items = len(key)
advanced_index_positions = []
non_slices = []

# Combine both loops for advanced_index_positions and non_slices for cache locality
for i in range(n_items):
k = key[i]
if not isinstance(k, slice):
advanced_index_positions.append(i)
non_slices.append(k)

# Fast exit if nothing to reorder, avoids allocations below

if not advanced_index_positions or not _is_contiguous(advanced_index_positions):
# Nothing to reorder: dimensions on the indexing result are already
Expand All @@ -144,13 +156,32 @@ def _advanced_indexer_subspaces(key):
# https://numpy.org/doc/stable/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
return (), ()

non_slices = [k for k in key if not isinstance(k, slice)]
broadcasted_shape = np.broadcast_shapes(
*[item.shape if is_duck_array(item) else (0,) for item in non_slices]
)
# Optimize broadcast_shapes by using tuple generator to avoid creating intermediate list
shapes = []
for item in non_slices:
# Fast path: skip function call for ndarray
if type(item) is np.ndarray:
shapes.append(item.shape)
elif is_duck_array(item):
shapes.append(item.shape)
else:
shapes.append((0,))
# Avoid unpacking from generator, which is slower than from list
broadcasted_shape = np.broadcast_shapes(*shapes)
ndim = len(broadcasted_shape)
mixed_positions = advanced_index_positions[0] + np.arange(ndim)
vindex_positions = np.arange(ndim)

# Optimize np.arange calls and mixed_positions calculation
# Use cached first position for mixed_positions calculation
base = advanced_index_positions[0]
# Avoid np.arange if ndim is zero (no-op)
if ndim == 0:
mixed_positions = np.empty(0, dtype=int)
vindex_positions = np.empty(0, dtype=int)
else:
arange_ndim = np.arange(ndim)
mixed_positions = base + arange_ndim
vindex_positions = arange_ndim

return mixed_positions, vindex_positions


Expand All @@ -166,6 +197,9 @@ def __init__(self, array):

def __getitem__(self, key):
mixed_positions, vindex_positions = _advanced_indexer_subspaces(key)
# Skip moveaxis if no-op, avoids overhead for common simple indexing
if mixed_positions is None or len(mixed_positions) == 0:
return self._array[key]
return np.moveaxis(self._array[key], mixed_positions, vindex_positions)

def __setitem__(self, key, value):
Expand Down