Skip to content

Commit 1a20262

Browse files
committed
Add format generation from levels.
1 parent d770f66 commit 1a20262

File tree

4 files changed

+213
-34
lines changed

4 files changed

+213
-34
lines changed

pixi.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ mkdocs-jupyter = "*"
2828
[feature.tests.tasks]
2929
test = "pytest --pyargs sparse -n auto"
3030
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -vvv", env = { SPARSE_BACKEND = "MLIR" } }
31-
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto", env = { SPARSE_BACKEND = "Finch" }, depends-on = ["precompile"] }
31+
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -vvv", env = { SPARSE_BACKEND = "Finch", PYTHONFAULTHANDLER = "${HOME}/faulthandler.log" }, depends-on = ["precompile"] }
3232

3333
[feature.tests.dependencies]
3434
pytest = ">=3.5"

sparse/mlir_backend/_common.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,44 @@
44
import weakref
55
from dataclasses import dataclass
66

7+
import mlir.runtime as rt
78
from mlir import ir
89

10+
import numpy as np
11+
12+
from ._core import libc
13+
from ._dtypes import DType, asdtype
14+
15+
16+
def fn_cache(f, maxsize: int | None = None):
17+
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))
18+
19+
20+
@fn_cache
21+
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> ctypes.Structure:
22+
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())
23+
24+
25+
def numpy_to_ranked_memref(arr: np.ndarray) -> ctypes.Structure:
26+
memref = rt.get_ranked_memref_descriptor(arr)
27+
memref_descr = get_nd_memref_descr(arr.ndim, asdtype(arr.dtype))
28+
# Required due to ctypes type checks
29+
return memref_descr(
30+
allocated=memref.allocated,
31+
aligned=memref.aligned,
32+
offset=memref.offset,
33+
shape=memref.shape,
34+
strides=memref.strides,
35+
)
36+
37+
38+
def ranked_memref_to_numpy(ref: ctypes.Structure) -> np.ndarray:
39+
return rt.ranked_memref_to_numpy([ref])
40+
41+
42+
def free_memref(obj: ctypes.Structure) -> None:
43+
libc.free(ctypes.cast(obj.allocated, ctypes.c_void_p))
44+
945

1046
class MlirType(abc.ABC):
1147
@classmethod
@@ -27,10 +63,6 @@ def __len__(self):
2763
return len(self.contents)
2864

2965

30-
def fn_cache(f, maxsize: int | None = None):
31-
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))
32-
33-
3466
def _hold_self_ref_in_ret(fn):
3567
@functools.wraps(fn)
3668
def wrapped(self, *a, **kw):

sparse/mlir_backend/_constructors.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,30 @@
22
from collections.abc import Iterable
33
from typing import Any
44

5-
import mlir.runtime as rt
65
from mlir import ir
76
from mlir.dialects import sparse_tensor
87

98
import numpy as np
109
import scipy.sparse as sps
1110

12-
from ._common import PackedArgumentTuple, _hold_self_ref_in_ret, _take_owneship, fn_cache
13-
from ._core import ctx, libc
11+
from ._common import (
12+
PackedArgumentTuple,
13+
_hold_self_ref_in_ret,
14+
_take_owneship,
15+
fn_cache,
16+
free_memref,
17+
get_nd_memref_descr,
18+
numpy_to_ranked_memref,
19+
ranked_memref_to_numpy,
20+
)
21+
from ._core import ctx
1422
from ._dtypes import DType, asdtype
1523

1624
###########
1725
# Memrefs #
1826
###########
1927

2028

21-
@fn_cache
22-
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> type:
23-
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())
24-
25-
26-
def numpy_to_ranked_memref(arr: np.ndarray) -> ctypes.Structure:
27-
memref = rt.get_ranked_memref_descriptor(arr)
28-
memref_descr = get_nd_memref_descr(arr.ndim, asdtype(arr.dtype))
29-
# Required due to ctypes type checks
30-
return memref_descr(
31-
allocated=memref.allocated,
32-
aligned=memref.aligned,
33-
offset=memref.offset,
34-
shape=memref.shape,
35-
strides=memref.strides,
36-
)
37-
38-
39-
def ranked_memref_to_numpy(ref: ctypes.Structure) -> np.ndarray:
40-
return rt.ranked_memref_to_numpy([ref])
41-
42-
43-
def free_memref(obj: ctypes.Structure) -> None:
44-
libc.free(ctypes.cast(obj.allocated, ctypes.c_void_p))
45-
46-
4729
###########
4830
# Formats #
4931
###########

sparse/mlir_backend/_levels.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import ctypes
2+
import dataclasses
3+
import enum
4+
import itertools
5+
import re
6+
import typing
7+
8+
import mlir.runtime as rt
9+
from mlir import ir
10+
from mlir.dialects import sparse_tensor
11+
12+
import numpy as np
13+
14+
from ._common import (
15+
PackedArgumentTuple,
16+
_take_owneship,
17+
fn_cache,
18+
numpy_to_ranked_memref,
19+
ranked_memref_to_numpy,
20+
)
21+
from ._dtypes import DType, asdtype
22+
23+
_CAMEL_TO_SNAKE = [re.compile("(.)([A-Z][a-z]+)"), re.compile("([a-z0-9])([A-Z])")]
24+
25+
26+
def _camel_to_snake(name: str) -> str:
27+
for exp in _CAMEL_TO_SNAKE:
28+
name = exp.sub(r"\1_\2", name)
29+
30+
return name.lower()
31+
32+
33+
@fn_cache
34+
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> type:
35+
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())
36+
37+
38+
class LevelProperties(enum.Flag):
39+
NonOrdered = enum.auto()
40+
NonUnique = enum.auto()
41+
42+
def build(self) -> list[sparse_tensor.LevelProperty]:
43+
return [getattr(sparse_tensor.LevelProperty, _camel_to_snake(p.name)) for p in type(self) if p in self]
44+
45+
46+
class LevelFormat(enum.Enum):
47+
Dense = "dense"
48+
Compressed = "compressed"
49+
Singleton = "singleton"
50+
51+
def build(self) -> sparse_tensor.LevelFormat:
52+
return getattr(sparse_tensor.LevelFormat, self.value)
53+
54+
55+
@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
56+
class Level:
57+
format: LevelFormat
58+
properties: LevelProperties = LevelProperties(0)
59+
60+
def build(self):
61+
sparse_tensor.EncodingAttr.build_level_type(self.format.build(), self.properties.build())
62+
63+
64+
@dataclasses.dataclass(kw_only=True)
65+
class StorageFormat:
66+
levels: tuple[Level, ...]
67+
order: typing.Literal["C", "F"] | tuple[int, ...]
68+
pos_width: int
69+
crd_width: int
70+
dtype: type[DType]
71+
72+
@property
73+
def storage_rank(self) -> int:
74+
return len(self.levels)
75+
76+
@property
77+
def rank(self) -> int:
78+
return self.storage_rank
79+
80+
def __post_init__(self):
81+
rank = self.storage_rank
82+
self.dtype = asdtype(self.dtype)
83+
if self.order == "C":
84+
self.order = tuple(range(rank))
85+
return
86+
87+
if self.order == "F":
88+
self.order = tuple(reversed(range(rank)))
89+
return
90+
91+
if sorted(self.order) != list(range(rank)):
92+
raise ValueError(f"`sorted(self.order) != list(range(rank))`, {self.order=}, {rank=}.")
93+
94+
self.order = tuple(self.order)
95+
96+
@fn_cache
97+
def get_mlir_type(self, *, shape: tuple[int, ...]) -> ir.RankedTensorType:
98+
if len(shape) != self.rank:
99+
raise ValueError(f"`len(shape) != self.rank`, {shape=}, {self.rank=}")
100+
mlir_levels = [level.build() for level in self.levels]
101+
mlir_order = list(self.order)
102+
mlir_reverse_order = [0] * self.rank
103+
for i, r in enumerate(mlir_order):
104+
mlir_reverse_order[r] = i
105+
106+
dtype = self.dtype.get_mlir_type()
107+
encoding = sparse_tensor.EncodingAttr.get(
108+
mlir_levels, mlir_order, mlir_reverse_order, self.pos_width, self.crd_width
109+
)
110+
return ir.RankedTensorType.get(list(shape), dtype, encoding)
111+
112+
@fn_cache
113+
def get_ctypes_type(self):
114+
ptr_dtype = asdtype(getattr(np, f"uint{self.pos_width}"))
115+
idx_dtype = asdtype(getattr(np, f"uint{self.crd_width}"))
116+
117+
def get_fields():
118+
fields = []
119+
compressed_counter = 0
120+
for level, next_level in itertools.zip_longest(self.levels, self.levels[1:]):
121+
if LevelFormat.Compressed == level.format:
122+
compressed_counter += 1
123+
fields.append((f"pointers_to_{compressed_counter}", get_nd_memref_descr(1, ptr_dtype)))
124+
if next_level is not None and LevelFormat.Singleton == next_level.format:
125+
fields.append((f"indices_{compressed_counter}", get_nd_memref_descr(2, idx_dtype)))
126+
else:
127+
fields.append((f"indices_{compressed_counter}", get_nd_memref_descr(1, idx_dtype)))
128+
129+
fields.append(("values", get_nd_memref_descr(1, self.dtype.np_dtype)))
130+
return fields
131+
132+
storage_format = self
133+
134+
class Format(ctypes.Structure):
135+
_fields_ = get_fields()
136+
137+
def get_mlir_type(self, *, shape: tuple[int, ...]):
138+
return self.get_storage_format().get_mlir_type(shape=shape)
139+
140+
def to_module_arg(self) -> list:
141+
return [ctypes.pointer(ctypes.pointer(f) for f in self.get__fields_())]
142+
143+
def get__fields_(self) -> list:
144+
return [getattr(self, field[0]) for field in self._fields_]
145+
146+
def to_constituent_arrays(self) -> PackedArgumentTuple:
147+
return PackedArgumentTuple(tuple(ranked_memref_to_numpy(field) for field in self.get__fields_()))
148+
149+
def get_storage_format(self) -> StorageFormat:
150+
return storage_format
151+
152+
@classmethod
153+
def from_constituent_arrays(cls, arrs: list[np.ndarray]) -> "Format":
154+
inst = cls(*(numpy_to_ranked_memref(arr) for arr in arrs))
155+
for arr in arrs:
156+
_take_owneship(inst, arr)
157+
return inst
158+
159+
return Format
160+
161+
def __hash__(self):
162+
return hash(id(self))
163+
164+
def __eq__(self, value):
165+
return self is value

0 commit comments

Comments
 (0)