Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- 5432:5432
strategy:
matrix:
go: ['1.20', '1.21']
go: ['1.22', '1.23', '1.24']
name: Go ${{ matrix.go }}
steps:
- uses: actions/checkout@v3
Expand Down
127 changes: 127 additions & 0 deletions backoff/backoff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package backoff

import (
"math"
"math/rand/v2"
"time"

"github.com/acoshift/pgsql"
)

// Config contains common configuration for all backoff strategies
type Config struct {
BaseDelay time.Duration // Base delay for backoff
MaxDelay time.Duration // Maximum delay cap
}

// ExponentialConfig contains configuration for exponential backoff
type ExponentialConfig struct {
Config
Multiplier float64 // Multiplier for exponential growth
JitterType JitterType
}

// LinearConfig contains configuration for linear backoff
type LinearConfig struct {
Config
Increment time.Duration // Amount to increase delay each attempt
}

// JitterType defines the type of jitter to apply
type JitterType int

const (
// NoJitter applies no jitter
NoJitter JitterType = iota
// FullJitter applies full jitter (0 to calculated delay)
FullJitter
// EqualJitter applies equal jitter (half fixed + half random)
EqualJitter
)

// NewExponential creates a new exponential backoff function
func NewExponential(config ExponentialConfig) pgsql.BackoffDelayFunc {
return func(attempt int) time.Duration {
baseDelay := time.Duration(float64(config.BaseDelay) * math.Pow(config.Multiplier, float64(attempt)))
if baseDelay > config.MaxDelay {
baseDelay = config.MaxDelay
}

var delay time.Duration
switch config.JitterType {
case FullJitter:
// Full jitter: random delay between 0 and calculated delay
if baseDelay > 0 {
delay = time.Duration(rand.Int64N(int64(baseDelay)))
} else {
delay = baseDelay
}
case EqualJitter:
// Equal jitter: half fixed + half random
half := baseDelay / 2
if half > 0 {
delay = half + time.Duration(rand.Int64N(int64(half)))
} else {
delay = baseDelay
}
default:
delay = baseDelay
}

return delay
}
}

// NewLinear creates a new linear backoff function
func NewLinear(config LinearConfig) pgsql.BackoffDelayFunc {
return func(attempt int) time.Duration {
delay := config.BaseDelay + time.Duration(attempt)*config.Increment
if delay > config.MaxDelay {
delay = config.MaxDelay
}
return delay
}
}

func DefaultExponential() pgsql.BackoffDelayFunc {
return NewExponential(ExponentialConfig{
Config: Config{
BaseDelay: 100 * time.Millisecond,
MaxDelay: 5 * time.Second,
},
Multiplier: 2.0,
JitterType: NoJitter,
})
}

func DefaultExponentialWithFullJitter() pgsql.BackoffDelayFunc {
return NewExponential(ExponentialConfig{
Config: Config{
BaseDelay: 100 * time.Millisecond,
MaxDelay: 5 * time.Second,
},
Multiplier: 2.0,
JitterType: FullJitter,
})
}

func DefaultExponentialWithEqualJitter() pgsql.BackoffDelayFunc {
return NewExponential(ExponentialConfig{
Config: Config{
BaseDelay: 100 * time.Millisecond,
MaxDelay: 5 * time.Second,
},
Multiplier: 2.0,
JitterType: EqualJitter,
})
}

func DefaultLinear() pgsql.BackoffDelayFunc {
return NewLinear(LinearConfig{
Config: Config{
BaseDelay: 100 * time.Millisecond,
MaxDelay: 5 * time.Second,
},
Increment: 100 * time.Millisecond,
})
}
152 changes: 152 additions & 0 deletions backoff/backoff_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package backoff_test

import (
"testing"
"time"

"github.com/acoshift/pgsql/backoff"
)

func TestExponential(t *testing.T) {
t.Parallel()

config := backoff.ExponentialConfig{
Config: backoff.Config{
BaseDelay: 10 * time.Millisecond,
MaxDelay: 1 * time.Second,
},
Multiplier: 2.0,
}
backoff := backoff.NewExponential(config)

// Test exponential growth
delays := []time.Duration{}
for i := 0; i < 10; i++ {
delay := backoff(i)
delays = append(delays, delay)
}

// Verify exponential growth
for i := 1; i < len(delays); i++ {
if delays[i] < delays[i-1] {
t.Errorf("Expected delay[%d] >= delay[%d], got %v < %v", i, i-1, delays[i], delays[i-1])
}
}

// Verify max delay
for i := 0; i < 10; i++ {
delay := backoff(i)
if delay > config.MaxDelay {
t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay)
}
}
}

func TestExponentialWithFullJitter(t *testing.T) {
t.Parallel()

config := backoff.ExponentialConfig{
Config: backoff.Config{
BaseDelay: 100 * time.Millisecond,
MaxDelay: 1 * time.Second,
},
Multiplier: 2.0,
JitterType: backoff.FullJitter,
}
backoff := backoff.NewExponential(config)

// Test that jitter introduces randomness
var delays []time.Duration
for i := 0; i < 10; i++ {
delay := backoff(3) // Use same attempt number
delays = append(delays, delay)
}

// Check that not all delays are the same (indicating jitter is working)
allSame := true
for i := 1; i < len(delays); i++ {
if delays[i] != delays[0] {
allSame = false
break
}
}
if allSame {
t.Error("Expected jitter to produce different delays, but all delays were the same")
}

// Verify max delay
for i := 0; i < 15; i++ {
delay := backoff(i)
if delay > config.MaxDelay {
t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay)
}
}
}

func TestExponentialWithEqualJitter(t *testing.T) {
t.Parallel()

config := backoff.ExponentialConfig{
Config: backoff.Config{
BaseDelay: 100 * time.Millisecond,
MaxDelay: 1 * time.Second,
},
Multiplier: 2.0,
JitterType: backoff.EqualJitter,
}
backoff := backoff.NewExponential(config)

delay := backoff(2)

// With equal jitter, delay should be at least half of the calculated delay
expectedMin := 200 * time.Millisecond // (100ms * 2^2) / 2 = 200ms
if delay < expectedMin {
t.Errorf("Expected delay >= %v with equal jitter, got %v", expectedMin, delay)
}

// Verify max delay
for i := 0; i < 15; i++ {
delay := backoff(i)
if delay > config.MaxDelay {
t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay)
}
}
}

func TestLinearBackoff(t *testing.T) {
t.Parallel()

config := backoff.LinearConfig{
Config: backoff.Config{
BaseDelay: 100 * time.Millisecond,
MaxDelay: 1 * time.Second,
},
Increment: 100 * time.Millisecond,
}
backoff := backoff.NewLinear(config)

// Test linear growth
delays := []time.Duration{}
for i := 0; i < 5; i++ {
delay := backoff(i)
delays = append(delays, delay)
}

// Verify linear growth
for i := 1; i < len(delays); i++ {
expectedIncrease := 100 * time.Millisecond
actualIncrease := delays[i] - delays[i-1]

if actualIncrease != expectedIncrease {
t.Errorf("Expected linear increase of %v, got %v", expectedIncrease, actualIncrease)
}
}

// Verify max delay
for i := 0; i < 15; i++ {
delay := backoff(i)
if delay > config.MaxDelay {
t.Errorf("Expected delay[%d] <= MaxDelay (%v), got %v", i, config.MaxDelay, delay)
}
}
}
29 changes: 28 additions & 1 deletion tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"time"
)

// ErrAbortTx rollbacks transaction and return nil error
Expand All @@ -14,10 +15,14 @@ type BeginTxer interface {
BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error)
}

// BackoffDelayFunc is a function type that defines the delay for backoff
type BackoffDelayFunc func(attempt int) time.Duration

// TxOptions is the transaction options
type TxOptions struct {
sql.TxOptions
MaxAttempts int
MaxAttempts int
BackoffDelayFunc BackoffDelayFunc
}

const (
Expand Down Expand Up @@ -54,6 +59,8 @@ func RunInTxContext(ctx context.Context, db BeginTxer, opts *TxOptions, fn func(
if opts.Isolation == sql.LevelDefault {
option.Isolation = sql.LevelSerializable
}

option.BackoffDelayFunc = opts.BackoffDelayFunc
}

f := func() error {
Expand All @@ -80,7 +87,27 @@ func RunInTxContext(ctx context.Context, db BeginTxer, opts *TxOptions, fn func(
if !IsSerializationFailure(err) {
return err
}

if i < option.MaxAttempts-1 && option.BackoffDelayFunc != nil {
if err = wait(ctx, i, option.BackoffDelayFunc); err != nil {
return err
}
}
}

return err
}

func wait(ctx context.Context, attempt int, backOffDelayFunc BackoffDelayFunc) error {
delay := backOffDelayFunc(attempt)
if delay <= 0 {
return nil
}

select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
return nil
}
}
Loading