Skip to content

Commit d025cea

Browse files
[CIR][MLIR][LoweringThroughMLIR] CIRUnaryOpLowering on float values. (#2032)
UnaryOpKind Inc, Dec, Plus and Minus can accept float operands, the lowering should also handle those situations. Before this commit, the compiler crash when it met float operands.
1 parent f35f133 commit d025cea

File tree

4 files changed

+73
-23
lines changed

4 files changed

+73
-23
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,33 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
807807
public:
808808
using OpConversionPattern<cir::UnaryOp>::OpConversionPattern;
809809

810+
template <typename OpFloat, typename OpInt, bool rev>
811+
mlir::Operation *
812+
replaceImmediateOp(cir::UnaryOp op, mlir::Type type, mlir::Value input,
813+
int64_t n,
814+
mlir::ConversionPatternRewriter &rewriter) const {
815+
if (type.isFloat()) {
816+
auto imm = mlir::arith::ConstantOp::create(
817+
rewriter, op.getLoc(),
818+
mlir::FloatAttr::get(type, static_cast<double>(n)));
819+
if constexpr (rev)
820+
return rewriter.replaceOpWithNewOp<OpFloat>(op, type, imm, input);
821+
else
822+
return rewriter.replaceOpWithNewOp<OpFloat>(op, type, input, imm);
823+
}
824+
if (type.isInteger()) {
825+
auto imm = mlir::arith::ConstantOp::create(
826+
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, n));
827+
if constexpr (rev)
828+
return rewriter.replaceOpWithNewOp<OpInt>(op, type, imm, input);
829+
else
830+
return rewriter.replaceOpWithNewOp<OpInt>(op, type, input, imm);
831+
}
832+
op->emitError("Unsupported type: ") << type << " at " << op->getLoc();
833+
llvm_unreachable("CIRUnaryOpLowering met unsupported type");
834+
return nullptr;
835+
}
836+
810837
mlir::LogicalResult
811838
matchAndRewrite(cir::UnaryOp op, OpAdaptor adaptor,
812839
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -815,36 +842,31 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
815842

816843
switch (op.getKind()) {
817844
case cir::UnaryOpKind::Inc: {
818-
auto One = mlir::arith::ConstantOp::create(
819-
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 1));
820-
rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(op, type, input, One);
845+
replaceImmediateOp<mlir::arith::AddFOp, mlir::arith::AddIOp, false>(
846+
op, type, input, 1, rewriter);
821847
break;
822848
}
823849
case cir::UnaryOpKind::Dec: {
824-
auto One = mlir::arith::ConstantOp::create(
825-
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 1));
826-
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, input, One);
850+
replaceImmediateOp<mlir::arith::AddFOp, mlir::arith::AddIOp, false>(
851+
op, type, input, -1, rewriter);
827852
break;
828853
}
829854
case cir::UnaryOpKind::Plus: {
830855
rewriter.replaceOp(op, op.getInput());
831856
break;
832857
}
833858
case cir::UnaryOpKind::Minus: {
834-
auto Zero = mlir::arith::ConstantOp::create(
835-
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 0));
836-
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, Zero, input);
859+
replaceImmediateOp<mlir::arith::SubFOp, mlir::arith::SubIOp, true>(
860+
op, type, input, 0, rewriter);
837861
break;
838862
}
839863
case cir::UnaryOpKind::Not: {
840-
auto MinusOne = mlir::arith::ConstantOp::create(
864+
auto o = mlir::arith::ConstantOp::create(
841865
rewriter, op.getLoc(), mlir::IntegerAttr::get(type, -1));
842-
rewriter.replaceOpWithNewOp<mlir::arith::XOrIOp>(op, type, MinusOne,
843-
input);
866+
rewriter.replaceOpWithNewOp<mlir::arith::XOrIOp>(op, type, o, input);
844867
break;
845868
}
846869
}
847-
848870
return mlir::LogicalResult::success();
849871
}
850872
};

clang/test/CIR/Lowering/ThroughMLIR/if.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ void foo() {
2929
//CHECK: memref.store %[[SEVEN]], %[[alloca_0]][] : memref<i32>
3030
//CHECK: } else {
3131
//CHECK: %[[SIX:.+]] = memref.load %[[alloca_0]][] : memref<i32>
32-
//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
33-
//CHECK: %[[SEVEN:.+]] = arith.subi %[[SIX]], %[[C1_I32]] : i32
32+
//CHECK: %[[C1_I32:.+]] = arith.constant -1 : i32
33+
//CHECK: %[[SEVEN:.+]] = arith.addi %[[SIX]], %[[C1_I32]] : i32
3434
//CHECK: memref.store %[[SEVEN]], %[[alloca_0]][] : memref<i32>
3535
//CHECK: }
3636
//CHECK: }
@@ -106,8 +106,8 @@ void foo3() {
106106
//CHECK: memref.store %[[THIRTEEN]], %[[alloca_0]][] : memref<i32>
107107
//CHECK: } else {
108108
//CHECK: %[[TWELVE:.+]] = memref.load %[[alloca_0]][] : memref<i32>
109-
//CHECK: %[[C1_I32_5:.+]] = arith.constant 1 : i32
110-
//CHECK: %[[THIRTEEN:.+]] = arith.subi %[[TWELVE]], %[[C1_I32_5]] : i32
109+
//CHECK: %[[C1_I32_5:.+]] = arith.constant -1 : i32
110+
//CHECK: %[[THIRTEEN:.+]] = arith.addi %[[TWELVE]], %[[C1_I32_5]] : i32
111111
//CHECK: memref.store %[[THIRTEEN]], %[[alloca_0]][] : memref<i32>
112112
//CHECK: }
113113
//CHECK: }

clang/test/CIR/Lowering/ThroughMLIR/unary-inc-dec.cir

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: cir-opt %s -cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR
2-
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
1+
// RUN: cir-opt %s --cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR
2+
// RUN: cir-opt %s --cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
33

44
!s32i = !cir.int<s, 32>
55
module {
@@ -17,14 +17,32 @@ module {
1717
%5 = cir.load %1 : !cir.ptr<!s32i>, !s32i
1818
%6 = cir.unary(dec, %5) : !s32i, !s32i
1919
cir.store %6, %1 : !s32i, !cir.ptr<!s32i>
20+
21+
// test float
22+
%7 = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64}
2023
cir.return
2124
}
22-
}
2325

2426
// MLIR: = arith.constant 1
2527
// MLIR: = arith.addi
26-
// MLIR: = arith.constant 1
27-
// MLIR: = arith.subi
28+
// MLIR: = arith.constant -1
29+
// MLIR: = arith.addi
2830

2931
// LLVM: = add i32 %[[#]], 1
30-
// LLVM: = sub i32 %[[#]], 1
32+
// LLVM: = add i32 %[[#]], -1
33+
34+
35+
cir.func @floatingPoints(%arg0: !cir.double) {
36+
%0 = cir.alloca !cir.double, !cir.ptr<!cir.double>, ["X", init] {alignment = 8 : i64}
37+
cir.store %arg0, %0 : !cir.double, !cir.ptr<!cir.double>
38+
%1 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
39+
%2 = cir.unary(inc, %1) : !cir.double, !cir.double
40+
%3 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
41+
%4 = cir.unary(dec, %3) : !cir.double, !cir.double
42+
cir.return
43+
}
44+
// MLIR: = arith.constant 1.0
45+
// MLIR: = arith.addf
46+
// MLIR: = arith.constant -1.0
47+
// MLIR: = arith.addf
48+
}

clang/test/CIR/Lowering/ThroughMLIR/unary-plus-minus.cir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ module {
1919
cir.store %6, %1 : !s32i, !cir.ptr<!s32i>
2020
cir.return
2121
}
22+
23+
cir.func @floatingPoints(%arg0: !cir.double) {
24+
%0 = cir.alloca !cir.double, !cir.ptr<!cir.double>, ["X", init] {alignment = 8 : i64}
25+
cir.store %arg0, %0 : !cir.double, !cir.ptr<!cir.double>
26+
%1 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
27+
%2 = cir.unary(plus, %1) : !cir.double, !cir.double
28+
%3 = cir.load %0 : !cir.ptr<!cir.double>, !cir.double
29+
%4 = cir.unary(minus, %3) : !cir.double, !cir.double
30+
cir.return
31+
}
2232
}
2333

2434
// MLIR: %[[#INPUT_PLUS:]] = memref.load

0 commit comments

Comments
 (0)