Skip to content

Commit 961d7a4

Browse files
committed
Add SciPy conversions.
1 parent 1a20262 commit 961d7a4

File tree

8 files changed

+94
-644
lines changed

8 files changed

+94
-644
lines changed

pixi.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ mkdocs-jupyter = "*"
2727

2828
[feature.tests.tasks]
2929
test = "pytest --pyargs sparse -n auto"
30-
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -vvv", env = { SPARSE_BACKEND = "MLIR" } }
31-
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -vvv", env = { SPARSE_BACKEND = "Finch", PYTHONFAULTHANDLER = "${HOME}/faulthandler.log" }, depends-on = ["precompile"] }
30+
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -v", env = { SPARSE_BACKEND = "MLIR" } }
31+
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -v", env = { SPARSE_BACKEND = "Finch", PYTHONFAULTHANDLER = "${HOME}/faulthandler.log" }, depends-on = ["precompile"] }
3232

3333
[feature.tests.dependencies]
3434
pytest = ">=3.5"

sparse/mlir_backend/__init__.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,20 @@
77
"to enable MLIR backend."
88
) from e
99

10-
from ._constructors import (
11-
PackedArgumentTuple,
12-
asarray,
13-
)
10+
from ._common import PackedArgumentTuple
11+
from ._conversions import asarray, to_numpy, to_scipy
1412
from ._dtypes import (
1513
asdtype,
1614
)
1715
from ._ops import (
1816
add,
19-
broadcast_to,
20-
reshape,
2117
)
2218

2319
__all__ = [
2420
"add",
25-
"broadcast_to",
2621
"asarray",
2722
"asdtype",
28-
"reshape",
2923
"PackedArgumentTuple",
24+
"to_numpy",
25+
"to_scipy",
3026
]

sparse/mlir_backend/_common.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import abc
21
import ctypes
32
import functools
43
import weakref
54
from dataclasses import dataclass
65

76
import mlir.runtime as rt
8-
from mlir import ir
97

108
import numpy as np
119

@@ -17,8 +15,12 @@ def fn_cache(f, maxsize: int | None = None):
1715
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))
1816

1917

20-
@fn_cache
2118
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
19+
return _get_nd_memref_descr(int(rank), asdtype(dtype))
20+
21+
22+
@fn_cache
23+
def _get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
2224
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())
2325

2426

@@ -43,12 +45,6 @@ def free_memref(obj: ctypes.Structure) -> None:
4345
libc.free(ctypes.cast(obj.allocated, ctypes.c_void_p))
4446

4547

46-
class MlirType(abc.ABC):
47-
@classmethod
48-
@abc.abstractmethod
49-
def get_mlir_type(cls) -> ir.Type: ...
50-
51-
5248
@dataclass
5349
class PackedArgumentTuple:
5450
contents: tuple
@@ -67,13 +63,13 @@ def _hold_self_ref_in_ret(fn):
6763
@functools.wraps(fn)
6864
def wrapped(self, *a, **kw):
6965
ret = fn(self, *a, **kw)
70-
_take_owneship(ret, self)
66+
_hold_ref(ret, self)
7167
return ret
7268

7369
return wrapped
7470

7571

76-
def _take_owneship(owner, obj):
72+
def _hold_ref(owner, obj):
7773
ptr = ctypes.py_object(obj)
7874
ctypes.pythonapi.Py_IncRef(ptr)
7975

0 commit comments

Comments
 (0)