From 5824b970a904a17244330c834c6a22bf765b8ab4 Mon Sep 17 00:00:00 2001 From: Matthew Bamber Date: Sat, 21 May 2022 20:27:27 +0100 Subject: [PATCH] feat: Allow customizing of AWS retry codes Allow users to customize the error codes that should be retried by the AWS SDK. This enables advanced workflows such as retrying authentication failures --- aws_config.go | 14 ++++-- aws_config_test.go | 88 +++++++++++++++++++++++++++++++++ v2/awsv1shim/session.go | 13 +++++ v2/awsv1shim/session_test.go | 95 ++++++++++++++++++++++++++++++++++++ 4 files changed, 207 insertions(+), 3 deletions(-) diff --git a/aws_config.go b/aws_config.go index 6a6e6de5..45ddaa50 100644 --- a/aws_config.go +++ b/aws_config.go @@ -84,10 +84,18 @@ func resolveRetryer(ctx context.Context, awsConfig *aws.Config) { }) } + var r aws.Retryer = &networkErrorShortcutter{ + RetryerV2: retry.NewStandard(standardOptions...), + } + + // Add additional retry codes + if retryCodes := os.Getenv("AWS_RETRY_CODES"); retryCodes != "" { + codes := strings.Split(retryCodes, ",") + r = retry.AddWithErrorCodes(r, codes...) + } + awsConfig.Retryer = func() aws.Retryer { - return &networkErrorShortcutter{ - RetryerV2: retry.NewStandard(standardOptions...), - } + return r } } diff --git a/aws_config_test.go b/aws_config_test.go index d57e0339..faf5342d 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -20,6 +20,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/smithy-go" "github.com/aws/smithy-go/middleware" smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/google/go-cmp/cmp" @@ -1649,6 +1650,93 @@ max_attempts = 10 } } +func TestRetryCodes(t *testing.T) { + testCases := map[string]struct { + Config *Config + EnvironmentVariables map[string]string + ExpectedRetryableErrors []smithy.APIError + ExpectedNonRetryableErrors []smithy.APIError + }{ + "no configuration": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedNonRetryableErrors: []smithy.APIError{ + &smithy.GenericAPIError{Code: "error 1"}, + }, + }, + + "AWS_RETRY_CODES single": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + EnvironmentVariables: map[string]string{ + "AWS_RETRY_CODES": "error 1", + }, + ExpectedRetryableErrors: []smithy.APIError{ + &smithy.GenericAPIError{Code: "error 1"}, + }, + ExpectedNonRetryableErrors: []smithy.APIError{ + &smithy.GenericAPIError{Code: "error 2"}, + }, + }, + + "AWS_RETRY_CODES multiple": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + EnvironmentVariables: map[string]string{ + "AWS_RETRY_CODES": "error 1,error 2", + }, + ExpectedRetryableErrors: []smithy.APIError{ + &smithy.GenericAPIError{Code: "error 1"}, + &smithy.GenericAPIError{Code: "error 2"}, + }, + ExpectedNonRetryableErrors: []smithy.APIError{ + &smithy.GenericAPIError{Code: "error 3"}, + }, + }, + } + + for testName, testCase := range testCases { + testCase := testCase + + t.Run(testName, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + for k, v := range testCase.EnvironmentVariables { + os.Setenv(k, v) + } + + testCase.Config.SkipCredsValidation = true + + awsConfig, err := GetAwsConfig(context.Background(), testCase.Config) + if err != nil { + t.Fatalf("error in GetAwsConfig() '%[1]T': %[1]s", err) + } + + retryer := awsConfig.Retryer() + if retryer == nil { + t.Fatal("no retryer set") + } + for _, e := range testCase.ExpectedRetryableErrors { + if a := retryer.IsErrorRetryable(e); !a { + t.Errorf(`expected error %q would be retryable, got not retryable`, e) + } + } + for _, e := range testCase.ExpectedNonRetryableErrors { + if a := retryer.IsErrorRetryable(e); a { + t.Errorf(`expected error %q would not be retryable, got retryable`, e) + } + } + }) + } +} + func TestServiceEndpointTypes(t *testing.T) { testCases := map[string]struct { Config *Config diff --git a/v2/awsv1shim/session.go b/v2/awsv1shim/session.go index ca415db3..11bec902 100644 --- a/v2/awsv1shim/session.go +++ b/v2/awsv1shim/session.go @@ -5,6 +5,7 @@ import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim "fmt" "log" "os" + "strings" awsv2 "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go/aws" @@ -89,6 +90,18 @@ func GetSession(awsC *awsv2.Config, c *awsbase.Config) (*session.Session, error) sess = sess.Copy(&aws.Config{MaxRetries: aws.Int(retryer.MaxAttempts())}) } + // Add custom error code retries. It's easier to recheck the environment variable + // here as the retry codes aren't available from the original v2 config + if retryCodes := os.Getenv("AWS_RETRY_CODES"); retryCodes != "" { + codes := strings.Split(retryCodes, ",") + log.Printf("[DEBUG] Using additional retry codes: %s", codes) + sess.Handlers.Retry.PushBack(func(r *request.Request) { + if tfawserr.ErrCodeEquals(r.Error, codes...) { + r.Retryable = aws.Bool(true) + } + }) + } + SetSessionUserAgent(sess, c.APNInfo, c.UserAgent) // Add custom input from ENV to the User-Agent request header diff --git a/v2/awsv1shim/session_test.go b/v2/awsv1shim/session_test.go index 2e0a1743..185973d6 100644 --- a/v2/awsv1shim/session_test.go +++ b/v2/awsv1shim/session_test.go @@ -1470,6 +1470,101 @@ max_attempts = 10 } } +func TestRetryCodes(t *testing.T) { + testCases := map[string]struct { + Config *awsbase.Config + EnvironmentVariables map[string]string + ExpectedRetryableErrors []awserr.Error + ExpectedNonRetryableErrors []awserr.Error + }{ + "no configuration": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedNonRetryableErrors: []awserr.Error{ + awserr.New("error 1", "", nil), + }, + }, + + "AWS_RETRY_CODES single": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + EnvironmentVariables: map[string]string{ + "AWS_RETRY_CODES": "error 1", + }, + ExpectedRetryableErrors: []awserr.Error{ + awserr.New("error 1", "", nil), + }, + ExpectedNonRetryableErrors: []awserr.Error{ + awserr.New("error 2", "", nil), + }, + }, + + "AWS_RETRY_CODES multiple": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + SecretKey: servicemocks.MockStaticSecretKey, + }, + EnvironmentVariables: map[string]string{ + "AWS_RETRY_CODES": "error 1,error 2", + }, + ExpectedRetryableErrors: []awserr.Error{ + awserr.New("error 1", "", nil), + awserr.New("error 2", "", nil), + }, + ExpectedNonRetryableErrors: []awserr.Error{ + awserr.New("error 3", "", nil), + }, + }, + } + + for testName, testCase := range testCases { + testCase := testCase + + t.Run(testName, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + for k, v := range testCase.EnvironmentVariables { + os.Setenv(k, v) + } + + testCase.Config.SkipCredsValidation = true + + awsConfig, err := awsbase.GetAwsConfig(context.Background(), testCase.Config) + if err != nil { + t.Fatalf("GetAwsConfig() returned error: %s", err) + } + actualSession, err := GetSession(&awsConfig, testCase.Config) + if err != nil { + t.Fatalf("error in GetSession() '%[1]T': %[1]s", err) + } + + for _, e := range testCase.ExpectedRetryableErrors { + r := &request.Request{ + Error: e, + } + actualSession.Handlers.Retry.Run(r) + if !aws.BoolValue(r.Retryable) { + t.Errorf(`expected error %q would be retryable, got not retryable`, e) + } + } + for _, e := range testCase.ExpectedNonRetryableErrors { + r := &request.Request{ + Error: e, + } + actualSession.Handlers.Retry.Run(r) + if aws.BoolValue(r.Retryable) { + t.Errorf(`expected error %q would not be retryable, got retryable`, e) + } + } + }) + } +} + func TestServiceEndpointTypes(t *testing.T) { testCases := map[string]struct { Config *awsbase.Config