Skip to content

Commit f3803c6

Browse files
YakDriveraeschright
authored andcommitted
Ensure proper order for obtaining credentials, assuming roles, using profiles (#5)
* Adjust logic to validate creds before assumerole Previously, a bug existed that prevented session-derived creds from being used when assuming a role. This is because session- derived creds would not be gathered until the very last moment. Since all the assumerole logic was passed before this last moment, assumerole could not work with session-derived creds. Now, GetCredentials has a new contract - it provides and validates credentials. Before, GetCredentials would sometimes return unvalidated creds and sometimes validated creds. This meant that more error handling logic needed to be included in GetSession and GetSessionOptions. As part of validating creds, GetCredentials now gets session-derived creds, if necessary, prior to assuming a role. * Add error instance for better checking
1 parent 516649a commit f3803c6

File tree

4 files changed

+100
-79
lines changed

4 files changed

+100
-79
lines changed

awsauth.go

Lines changed: 88 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,21 @@ import (
2222
"github.com/hashicorp/go-multierror"
2323
)
2424

25+
const (
26+
// errMsgNoValidCredentialSources error getting credentials
27+
errMsgNoValidCredentialSources = `No valid credential sources found for AWS Provider.
28+
Please see https://terraform.io/docs/providers/aws/index.html for more information on
29+
providing credentials for the AWS Provider`
30+
)
31+
32+
var (
33+
// ErrNoValidCredentialSources indicates that no credentials source could be found
34+
ErrNoValidCredentialSources = errNoValidCredentialSources()
35+
)
36+
37+
func errNoValidCredentialSources() error { return errors.New(errMsgNoValidCredentialSources) }
38+
39+
// GetAccountIDAndPartition gets the account ID and associated partition.
2540
func GetAccountIDAndPartition(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, string, error) {
2641
var accountID, partition string
2742
var err, errors error
@@ -51,6 +66,8 @@ func GetAccountIDAndPartition(iamconn *iam.IAM, stsconn *sts.STS, authProviderNa
5166
return accountID, partition, errors
5267
}
5368

69+
// GetAccountIDAndPartitionFromEC2Metadata gets the account ID and associated
70+
// partition from EC2 metadata.
5471
func GetAccountIDAndPartitionFromEC2Metadata() (string, string, error) {
5572
log.Println("[DEBUG] Trying to get account information via EC2 Metadata")
5673

@@ -75,6 +92,8 @@ func GetAccountIDAndPartitionFromEC2Metadata() (string, string, error) {
7592
return parseAccountIDAndPartitionFromARN(info.InstanceProfileArn)
7693
}
7794

95+
// GetAccountIDAndPartitionFromIAMGetUser gets the account ID and associated
96+
// partition from IAM.
7897
func GetAccountIDAndPartitionFromIAMGetUser(iamconn *iam.IAM) (string, string, error) {
7998
log.Println("[DEBUG] Trying to get account information via iam:GetUser")
8099

@@ -102,6 +121,8 @@ func GetAccountIDAndPartitionFromIAMGetUser(iamconn *iam.IAM) (string, string, e
102121
return parseAccountIDAndPartitionFromARN(aws.StringValue(output.User.Arn))
103122
}
104123

124+
// GetAccountIDAndPartitionFromIAMListRoles gets the account ID and associated
125+
// partition from listing IAM roles.
105126
func GetAccountIDAndPartitionFromIAMListRoles(iamconn *iam.IAM) (string, string, error) {
106127
log.Println("[DEBUG] Trying to get account information via iam:ListRoles")
107128

@@ -123,6 +144,8 @@ func GetAccountIDAndPartitionFromIAMListRoles(iamconn *iam.IAM) (string, string,
123144
return parseAccountIDAndPartitionFromARN(aws.StringValue(output.Roles[0].Arn))
124145
}
125146

147+
// GetAccountIDAndPartitionFromSTSGetCallerIdentity gets the account ID and associated
148+
// partition from STS caller identity.
126149
func GetAccountIDAndPartitionFromSTSGetCallerIdentity(stsconn *sts.STS) (string, string, error) {
127150
log.Println("[DEBUG] Trying to get account information via sts:GetCallerIdentity")
128151

@@ -148,9 +171,54 @@ func parseAccountIDAndPartitionFromARN(inputARN string) (string, string, error)
148171
return arn.AccountID, arn.Partition, nil
149172
}
150173

151-
// This function is responsible for reading credentials from the
152-
// environment in the case that they're not explicitly specified
153-
// in the Terraform configuration.
174+
// GetCredentialsFromSession returns credentials derived from a session. A
175+
// session uses the AWS SDK Go chain of providers so may use a provider (e.g.,
176+
// ProcessProvider) that is not part of the Terraform provider chain.
177+
func GetCredentialsFromSession(c *Config) (*awsCredentials.Credentials, error) {
178+
log.Printf("[INFO] Attempting to use session-derived credentials")
179+
180+
var sess *session.Session
181+
var err error
182+
if c.Profile == "" {
183+
sess, err = session.NewSession()
184+
if err != nil {
185+
return nil, ErrNoValidCredentialSources
186+
}
187+
} else {
188+
options := &session.Options{
189+
Config: aws.Config{
190+
HTTPClient: cleanhttp.DefaultClient(),
191+
MaxRetries: aws.Int(0),
192+
Region: aws.String(c.Region),
193+
},
194+
}
195+
options.Profile = c.Profile
196+
options.SharedConfigState = session.SharedConfigEnable
197+
198+
sess, err = session.NewSessionWithOptions(*options)
199+
if err != nil {
200+
if IsAWSErr(err, "NoCredentialProviders", "") {
201+
return nil, ErrNoValidCredentialSources
202+
}
203+
return nil, fmt.Errorf("Error creating AWS session: %s", err)
204+
}
205+
}
206+
207+
creds := sess.Config.Credentials
208+
cp, err := sess.Config.Credentials.Get()
209+
if err != nil {
210+
return nil, ErrNoValidCredentialSources
211+
}
212+
213+
log.Printf("[INFO] Successfully derived credentials from session")
214+
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
215+
return creds, nil
216+
}
217+
218+
// GetCredentials gets credentials from the environment, shared credentials,
219+
// or the session (which may include a credential process). GetCredentials also
220+
// validates the credentials and the ability to assume a role or will return an
221+
// error if unsuccessful.
154222
func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
155223
// build a chain provider, lazy-evaluated by aws-sdk
156224
providers := []awsCredentials.Provider{
@@ -225,30 +293,32 @@ func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
225293
}
226294
}
227295

296+
// Validate the credentials before returning them
297+
creds := awsCredentials.NewChainCredentials(providers)
298+
cp, err := creds.Get()
299+
if err != nil {
300+
if IsAWSErr(err, "NoCredentialProviders", "") {
301+
creds, err = GetCredentialsFromSession(c)
302+
if err != nil {
303+
return nil, err
304+
}
305+
} else {
306+
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
307+
}
308+
} else {
309+
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
310+
}
311+
228312
// This is the "normal" flow (i.e. not assuming a role)
229313
if c.AssumeRoleARN == "" {
230-
return awsCredentials.NewChainCredentials(providers), nil
314+
return creds, nil
231315
}
232316

233317
// Otherwise we need to construct an STS client with the main credentials, and verify
234318
// that we can assume the defined role.
235319
log.Printf("[INFO] Attempting to AssumeRole %s (SessionName: %q, ExternalId: %q, Policy: %q)",
236320
c.AssumeRoleARN, c.AssumeRoleSessionName, c.AssumeRoleExternalID, c.AssumeRolePolicy)
237321

238-
creds := awsCredentials.NewChainCredentials(providers)
239-
cp, err := creds.Get()
240-
if err != nil {
241-
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
242-
return nil, errors.New(`No valid credential sources found for AWS Provider.
243-
Please see https://terraform.io/docs/providers/aws/index.html for more information on
244-
providing credentials for the AWS Provider`)
245-
}
246-
247-
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
248-
}
249-
250-
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
251-
252322
awsConfig := &aws.Config{
253323
Credentials: creds,
254324
Region: aws.String(c.Region),

awsauth_test.go

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"os"
88
"testing"
99

10-
"github.com/aws/aws-sdk-go/aws/awserr"
1110
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
1211
"github.com/aws/aws-sdk-go/service/iam"
1312
"github.com/aws/aws-sdk-go/service/sts"
@@ -423,20 +422,12 @@ func TestAWSGetCredentials_shouldErrorWhenBlank(t *testing.T) {
423422
defer resetEnv()
424423

425424
cfg := Config{}
426-
c, err := GetCredentials(&cfg)
425+
_, err := GetCredentials(&cfg)
427426

428-
if err != nil {
427+
if err != ErrNoValidCredentialSources {
429428
t.Fatalf("Unexpected error: %s", err)
430429
}
431430

432-
_, err = c.Get()
433-
if awsErr, ok := err.(awserr.Error); ok {
434-
if awsErr.Code() != "NoCredentialProviders" {
435-
t.Fatal("Expected NoCredentialProviders error")
436-
}
437-
} else {
438-
t.Fatal("Expected AWS error")
439-
}
440431
if err == nil {
441432
t.Fatal("Expected an error given empty env, keys, and IAM in AWS Config")
442433
}
@@ -586,22 +577,14 @@ func TestAWSGetCredentials_shouldErrorWithInvalidEndpoint(t *testing.T) {
586577
ts := invalidAwsEnv(t)
587578
defer ts()
588579

589-
creds, err := GetCredentials(&Config{})
590-
if err != nil {
580+
_, err := GetCredentials(&Config{})
581+
if err != ErrNoValidCredentialSources {
591582
t.Fatalf("Error gettings creds: %s", err)
592583
}
593-
if creds == nil {
594-
t.Fatal("Expected a static creds provider to be returned")
595-
}
596584

597-
v, err := creds.Get()
598585
if err == nil {
599586
t.Fatal("Expected error returned when getting creds w/ invalid EC2 endpoint")
600587
}
601-
602-
if v.ProviderName != "" {
603-
t.Fatalf("Expected provider name to be empty, %q given", v.ProviderName)
604-
}
605588
}
606589

607590
func TestAWSGetCredentials_shouldIgnoreInvalidEndpoint(t *testing.T) {

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ require (
88
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd // indirect
99
golang.org/x/text v0.3.0 // indirect
1010
)
11+
12+
go 1.13

session.go

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package awsbase
22

33
import (
44
"crypto/tls"
5-
"errors"
65
"fmt"
76
"log"
87
"net/http"
@@ -28,45 +27,14 @@ func GetSessionOptions(c *Config) (*session.Options, error) {
2827
},
2928
}
3029

30+
// get and validate credentials
3131
creds, err := GetCredentials(c)
3232
if err != nil {
3333
return nil, err
3434
}
3535

36-
// Call Get to check for credential provider. If nothing found, we'll get an
37-
// error, and we can present it nicely to the user
38-
cp, err := creds.Get()
39-
if err != nil {
40-
if IsAWSErr(err, "NoCredentialProviders", "") {
41-
// If a profile wasn't specified, the session may still be able to resolve credentials from shared config.
42-
if c.Profile == "" {
43-
sess, err := session.NewSession()
44-
if err != nil {
45-
return nil, errors.New(`No valid credential sources found for AWS Provider.
46-
Please see https://terraform.io/docs/providers/aws/index.html for more information on
47-
providing credentials for the AWS Provider`)
48-
}
49-
_, err = sess.Config.Credentials.Get()
50-
if err != nil {
51-
return nil, errors.New(`No valid credential sources found for AWS Provider.
52-
Please see https://terraform.io/docs/providers/aws/index.html for more information on
53-
providing credentials for the AWS Provider`)
54-
}
55-
log.Printf("[INFO] Using session-derived AWS Auth")
56-
options.Config.Credentials = sess.Config.Credentials
57-
} else {
58-
log.Printf("[INFO] AWS Auth using Profile: %q", c.Profile)
59-
options.Profile = c.Profile
60-
options.SharedConfigState = session.SharedConfigEnable
61-
}
62-
} else {
63-
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
64-
}
65-
} else {
66-
// add the validated credentials to the session options
67-
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
68-
options.Config.Credentials = creds
69-
}
36+
// add the validated credentials to the session options
37+
options.Config.Credentials = creds
7038

7139
if c.Insecure {
7240
transport := options.Config.HTTPClient.Transport.(*http.Transport)
@@ -83,7 +51,7 @@ func GetSessionOptions(c *Config) (*session.Options, error) {
8351
return options, nil
8452
}
8553

86-
// GetSession attempts to return valid AWS Go SDK session
54+
// GetSession attempts to return valid AWS Go SDK session.
8755
func GetSession(c *Config) (*session.Session, error) {
8856
options, err := GetSessionOptions(c)
8957

@@ -94,9 +62,7 @@ func GetSession(c *Config) (*session.Session, error) {
9462
sess, err := session.NewSessionWithOptions(*options)
9563
if err != nil {
9664
if IsAWSErr(err, "NoCredentialProviders", "") {
97-
return nil, errors.New(`No valid credential sources found for AWS Provider.
98-
Please see https://terraform.io/docs/providers/aws/index.html for more information on
99-
providing credentials for the AWS Provider`)
65+
return nil, ErrNoValidCredentialSources
10066
}
10167
return nil, fmt.Errorf("Error creating AWS session: %s", err)
10268
}
@@ -138,7 +104,7 @@ func GetSession(c *Config) (*session.Session, error) {
138104
if !c.SkipCredsValidation {
139105
stsClient := sts.New(sess.Copy(&aws.Config{Endpoint: aws.String(c.StsEndpoint)}))
140106
if _, _, err := GetAccountIDAndPartitionFromSTSGetCallerIdentity(stsClient); err != nil {
141-
return nil, fmt.Errorf("error validating provider credentials: %s", err)
107+
return nil, fmt.Errorf("error using credentials to get account ID: %s", err)
142108
}
143109
}
144110

0 commit comments

Comments
 (0)