Skip to content

Commit 34fbab8

Browse files
Convert predicates to strings (#229)
This allows the use of the predicate enums instead of just the string versions.
1 parent 0567347 commit 34fbab8

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

projects/eudsl-python-extras/mlir/extras/dialects/arith.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def _arith_CmpIPredicateAttr(predicate: Union[str, Attribute], context: Context)
249249
}
250250
if isinstance(predicate, Attribute):
251251
return predicate
252+
predicate = str(predicate)
252253
assert predicate in predicates, f"{predicate=} not in predicates"
253254
return _arith_cmpipredicateattr(predicates[predicate], context)
254255

@@ -282,6 +283,7 @@ def _arith_CmpFPredicateAttr(predicate: Union[str, Attribute], context: Context)
282283
}
283284
if isinstance(predicate, Attribute):
284285
return predicate
286+
predicate = str(predicate)
285287
assert predicate in predicates, f"{predicate=} not in predicates"
286288
return _arith_cmpfpredicateattr(predicates[predicate], context)
287289

projects/eudsl-python-extras/tests/dialect/test_arith.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def test_arithmetic(ctx: MLIRContext):
9292
# CHECK: %[[VAL_10:.*]] = arith.subf %[[VAL_7]], %[[VAL_8]] : f32
9393
# CHECK: %[[VAL_11:.*]] = arith.divf %[[VAL_7]], %[[VAL_8]] : f32
9494
# CHECK: %[[VAL_12:.*]] = arith.remf %[[VAL_7]], %[[VAL_8]] : f32
95-
9695
filecheck_with_comments(ctx.module)
9796

9897

@@ -161,6 +160,51 @@ def test_arith_cmp(ctx: MLIRContext):
161160
filecheck_with_comments(ctx.module)
162161

163162

163+
def test_arith_cmp_enum_values(ctx: MLIRContext):
164+
one = arith.constant(1)
165+
two = arith.constant(2)
166+
for pred in arith.CmpIPredicate.__members__.values():
167+
arith.cmpi(pred, one, two)
168+
169+
one, two = arith.constant(1.0), arith.constant(2.0)
170+
for pred in arith.CmpFPredicate.__members__.values():
171+
arith.cmpf(pred, one, two)
172+
173+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
174+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 2 : i32
175+
# CHECK: %[[CMPI_0:.*]] = arith.cmpi eq, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
176+
# CHECK: %[[CMPI_1:.*]] = arith.cmpi ne, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
177+
# CHECK: %[[CMPI_2:.*]] = arith.cmpi slt, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
178+
# CHECK: %[[CMPI_3:.*]] = arith.cmpi sle, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
179+
# CHECK: %[[CMPI_4:.*]] = arith.cmpi sgt, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
180+
# CHECK: %[[CMPI_5:.*]] = arith.cmpi sge, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
181+
# CHECK: %[[CMPI_6:.*]] = arith.cmpi ult, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
182+
# CHECK: %[[CMPI_7:.*]] = arith.cmpi ule, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
183+
# CHECK: %[[CMPI_8:.*]] = arith.cmpi ugt, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
184+
# CHECK: %[[CMPI_9:.*]] = arith.cmpi uge, %[[CONSTANT_0]], %[[CONSTANT_1]] : i32
185+
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 1.000000e+00 : f32
186+
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 2.000000e+00 : f32
187+
# CHECK: %[[CMPF_0:.*]] = arith.cmpf false, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
188+
# CHECK: %[[CMPF_1:.*]] = arith.cmpf oeq, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
189+
# CHECK: %[[CMPF_2:.*]] = arith.cmpf ogt, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
190+
# CHECK: %[[CMPF_3:.*]] = arith.cmpf oge, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
191+
# CHECK: %[[CMPF_4:.*]] = arith.cmpf olt, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
192+
# CHECK: %[[CMPF_5:.*]] = arith.cmpf ole, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
193+
# CHECK: %[[CMPF_6:.*]] = arith.cmpf one, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
194+
# CHECK: %[[CMPF_7:.*]] = arith.cmpf ord, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
195+
# CHECK: %[[CMPF_8:.*]] = arith.cmpf ueq, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
196+
# CHECK: %[[CMPF_9:.*]] = arith.cmpf ugt, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
197+
# CHECK: %[[CMPF_10:.*]] = arith.cmpf uge, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
198+
# CHECK: %[[CMPF_11:.*]] = arith.cmpf ult, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
199+
# CHECK: %[[CMPF_12:.*]] = arith.cmpf ule, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
200+
# CHECK: %[[CMPF_13:.*]] = arith.cmpf une, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
201+
# CHECK: %[[CMPF_14:.*]] = arith.cmpf uno, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
202+
# CHECK: %[[CMPF_15:.*]] = arith.cmpf true, %[[CONSTANT_2]], %[[CONSTANT_3]] : f32
203+
204+
ctx.module.operation.verify()
205+
filecheck_with_comments(ctx.module)
206+
207+
164208
def test_arith_cmp_literals(ctx: MLIRContext):
165209
one = arith.constant(1)
166210
two = 2

0 commit comments

Comments
 (0)