Skip to content

Commit 4ad635b

Browse files
committed
Simplify and type hint to_numpy and to_scipy.
1 parent a5b692e commit 4ad635b

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

sparse/mlir_backend/_conversions.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,14 @@ def _from_numpy(arr: np.ndarray, copy: bool | None = None) -> Array:
4343
return from_constituent_arrays(format=dense_format, arrays=(arr_flat,), shape=arr.shape)
4444

4545

46-
def to_numpy(arr):
47-
storage = arr._storage
48-
storage_format: StorageFormat = storage.get_storage_format()
46+
def to_numpy(arr: Array) -> np.ndarray:
47+
storage_format: StorageFormat = arr.format
4948

5049
if not all(LevelFormat.Dense == level.format for level in storage_format.levels):
5150
raise TypeError(f"Cannot convert a non-dense array to NumPy. `{storage_format=}`")
5251

5352
data = ranked_memref_to_numpy(arr._storage.values)
54-
_hold_ref(data, storage)
53+
_hold_ref(data, arr._storage)
5554
arg_order = [0] * storage_format.storage_rank
5655
for i, o in enumerate(storage_format.order):
5756
arg_order[o] = i
@@ -126,27 +125,23 @@ def _from_scipy(arr: ScipySparseArray, copy: bool | None = None) -> Array:
126125

127126

128127
@_guard_scipy
129-
def to_scipy(arr) -> ScipySparseArray:
130-
storage = arr._storage
131-
storage_format: StorageFormat = storage.get_storage_format()
128+
def to_scipy(arr: Array) -> ScipySparseArray:
129+
storage_format = arr.format
132130

133131
match storage_format.levels:
134132
case (Level(LevelFormat.Dense, _), Level(LevelFormat.Compressed, _)):
135-
data = ranked_memref_to_numpy(storage.values)
136-
indices = ranked_memref_to_numpy(storage.indices_1)
137-
indptr = ranked_memref_to_numpy(storage.pointers_to_1)
133+
indptr, indices, data = arr.get_constituent_arrays()
138134
if storage_format.order == (0, 1):
139135
sps_arr = sps.csr_array((data, indices, indptr), shape=arr.shape)
140136
else:
141137
sps_arr = sps.csc_array((data, indices, indptr), shape=arr.shape)
142138
case (Level(LevelFormat.Compressed, _), Level(LevelFormat.Singleton, _)):
143-
data = ranked_memref_to_numpy(storage.values)
144-
coords = ranked_memref_to_numpy(storage.indices_1)
139+
_, coords, data = arr.get_constituent_arrays()
145140
sps_arr = sps.coo_array((data, (coords[:, 0], coords[:, 1])), shape=arr.shape)
146141
case _:
147142
raise RuntimeError(f"No conversion implemented for `{storage_format=}`.")
148143

149-
_hold_ref(sps_arr, storage)
144+
_hold_ref(sps_arr, arr._storage)
150145
return sps_arr
151146

152147

0 commit comments

Comments
 (0)