diff --git a/Sources/SQLKitExtras/SQLKitExtras/Expressions/SQLCastExpression.swift b/Sources/SQLKitExtras/SQLKitExtras/Expressions/SQLCastExpression.swift index d2e596e..e207476 100644 --- a/Sources/SQLKitExtras/SQLKitExtras/Expressions/SQLCastExpression.swift +++ b/Sources/SQLKitExtras/SQLKitExtras/Expressions/SQLCastExpression.swift @@ -29,7 +29,17 @@ public struct SQLCastExpression: SQLExpression { /// See `SQLExpression.serialize(to:)`. public func serialize(to serializer: inout SQLSerializer) { - SQLFunction("CAST", args: SQLAlias(self.original, as: self.desiredType)) + let desiredType: any SQLExpression = if + serializer.dialect.name == "mysql", + let ident = self.desiredType as? SQLIdentifier, + ident.string.allSatisfy({ $0.isASCII && ($0.isLowercase || $0.isUppercase || $0.isWholeNumber || $0 == "_") }) + { + SQLRaw(ident.string) + } else { + self.desiredType + } + + SQLFunction("CAST", args: SQLAlias(self.original, as: desiredType)) .serialize(to: &serializer) } } diff --git a/Tests/SQLKitExtrasTests/FluentSQLKitExtrasTests.swift b/Tests/SQLKitExtrasTests/FluentSQLKitExtrasTests.swift index 41388ca..56c1acc 100644 --- a/Tests/SQLKitExtrasTests/FluentSQLKitExtrasTests.swift +++ b/Tests/SQLKitExtrasTests/FluentSQLKitExtrasTests.swift @@ -302,6 +302,9 @@ struct FluentSQLKitExtrasTests { func castExpression() { #expect(serialize(.cast(\FooModel.$field, to: "text")) == #"CAST("foos"."field" AS "text")"#) #expect(serialize(.cast(\FooModel.$field, to: .unsafeRaw("text"))) == #"CAST("foos"."field" AS text)"#) + + #expect(MockSQLDatabase(dialect: "mysql").serialize(.cast(\FooModel.$field, to: "text")).sql == #"CAST("foos"."field" AS text)"#) + #expect(MockSQLDatabase(dialect: "postgresql").serialize(.cast(\FooModel.$field, to: "text")).sql == #"CAST("foos"."field" AS "text")"#) } } } diff --git a/Tests/SQLKitExtrasTests/SQLKitExtrasTests.swift b/Tests/SQLKitExtrasTests/SQLKitExtrasTests.swift index f4eac2b..0a39cb5 100644 --- a/Tests/SQLKitExtrasTests/SQLKitExtrasTests.swift +++ b/Tests/SQLKitExtrasTests/SQLKitExtrasTests.swift @@ -259,6 +259,9 @@ struct SQLKitExtrasTests { #expect(serialize(.cast("foo", to: "text")) == #"CAST("foo" AS "text")"#) #expect(serialize(.cast(.column("foo"), to: "text")) == #"CAST("foo" AS "text")"#) #expect(serialize(.cast(.column("foo"), to: .unsafeRaw("text"))) == #"CAST("foo" AS text)"#) + + #expect(MockSQLDatabase(dialect: "mysql").serialize(SQLCastExpression(.column("foo"), to: "text")).sql == #"CAST("foo" AS text)"#) + #expect(MockSQLDatabase(dialect: "postgresql").serialize(SQLCastExpression(.column("foo"), to: "text")).sql == #"CAST("foo" AS "text")"#) } @Test