Skip to content

Commit 6732b5e

Browse files
committed
Add format generation from levels.
1 parent d770f66 commit 6732b5e

File tree

2 files changed

+145
-1
lines changed

2 files changed

+145
-1
lines changed

pixi.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ mkdocs-jupyter = "*"
2828
[feature.tests.tasks]
2929
test = "pytest --pyargs sparse -n auto"
3030
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -vvv", env = { SPARSE_BACKEND = "MLIR" } }
31-
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto", env = { SPARSE_BACKEND = "Finch" }, depends-on = ["precompile"] }
31+
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -vvv", env = { SPARSE_BACKEND = "Finch", PYTHONFAULTHANDLER = "${HOME}/faulthandler.log" }, depends-on = ["precompile"] }
3232

3333
[feature.tests.dependencies]
3434
pytest = ">=3.5"

sparse/mlir_backend/_levels.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import ctypes
2+
import dataclasses
3+
import enum
4+
import itertools
5+
import re
6+
import typing
7+
8+
import mlir.runtime as rt
9+
from mlir import ir
10+
from mlir.dialects import sparse_tensor
11+
12+
import numpy as np
13+
14+
from ._common import MlirType, fn_cache
15+
from ._dtypes import DType, asdtype
16+
17+
_CAMEL_TO_SNAKE = [re.compile("(.)([A-Z][a-z]+)"), re.compile("([a-z0-9])([A-Z])")]
18+
19+
20+
def _camel_to_snake(name: str) -> str:
21+
for exp in _CAMEL_TO_SNAKE:
22+
name = exp.sub(r"\1_\2", name)
23+
24+
return name.lower()
25+
26+
27+
@fn_cache
28+
def get_nd_memref_descr(rank: int, dtype: type[DType]) -> type:
29+
return rt.make_nd_memref_descriptor(rank, dtype.to_ctype())
30+
31+
32+
class LevelProperties(enum.Flag):
33+
NonOrdered = enum.auto()
34+
NonUnique = enum.auto()
35+
36+
def build(self) -> list[sparse_tensor.LevelProperty]:
37+
return [getattr(sparse_tensor.LevelProperty, _camel_to_snake(p.name)) for p in type(self) if p in self]
38+
39+
40+
class LevelFormat(enum.Enum):
41+
Dense = "dense"
42+
Compressed = "compressed"
43+
Singleton = "singleton"
44+
45+
def build(self) -> sparse_tensor.LevelFormat:
46+
return getattr(sparse_tensor.LevelFormat, self.value)
47+
48+
49+
@dataclasses.dataclass
50+
class Level:
51+
format: LevelFormat
52+
properties: LevelProperties = LevelProperties(0)
53+
54+
def build(self):
55+
sparse_tensor.EncodingAttr.build_level_type(self.format.build(), self.properties.build())
56+
57+
58+
@dataclasses.dataclass
59+
class StorageFormat:
60+
levels: tuple[Level, ...]
61+
order: typing.Literal["C", "F"] | tuple[int, ...]
62+
pos_width: int
63+
crd_width: int
64+
dtype: type[DType]
65+
66+
@property
67+
def storage_rank(self) -> int:
68+
return len(self.levels)
69+
70+
@property
71+
def rank(self) -> int:
72+
return self.storage_rank
73+
74+
def __post_init__(self):
75+
rank = self.storage_rank
76+
self.dtype = asdtype(self.dtype)
77+
if self.order == "C":
78+
self.order = tuple(range(rank))
79+
return
80+
81+
if self.order == "F":
82+
self.order = tuple(range(rank))[::-1]
83+
return
84+
85+
if sorted(self.order) != list(range(rank)):
86+
raise ValueError(f"`sorted(self.order) != list(range(rank))`, {self.order=}, {rank=}.")
87+
88+
self.order = tuple(self.order)
89+
90+
@fn_cache
91+
def get_mlir_type(self, *, shape: tuple[int, ...]) -> ir.RankedTensorType:
92+
if len(shape) != self.rank:
93+
raise ValueError(f"`len(shape) != self.rank`, {shape=}, {self.rank=}")
94+
mlir_levels = [level.build() for level in self.levels]
95+
mlir_order = list(self.order)
96+
mlir_reverse_order = [0] * self.rank
97+
for i, r in enumerate(mlir_order):
98+
mlir_reverse_order[r] = i
99+
100+
dtype = self.dtype.get_mlir_type()
101+
encoding = sparse_tensor.EncodingAttr.get(
102+
mlir_levels, mlir_order, mlir_reverse_order, self.pos_width, self.crd_width
103+
)
104+
return ir.RankedTensorType.get(list(shape), dtype, encoding)
105+
106+
@fn_cache
107+
def get_ctypes_type(self):
108+
ptr_dtype = asdtype(getattr(np, f"uint{self.pos_width}"))
109+
idx_dtype = asdtype(getattr(np, f"uint{self.crd_width}"))
110+
111+
def get_fields():
112+
fields = []
113+
compressed_counter = 0
114+
for level, next_level in itertools.zip_longest(self.levels, self.levels[1:]):
115+
if LevelFormat.Compressed == level.format:
116+
compressed_counter += 1
117+
fields.append((f"pointers_to_{compressed_counter}", get_nd_memref_descr(1, ptr_dtype)))
118+
if next_level is not None and LevelFormat.Singleton == next_level.format:
119+
fields.append((f"indices_{compressed_counter}", get_nd_memref_descr(2, idx_dtype)))
120+
else:
121+
fields.append((f"indices_{compressed_counter}", get_nd_memref_descr(1, idx_dtype)))
122+
123+
fields.append(("values", get_nd_memref_descr(1, self.dtype.np_dtype)))
124+
return fields
125+
126+
class Format(ctypes.Structure, MlirType):
127+
_fields_ = get_fields()
128+
129+
def get_mlir_type(self, *, shape: tuple[int, ...]):
130+
return self.get_mlir_type(shape=shape)
131+
132+
def to_module_arg(self) -> list:
133+
return [ctypes.pointer(ctypes.pointer(f) for f in self.get__fields_())]
134+
135+
def get__fields_(self) -> list:
136+
return [getattr(self, field[0]) for field in self._fields_]
137+
138+
return Format
139+
140+
def __hash__(self):
141+
return hash(id(self))
142+
143+
def __eq__(self, value):
144+
return self is value

0 commit comments

Comments
 (0)