Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 45 additions & 26 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgsql

import (
"errors"
"reflect"
"regexp"

"github.com/lib/pq"
Expand Down Expand Up @@ -41,14 +42,13 @@ func IsErrorClass(err error, class string) bool {
// IsUniqueViolation checks is error an unique_violation with given constraint,
// constraint can be empty to ignore constraint name checks
func IsUniqueViolation(err error, constraint ...string) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
if len(constraint) == 0 {
return true
}
return contains(constraint, extractConstraint(pqErr))
if !IsErrorCode(err, "23505") { // for drivers that implement sqlState
return false
}
return false
if len(constraint) == 0 {
return true
}
return contains(constraint, extractConstraint(err))
}

// IsInvalidTextRepresentation checks is error an invalid_text_representation
Expand All @@ -61,16 +61,15 @@ func IsCharacterNotInRepertoire(err error) bool {
return IsErrorCode(err, "22021")
}

// IsForeignKeyViolation checks is error an foreign_key_violation
// IsForeignKeyViolation checks is error a foreign_key_violation
func IsForeignKeyViolation(err error, constraint ...string) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23503" {
if len(constraint) == 0 {
return true
}
return contains(constraint, extractConstraint(pqErr))
if !IsErrorCode(err, "23503") { // for drivers that implement sqlState
return false
}
return false
if len(constraint) == 0 {
return true
}
return contains(constraint, extractConstraint(err))
}

// IsQueryCanceled checks is error an query_canceled error
Expand All @@ -85,19 +84,39 @@ func IsSerializationFailure(err error) bool {
return IsErrorCode(err, "40001")
}

func extractConstraint(err *pq.Error) string {
if err.Constraint != "" {
return err.Constraint
}
if err.Message == "" {
return ""
}
if s := extractCRDBKey(err.Message); s != "" {
return s
func extractConstraint(err error) string {
{ // pq
var pqErr *pq.Error
if errors.As(err, &pqErr) {
if pqErr.Constraint != "" {
return pqErr.Constraint
}
if pqErr.Message == "" {
return ""
}
if s := extractCRDBKey(pqErr.Message); s != "" {
return s
}
if s := extractLastQuote(pqErr.Message); s != "" {
return s
}
return ""
}
}
if s := extractLastQuote(err.Message); s != "" {
return s

{ // pgx
v := reflect.ValueOf(err)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return ""
}
if f := v.FieldByName("ConstraintName"); f.IsValid() {
return f.String()
}
}

return ""
}

Expand Down
28 changes: 28 additions & 0 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ import (
"github.com/acoshift/pgsql"
)

type pgxError struct {
Code string
ConstraintName string
}

func (e *pgxError) Error() string {
return "pgxError"
}

func (e *pgxError) SQLState() string {
return e.Code
}

func TestIsUniqueViolation(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -45,6 +58,21 @@ func TestIsUniqueViolation(t *testing.T) {
Table: "users",
Constraint: "users_email_key",
}))

assert.True(t, pgsql.IsUniqueViolation(&pgxError{
Code: "23505",
ConstraintName: "users_email_key",
}))

assert.True(t, pgsql.IsUniqueViolation(&pgxError{
Code: "23505",
ConstraintName: "users_email_key",
}, "users_email_key"))

assert.False(t, pgsql.IsUniqueViolation(&pgxError{
Code: "23505",
ConstraintName: "users_email_key",
}, "pkey"))
}

func TestIsForeignKeyViolation(t *testing.T) {
Expand Down