diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c08eb0b..119fba2 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: - 5432:5432 strategy: matrix: - go: ['1.20', '1.21'] + go: ['1.22', '1.23', '1.24'] name: Go ${{ matrix.go }} steps: - uses: actions/checkout@v3 diff --git a/backoff/backoff.go b/backoff/backoff.go new file mode 100644 index 0000000..ecfed3e --- /dev/null +++ b/backoff/backoff.go @@ -0,0 +1,127 @@ +package backoff + +import ( + "math" + "math/rand/v2" + "time" + + "github.com/acoshift/pgsql" +) + +// Config contains common configuration for all backoff strategies +type Config struct { + BaseDelay time.Duration // Base delay for backoff + MaxDelay time.Duration // Maximum delay cap +} + +// ExponentialConfig contains configuration for exponential backoff +type ExponentialConfig struct { + Config + Multiplier float64 // Multiplier for exponential growth + JitterType JitterType +} + +// LinearConfig contains configuration for linear backoff +type LinearConfig struct { + Config + Increment time.Duration // Amount to increase delay each attempt +} + +// JitterType defines the type of jitter to apply +type JitterType int + +const ( + // NoJitter applies no jitter + NoJitter JitterType = iota + // FullJitter applies full jitter (0 to calculated delay) + FullJitter + // EqualJitter applies equal jitter (half fixed + half random) + EqualJitter +) + +// NewExponential creates a new exponential backoff function +func NewExponential(config ExponentialConfig) pgsql.BackoffDelayFunc { + return func(attempt int) time.Duration { + baseDelay := time.Duration(float64(config.BaseDelay) * math.Pow(config.Multiplier, float64(attempt))) + if baseDelay > config.MaxDelay { + baseDelay = config.MaxDelay + } + + var delay time.Duration + switch config.JitterType { + case FullJitter: + // Full jitter: random delay between 0 and calculated delay + if baseDelay > 0 { + delay = time.Duration(rand.Int64N(int64(baseDelay))) + } else { + delay = baseDelay + } + case EqualJitter: + // Equal jitter: half fixed + half random + half := baseDelay / 2 + if half > 0 { + delay = half + time.Duration(rand.Int64N(int64(half))) + } else { + delay = baseDelay + } + default: + delay = baseDelay + } + + return delay + } +} + +// NewLinear creates a new linear backoff function +func NewLinear(config LinearConfig) pgsql.BackoffDelayFunc { + return func(attempt int) time.Duration { + delay := config.BaseDelay + time.Duration(attempt)*config.Increment + if delay > config.MaxDelay { + delay = config.MaxDelay + } + return delay + } +} + +func DefaultExponential() pgsql.BackoffDelayFunc { + return NewExponential(ExponentialConfig{ + Config: Config{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 5 * time.Second, + }, + Multiplier: 2.0, + JitterType: NoJitter, + }) +} + +func DefaultExponentialWithFullJitter() pgsql.BackoffDelayFunc { + return NewExponential(ExponentialConfig{ + Config: Config{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 5 * time.Second, + }, + Multiplier: 2.0, + JitterType: FullJitter, + }) +} + +func DefaultExponentialWithEqualJitter() pgsql.BackoffDelayFunc { + return NewExponential(ExponentialConfig{ + Config: Config{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 5 * time.Second, + }, + Multiplier: 2.0, + JitterType: EqualJitter, + }) +} + +func DefaultLinear() pgsql.BackoffDelayFunc { + return NewLinear(LinearConfig{ + Config: Config{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 5 * time.Second, + }, + Increment: 100 * time.Millisecond, + }) +} diff --git a/backoff/backoff_test.go b/backoff/backoff_test.go new file mode 100644 index 0000000..621f573 --- /dev/null +++ b/backoff/backoff_test.go @@ -0,0 +1,152 @@ +package backoff_test + +import ( + "testing" + "time" + + "github.com/acoshift/pgsql/backoff" +) + +func TestExponential(t *testing.T) { + t.Parallel() + + config := backoff.ExponentialConfig{ + Config: backoff.Config{ + BaseDelay: 10 * time.Millisecond, + MaxDelay: 1 * time.Second, + }, + Multiplier: 2.0, + } + backoff := backoff.NewExponential(config) + + // Test exponential growth + delays := []time.Duration{} + for i := 0; i < 10; i++ { + delay := backoff(i) + delays = append(delays, delay) + } + + // Verify exponential growth + for i := 1; i < len(delays); i++ { + if delays[i] < delays[i-1] { + t.Errorf("Expected delay[%d] >= delay[%d], got %v < %v", i, i-1, delays[i], delays[i-1]) + } + } + + // Verify max delay + for i := 0; i < 10; i++ { + delay := backoff(i) + if delay > config.MaxDelay { + t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay) + } + } +} + +func TestExponentialWithFullJitter(t *testing.T) { + t.Parallel() + + config := backoff.ExponentialConfig{ + Config: backoff.Config{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 1 * time.Second, + }, + Multiplier: 2.0, + JitterType: backoff.FullJitter, + } + backoff := backoff.NewExponential(config) + + // Test that jitter introduces randomness + var delays []time.Duration + for i := 0; i < 10; i++ { + delay := backoff(3) // Use same attempt number + delays = append(delays, delay) + } + + // Check that not all delays are the same (indicating jitter is working) + allSame := true + for i := 1; i < len(delays); i++ { + if delays[i] != delays[0] { + allSame = false + break + } + } + if allSame { + t.Error("Expected jitter to produce different delays, but all delays were the same") + } + + // Verify max delay + for i := 0; i < 15; i++ { + delay := backoff(i) + if delay > config.MaxDelay { + t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay) + } + } +} + +func TestExponentialWithEqualJitter(t *testing.T) { + t.Parallel() + + config := backoff.ExponentialConfig{ + Config: backoff.Config{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 1 * time.Second, + }, + Multiplier: 2.0, + JitterType: backoff.EqualJitter, + } + backoff := backoff.NewExponential(config) + + delay := backoff(2) + + // With equal jitter, delay should be at least half of the calculated delay + expectedMin := 200 * time.Millisecond // (100ms * 2^2) / 2 = 200ms + if delay < expectedMin { + t.Errorf("Expected delay >= %v with equal jitter, got %v", expectedMin, delay) + } + + // Verify max delay + for i := 0; i < 15; i++ { + delay := backoff(i) + if delay > config.MaxDelay { + t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay) + } + } +} + +func TestLinearBackoff(t *testing.T) { + t.Parallel() + + config := backoff.LinearConfig{ + Config: backoff.Config{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 1 * time.Second, + }, + Increment: 100 * time.Millisecond, + } + backoff := backoff.NewLinear(config) + + // Test linear growth + delays := []time.Duration{} + for i := 0; i < 5; i++ { + delay := backoff(i) + delays = append(delays, delay) + } + + // Verify linear growth + for i := 1; i < len(delays); i++ { + expectedIncrease := 100 * time.Millisecond + actualIncrease := delays[i] - delays[i-1] + + if actualIncrease != expectedIncrease { + t.Errorf("Expected linear increase of %v, got %v", expectedIncrease, actualIncrease) + } + } + + // Verify max delay + for i := 0; i < 15; i++ { + delay := backoff(i) + if delay > config.MaxDelay { + t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay) + } + } +} diff --git a/tx.go b/tx.go index e7615ef..905ff78 100644 --- a/tx.go +++ b/tx.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "time" ) // ErrAbortTx rollbacks transaction and return nil error @@ -14,10 +15,14 @@ type BeginTxer interface { BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) } +// BackoffDelayFunc is a function type that defines the delay for backoff +type BackoffDelayFunc func(attempt int) time.Duration + // TxOptions is the transaction options type TxOptions struct { sql.TxOptions - MaxAttempts int + MaxAttempts int + BackoffDelayFunc BackoffDelayFunc } const ( @@ -54,6 +59,8 @@ func RunInTxContext(ctx context.Context, db BeginTxer, opts *TxOptions, fn func( if opts.Isolation == sql.LevelDefault { option.Isolation = sql.LevelSerializable } + + option.BackoffDelayFunc = opts.BackoffDelayFunc } f := func() error { @@ -80,7 +87,27 @@ func RunInTxContext(ctx context.Context, db BeginTxer, opts *TxOptions, fn func( if !IsSerializationFailure(err) { return err } + + if i < option.MaxAttempts-1 && option.BackoffDelayFunc != nil { + if err = wait(ctx, i, option.BackoffDelayFunc); err != nil { + return err + } + } } return err } + +func wait(ctx context.Context, attempt int, backOffDelayFunc BackoffDelayFunc) error { + delay := backOffDelayFunc(attempt) + if delay <= 0 { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + return nil + } +} diff --git a/tx_test.go b/tx_test.go index b2d15ed..ccb3e40 100644 --- a/tx_test.go +++ b/tx_test.go @@ -1,12 +1,16 @@ package pgsql_test import ( + "context" "database/sql" + "database/sql/driver" + "errors" "fmt" "log" "math/rand" "sync" "testing" + "time" "github.com/acoshift/pgsql" ) @@ -148,3 +152,155 @@ func TestTx(t *testing.T) { t.Fatalf("expected sum all value to be 0; got %d", result) } } + +func TestTxRetryWithBackoff(t *testing.T) { + t.Parallel() + + t.Run("Backoff when serialization failure occurs", func(t *testing.T) { + t.Parallel() + + attemptCount := 0 + opts := &pgsql.TxOptions{ + MaxAttempts: 3, + BackoffDelayFunc: func(attempt int) time.Duration { + attemptCount++ + return 1 + }, + } + + pgsql.RunInTxContext(context.Background(), sql.OpenDB(&fakeConnector{}), opts, func(*sql.Tx) error { + return &mockSerializationFailureError{} + }) + + if attemptCount != opts.MaxAttempts-1 { + t.Fatalf("expected BackoffDelayFunc to be called %d times, got %d", opts.MaxAttempts, attemptCount) + } + }) + + t.Run("Successful After Multiple Failures", func(t *testing.T) { + t.Parallel() + + failCount := 0 + maxFailures := 3 + opts := &pgsql.TxOptions{ + MaxAttempts: maxFailures + 1, + BackoffDelayFunc: func(attempt int) time.Duration { + return 1 + }, + } + + err := pgsql.RunInTxContext(context.Background(), sql.OpenDB(&fakeConnector{}), opts, func(tx *sql.Tx) error { + if failCount < maxFailures { + failCount++ + return &mockSerializationFailureError{} + } + return nil + }) + if err != nil { + t.Fatalf("expected success after failures, got error: %v", err) + } + if failCount != maxFailures { + t.Fatalf("expected %d failures before success, got %d", maxFailures, failCount) + } + }) + + t.Run("Context Cancellation", func(t *testing.T) { + t.Parallel() + + opts := &pgsql.TxOptions{ + MaxAttempts: 3, + BackoffDelayFunc: func(attempt int) time.Duration { + return 1 + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately + + err := pgsql.RunInTxContext(ctx, sql.OpenDB(&fakeConnector{}), opts, func(*sql.Tx) error { + return &mockSerializationFailureError{} + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled error, got %v", err) + } + }) + + t.Run("Max Attempts Reached", func(t *testing.T) { + t.Parallel() + + attemptCount := 0 + opts := &pgsql.TxOptions{ + MaxAttempts: 3, + BackoffDelayFunc: func(attempt int) time.Duration { + return 1 + }, + } + + err := pgsql.RunInTxContext(context.Background(), sql.OpenDB(&fakeConnector{}), opts, func(*sql.Tx) error { + attemptCount++ + return &mockSerializationFailureError{} + }) + if errors.As(err, &mockSerializationFailureError{}) { + t.Fatal("expected an error when max attempts reached") + } + if attemptCount != opts.MaxAttempts { + t.Fatalf("expected %d attempts, got %d", opts.MaxAttempts, attemptCount) + } + }) +} + +type fakeConnector struct { + driver.Connector +} + +func (c *fakeConnector) Connect(ctx context.Context) (driver.Conn, error) { + return &fakeConn{}, nil +} + +func (c *fakeConnector) Driver() driver.Driver { + panic("not implemented") +} + +type fakeConn struct { + driver.Conn +} + +func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { + return nil, fmt.Errorf("not implemented") +} + +func (c *fakeConn) Close() error { + return nil +} + +func (c *fakeConn) Begin() (driver.Tx, error) { + return &fakeTx{}, nil +} + +var _ driver.ConnBeginTx = (*fakeConn)(nil) + +func (c *fakeConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return &fakeTx{}, nil +} + +type fakeTx struct { + driver.Tx +} + +func (tx *fakeTx) Commit() error { + return nil +} + +func (tx *fakeTx) Rollback() error { + return nil +} + +type mockSerializationFailureError struct{} + +func (e mockSerializationFailureError) Error() string { + return "mock serialization failure error" +} + +func (e mockSerializationFailureError) SQLState() string { + return "40001" // SQLSTATE code for serialization failure +}