@@ -807,6 +807,33 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
807807public:
808808 using OpConversionPattern<cir::UnaryOp>::OpConversionPattern;
809809
810+ template <typename OpFloat, typename OpInt, bool rev>
811+ mlir::Operation *
812+ replaceImmediateOp (cir::UnaryOp op, mlir::Type type, mlir::Value input,
813+ int64_t n,
814+ mlir::ConversionPatternRewriter &rewriter) const {
815+ if (type.isFloat ()) {
816+ auto imm = mlir::arith::ConstantOp::create (
817+ rewriter, op.getLoc (),
818+ mlir::FloatAttr::get (type, static_cast <double >(n)));
819+ if constexpr (rev)
820+ return rewriter.replaceOpWithNewOp <OpFloat>(op, type, imm, input);
821+ else
822+ return rewriter.replaceOpWithNewOp <OpFloat>(op, type, input, imm);
823+ }
824+ if (type.isInteger ()) {
825+ auto imm = mlir::arith::ConstantOp::create (
826+ rewriter, op.getLoc (), mlir::IntegerAttr::get (type, n));
827+ if constexpr (rev)
828+ return rewriter.replaceOpWithNewOp <OpInt>(op, type, imm, input);
829+ else
830+ return rewriter.replaceOpWithNewOp <OpInt>(op, type, input, imm);
831+ }
832+ op->emitError (" Unsupported type: " ) << type << " at " << op->getLoc ();
833+ llvm_unreachable (" CIRUnaryOpLowering met unsupported type" );
834+ return nullptr ;
835+ }
836+
810837 mlir::LogicalResult
811838 matchAndRewrite (cir::UnaryOp op, OpAdaptor adaptor,
812839 mlir::ConversionPatternRewriter &rewriter) const override {
@@ -815,36 +842,31 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
815842
816843 switch (op.getKind ()) {
817844 case cir::UnaryOpKind::Inc: {
818- auto One = mlir::arith::ConstantOp::create (
819- rewriter, op.getLoc (), mlir::IntegerAttr::get (type, 1 ));
820- rewriter.replaceOpWithNewOp <mlir::arith::AddIOp>(op, type, input, One);
845+ replaceImmediateOp<mlir::arith::AddFOp, mlir::arith::AddIOp, false >(
846+ op, type, input, 1 , rewriter);
821847 break ;
822848 }
823849 case cir::UnaryOpKind::Dec: {
824- auto One = mlir::arith::ConstantOp::create (
825- rewriter, op.getLoc (), mlir::IntegerAttr::get (type, 1 ));
826- rewriter.replaceOpWithNewOp <mlir::arith::SubIOp>(op, type, input, One);
850+ replaceImmediateOp<mlir::arith::AddFOp, mlir::arith::AddIOp, false >(
851+ op, type, input, -1 , rewriter);
827852 break ;
828853 }
829854 case cir::UnaryOpKind::Plus: {
830855 rewriter.replaceOp (op, op.getInput ());
831856 break ;
832857 }
833858 case cir::UnaryOpKind::Minus: {
834- auto Zero = mlir::arith::ConstantOp::create (
835- rewriter, op.getLoc (), mlir::IntegerAttr::get (type, 0 ));
836- rewriter.replaceOpWithNewOp <mlir::arith::SubIOp>(op, type, Zero, input);
859+ replaceImmediateOp<mlir::arith::SubFOp, mlir::arith::SubIOp, true >(
860+ op, type, input, 0 , rewriter);
837861 break ;
838862 }
839863 case cir::UnaryOpKind::Not: {
840- auto MinusOne = mlir::arith::ConstantOp::create (
864+ auto o = mlir::arith::ConstantOp::create (
841865 rewriter, op.getLoc (), mlir::IntegerAttr::get (type, -1 ));
842- rewriter.replaceOpWithNewOp <mlir::arith::XOrIOp>(op, type, MinusOne,
843- input);
866+ rewriter.replaceOpWithNewOp <mlir::arith::XOrIOp>(op, type, o, input);
844867 break ;
845868 }
846869 }
847-
848870 return mlir::LogicalResult::success ();
849871 }
850872};
0 commit comments