diff --git a/error.go b/error.go index 3414019..1b61377 100644 --- a/error.go +++ b/error.go @@ -2,6 +2,7 @@ package pgsql import ( "errors" + "reflect" "regexp" "github.com/lib/pq" @@ -41,14 +42,13 @@ func IsErrorClass(err error, class string) bool { // IsUniqueViolation checks is error an unique_violation with given constraint, // constraint can be empty to ignore constraint name checks func IsUniqueViolation(err error, constraint ...string) bool { - var pqErr *pq.Error - if errors.As(err, &pqErr) && pqErr.Code == "23505" { - if len(constraint) == 0 { - return true - } - return contains(constraint, extractConstraint(pqErr)) + if !IsErrorCode(err, "23505") { // for drivers that implement sqlState + return false } - return false + if len(constraint) == 0 { + return true + } + return contains(constraint, extractConstraint(err)) } // IsInvalidTextRepresentation checks is error an invalid_text_representation @@ -61,16 +61,15 @@ func IsCharacterNotInRepertoire(err error) bool { return IsErrorCode(err, "22021") } -// IsForeignKeyViolation checks is error an foreign_key_violation +// IsForeignKeyViolation checks is error a foreign_key_violation func IsForeignKeyViolation(err error, constraint ...string) bool { - var pqErr *pq.Error - if errors.As(err, &pqErr) && pqErr.Code == "23503" { - if len(constraint) == 0 { - return true - } - return contains(constraint, extractConstraint(pqErr)) + if !IsErrorCode(err, "23503") { // for drivers that implement sqlState + return false } - return false + if len(constraint) == 0 { + return true + } + return contains(constraint, extractConstraint(err)) } // IsQueryCanceled checks is error an query_canceled error @@ -85,19 +84,39 @@ func IsSerializationFailure(err error) bool { return IsErrorCode(err, "40001") } -func extractConstraint(err *pq.Error) string { - if err.Constraint != "" { - return err.Constraint - } - if err.Message == "" { - return "" - } - if s := extractCRDBKey(err.Message); s != "" { - return s +func extractConstraint(err error) string { + { // pq + var pqErr *pq.Error + if errors.As(err, &pqErr) { + if pqErr.Constraint != "" { + return pqErr.Constraint + } + if pqErr.Message == "" { + return "" + } + if s := extractCRDBKey(pqErr.Message); s != "" { + return s + } + if s := extractLastQuote(pqErr.Message); s != "" { + return s + } + return "" + } } - if s := extractLastQuote(err.Message); s != "" { - return s + + { // pgx + v := reflect.ValueOf(err) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return "" + } + if f := v.FieldByName("ConstraintName"); f.IsValid() { + return f.String() + } } + return "" } diff --git a/error_test.go b/error_test.go index 961b75b..1b1dbe3 100644 --- a/error_test.go +++ b/error_test.go @@ -11,6 +11,19 @@ import ( "github.com/acoshift/pgsql" ) +type pgxError struct { + Code string + ConstraintName string +} + +func (e *pgxError) Error() string { + return "pgxError" +} + +func (e *pgxError) SQLState() string { + return e.Code +} + func TestIsUniqueViolation(t *testing.T) { t.Parallel() @@ -45,6 +58,21 @@ func TestIsUniqueViolation(t *testing.T) { Table: "users", Constraint: "users_email_key", })) + + assert.True(t, pgsql.IsUniqueViolation(&pgxError{ + Code: "23505", + ConstraintName: "users_email_key", + })) + + assert.True(t, pgsql.IsUniqueViolation(&pgxError{ + Code: "23505", + ConstraintName: "users_email_key", + }, "users_email_key")) + + assert.False(t, pgsql.IsUniqueViolation(&pgxError{ + Code: "23505", + ConstraintName: "users_email_key", + }, "pkey")) } func TestIsForeignKeyViolation(t *testing.T) {