Skip to content

Commit a7b41e5

Browse files
authored
Merge pull request #178 from mwieczorek/web-identity-token
Add support for AssumeRoleWithWebIdentity
2 parents ad7ee72 + 8c89298 commit a7b41e5

File tree

12 files changed

+603
-160
lines changed

12 files changed

+603
-160
lines changed

aws_config_test.go

Lines changed: 170 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ func TestGetAwsConfig(t *testing.T) {
4141
Description string
4242
EnableEc2MetadataServer bool
4343
EnableEcsCredentialsServer bool
44-
EnableWebIdentityToken bool
44+
EnableWebIdentityEnvVars bool
45+
EnableWebIdentityConfig bool
4546
EnvironmentVariables map[string]string
4647
ExpectedCredentialsValue aws.Credentials
4748
ExpectedRegion string
@@ -99,7 +100,7 @@ func TestGetAwsConfig(t *testing.T) {
99100
Region: "us-east-1",
100101
SecretKey: servicemocks.MockStaticSecretKey,
101102
},
102-
Description: "config AssumeRoleDurationSeconds",
103+
Description: "config AssumeRoleDuration",
103104
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
104105
ExpectedRegion: "us-east-1",
105106
MockStsEndpoints: []*servicemocks.MockEndpoint{
@@ -494,7 +495,7 @@ aws_secret_access_key = DefaultSharedCredentialsSecretKey
494495
Region: "us-east-1",
495496
},
496497
Description: "web identity token access key",
497-
EnableWebIdentityToken: true,
498+
EnableWebIdentityEnvVars: true,
498499
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
499500
ExpectedRegion: "us-east-1",
500501
MockStsEndpoints: []*servicemocks.MockEndpoint{
@@ -560,6 +561,42 @@ aws_secret_access_key = DefaultSharedCredentialsSecretKey
560561
servicemocks.MockStsGetCallerIdentityValidEndpoint,
561562
},
562563
},
564+
{
565+
Config: &Config{
566+
AssumeRole: &AssumeRole{
567+
RoleARN: servicemocks.MockStsAssumeRoleArn,
568+
SessionName: servicemocks.MockStsAssumeRoleSessionName,
569+
},
570+
Region: "us-east-1",
571+
},
572+
Description: "AssumeWebIdentity envvar AssumeRoleARN access key",
573+
EnableWebIdentityEnvVars: true,
574+
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
575+
ExpectedRegion: "us-east-1",
576+
MockStsEndpoints: []*servicemocks.MockEndpoint{
577+
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
578+
servicemocks.MockStsAssumeRoleValidEndpoint,
579+
servicemocks.MockStsGetCallerIdentityValidEndpoint,
580+
},
581+
},
582+
{
583+
Config: &Config{
584+
AssumeRole: &AssumeRole{
585+
RoleARN: servicemocks.MockStsAssumeRoleArn,
586+
SessionName: servicemocks.MockStsAssumeRoleSessionName,
587+
},
588+
Region: "us-east-1",
589+
},
590+
Description: "AssumeWebIdentity config AssumeRoleARN access key",
591+
EnableWebIdentityConfig: true,
592+
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
593+
ExpectedRegion: "us-east-1",
594+
MockStsEndpoints: []*servicemocks.MockEndpoint{
595+
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
596+
servicemocks.MockStsAssumeRoleValidEndpoint,
597+
servicemocks.MockStsGetCallerIdentityValidEndpoint,
598+
},
599+
},
563600
{
564601
Config: &Config{
565602
AccessKey: servicemocks.MockStaticAccessKey,
@@ -912,9 +949,8 @@ aws_secret_access_key = DefaultSharedCredentialsSecretKey
912949
defer closeEcsCredentials()
913950
}
914951

915-
if testCase.EnableWebIdentityToken {
952+
if testCase.EnableWebIdentityEnvVars || testCase.EnableWebIdentityConfig {
916953
file, err := ioutil.TempFile("", "aws-sdk-go-base-web-identity-token-file")
917-
918954
if err != nil {
919955
t.Fatalf("unexpected error creating temporary web identity token file: %s", err)
920956
}
@@ -927,9 +963,17 @@ aws_secret_access_key = DefaultSharedCredentialsSecretKey
927963
t.Fatalf("unexpected error writing web identity token file: %s", err)
928964
}
929965

930-
os.Setenv("AWS_ROLE_ARN", servicemocks.MockStsAssumeRoleWithWebIdentityArn)
931-
os.Setenv("AWS_ROLE_SESSION_NAME", servicemocks.MockStsAssumeRoleWithWebIdentitySessionName)
932-
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", file.Name())
966+
if testCase.EnableWebIdentityEnvVars {
967+
os.Setenv("AWS_ROLE_ARN", servicemocks.MockStsAssumeRoleWithWebIdentityArn)
968+
os.Setenv("AWS_ROLE_SESSION_NAME", servicemocks.MockStsAssumeRoleWithWebIdentitySessionName)
969+
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", file.Name())
970+
} else if testCase.EnableWebIdentityConfig {
971+
testCase.Config.AssumeRoleWithWebIdentity = &AssumeRoleWithWebIdentity{
972+
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
973+
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
974+
WebIdentityTokenFile: file.Name(),
975+
}
976+
}
933977
}
934978

935979
closeSts, _, stsEndpoint := mockdata.GetMockedAwsApiSession("STS", testCase.MockStsEndpoints)
@@ -2288,21 +2332,56 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) {
22882332
testCases := map[string]struct {
22892333
Config *Config
22902334
SetConfig bool
2335+
ExpandEnvVars bool
22912336
EnvironmentVariables map[string]string
22922337
SetEnvironmentVariable bool
22932338
SharedConfigurationFile string
22942339
SetSharedConfigurationFile bool
22952340
ExpectedCredentialsValue aws.Credentials
22962341
MockStsEndpoints []*servicemocks.MockEndpoint
22972342
}{
2298-
// "config": {
2299-
// Config: &Config{},
2300-
// SetConfig: true,
2301-
// ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
2302-
// MockStsEndpoints: []*servicemocks.MockEndpoint{
2303-
// servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
2304-
// },
2305-
// },
2343+
"config with inline token": {
2344+
Config: &Config{
2345+
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
2346+
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
2347+
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
2348+
WebIdentityToken: servicemocks.MockWebIdentityToken,
2349+
},
2350+
},
2351+
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
2352+
MockStsEndpoints: []*servicemocks.MockEndpoint{
2353+
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
2354+
},
2355+
},
2356+
2357+
"config with token file": {
2358+
Config: &Config{
2359+
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
2360+
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
2361+
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
2362+
},
2363+
},
2364+
SetConfig: true,
2365+
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
2366+
MockStsEndpoints: []*servicemocks.MockEndpoint{
2367+
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
2368+
},
2369+
},
2370+
2371+
"config with expanded path": {
2372+
Config: &Config{
2373+
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
2374+
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
2375+
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
2376+
},
2377+
},
2378+
SetConfig: true,
2379+
ExpandEnvVars: true,
2380+
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
2381+
MockStsEndpoints: []*servicemocks.MockEndpoint{
2382+
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
2383+
},
2384+
},
23062385

23072386
"envvar": {
23082387
Config: &Config{},
@@ -2331,19 +2410,24 @@ role_session_name = %[2]s
23312410
},
23322411
},
23332412

2334-
// "config overrides envvar": {
2335-
// Config: &Config{},
2336-
// SetConfig: true,
2337-
// EnvironmentVariables: map[string]string{
2338-
// "AWS_ROLE_ARN": servicemocks.MockStsAssumeRoleWithWebIdentityArn,
2339-
// "AWS_ROLE_SESSION_NAME": servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
2340-
// "AWS_WEB_IDENTITY_TOKEN_FILE": "no-such-file",
2341-
// },
2342-
// ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
2343-
// MockStsEndpoints: []*servicemocks.MockEndpoint{
2344-
// servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
2345-
// },
2346-
// },
2413+
"config overrides envvar": {
2414+
Config: &Config{
2415+
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
2416+
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
2417+
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
2418+
WebIdentityToken: servicemocks.MockWebIdentityToken,
2419+
},
2420+
},
2421+
EnvironmentVariables: map[string]string{
2422+
"AWS_ROLE_ARN": servicemocks.MockStsAssumeRoleWithWebIdentityArn,
2423+
"AWS_ROLE_SESSION_NAME": servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
2424+
"AWS_WEB_IDENTITY_TOKEN_FILE": "no-such-file",
2425+
},
2426+
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
2427+
MockStsEndpoints: []*servicemocks.MockEndpoint{
2428+
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
2429+
},
2430+
},
23472431

23482432
"envvar overrides shared configuration": {
23492433
Config: &Config{},
@@ -2363,6 +2447,36 @@ web_identity_token_file = no-such-file
23632447
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
23642448
},
23652449
},
2450+
2451+
"with duration": {
2452+
Config: &Config{
2453+
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
2454+
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
2455+
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
2456+
WebIdentityToken: servicemocks.MockWebIdentityToken,
2457+
Duration: 1 * time.Hour,
2458+
},
2459+
},
2460+
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
2461+
MockStsEndpoints: []*servicemocks.MockEndpoint{
2462+
servicemocks.MockStsAssumeRoleWithWebIdentityValidWithOptions(map[string]string{"DurationSeconds": "3600"}),
2463+
},
2464+
},
2465+
2466+
"with policy": {
2467+
Config: &Config{
2468+
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
2469+
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
2470+
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
2471+
WebIdentityToken: servicemocks.MockWebIdentityToken,
2472+
Policy: "{}",
2473+
},
2474+
},
2475+
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
2476+
MockStsEndpoints: []*servicemocks.MockEndpoint{
2477+
servicemocks.MockStsAssumeRoleWithWebIdentityValidWithOptions(map[string]string{"Policy": "{}"}),
2478+
},
2479+
},
23662480
}
23672481

23682482
for testName, testCase := range testCases {
@@ -2381,21 +2495,44 @@ web_identity_token_file = no-such-file
23812495

23822496
testCase.Config.StsEndpoint = stsEndpoint
23832497

2498+
tempdir, err := ioutil.TempDir("", "temp")
2499+
if err != nil {
2500+
t.Fatalf("error creating temp dir: %s", err)
2501+
}
2502+
defer os.Remove(tempdir)
2503+
os.Setenv("TMPDIR", tempdir)
2504+
23842505
tokenFile, err := ioutil.TempFile("", "aws-sdk-go-base-web-identity-token-file")
23852506
if err != nil {
23862507
t.Fatalf("unexpected error creating temporary web identity token file: %s", err)
23872508
}
2509+
tokenFileName := tokenFile.Name()
23882510

2389-
defer os.Remove(tokenFile.Name())
2511+
defer os.Remove(tokenFileName)
23902512

2391-
err = ioutil.WriteFile(tokenFile.Name(), []byte(servicemocks.MockWebIdentityToken), 0600)
2513+
err = ioutil.WriteFile(tokenFileName, []byte(servicemocks.MockWebIdentityToken), 0600)
23922514

23932515
if err != nil {
23942516
t.Fatalf("unexpected error writing web identity token file: %s", err)
23952517
}
23962518

2519+
if testCase.ExpandEnvVars {
2520+
tmpdir := os.Getenv("TMPDIR")
2521+
rel, err := filepath.Rel(tmpdir, tokenFileName)
2522+
if err != nil {
2523+
t.Fatalf("error making path relative: %s", err)
2524+
}
2525+
t.Logf("relative: %s", rel)
2526+
tokenFileName = filepath.Join("$TMPDIR", rel)
2527+
t.Logf("env tempfile: %s", tokenFileName)
2528+
}
2529+
2530+
if testCase.SetConfig {
2531+
testCase.Config.AssumeRoleWithWebIdentity.WebIdentityTokenFile = tokenFileName
2532+
}
2533+
23972534
if testCase.SetEnvironmentVariable {
2398-
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", tokenFile.Name())
2535+
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", tokenFileName)
23992536
}
24002537

24012538
if testCase.SharedConfigurationFile != "" {
@@ -2408,7 +2545,7 @@ web_identity_token_file = no-such-file
24082545
defer os.Remove(file.Name())
24092546

24102547
if testCase.SetSharedConfigurationFile {
2411-
testCase.SharedConfigurationFile += fmt.Sprintf("web_identity_token_file = %s\n", tokenFile.Name())
2548+
testCase.SharedConfigurationFile += fmt.Sprintf("web_identity_token_file = %s\n", tokenFileName)
24122549
}
24132550

24142551
err = ioutil.WriteFile(file.Name(), []byte(testCase.SharedConfigurationFile), 0600)

config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ type APNInfo = config.APNInfo
1313

1414
type AssumeRole = config.AssumeRole
1515

16+
type AssumeRoleWithWebIdentity = config.AssumeRoleWithWebIdentity
17+
1618
type UserAgentProducts = config.UserAgentProducts
1719

1820
type UserAgentProduct = config.UserAgentProduct

credentials.go

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,19 @@ func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProv
133133
return nil, "", fmt.Errorf("loading configuration: %w", err)
134134
}
135135

136+
// This can probably be configured directly in commonLoadOptions() once
137+
// https://github.com/aws/aws-sdk-go-v2/pull/1682 is merged
138+
if c.AssumeRoleWithWebIdentity != nil {
139+
if c.AssumeRoleWithWebIdentity.WebIdentityToken == "" && c.AssumeRoleWithWebIdentity.WebIdentityTokenFile == "" {
140+
return nil, "", c.NewCannotAssumeRoleWithWebIdentityError(fmt.Errorf("one of: WebIdentityToken, WebIdentityTokenFile must be set"))
141+
}
142+
provider, err := webIdentityCredentialsProvider(ctx, cfg, c)
143+
if err != nil {
144+
return nil, "", err
145+
}
146+
cfg.Credentials = provider
147+
}
148+
136149
creds, err := cfg.Credentials.Retrieve(ctx)
137150
if err != nil {
138151
if c.Profile != "" && os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
@@ -153,6 +166,30 @@ Error: %w`, err)
153166
return provider, creds.Source, err
154167
}
155168

169+
func webIdentityCredentialsProvider(ctx context.Context, awsConfig aws.Config, c *Config) (aws.CredentialsProvider, error) {
170+
ar := c.AssumeRoleWithWebIdentity
171+
client := stsClient(awsConfig, c)
172+
173+
appCreds := stscreds.NewWebIdentityRoleProvider(client, ar.RoleARN, ar, func(opts *stscreds.WebIdentityRoleOptions) {
174+
opts.RoleSessionName = ar.SessionName
175+
opts.Duration = ar.Duration
176+
177+
if ar.Policy != "" {
178+
opts.Policy = aws.String(ar.Policy)
179+
}
180+
181+
if len(ar.PolicyARNs) > 0 {
182+
opts.PolicyARNs = getPolicyDescriptorTypes(ar.PolicyARNs)
183+
}
184+
})
185+
186+
_, err := appCreds.Retrieve(ctx)
187+
if err != nil {
188+
return nil, c.NewCannotAssumeRoleWithWebIdentityError(err)
189+
}
190+
return aws.NewCredentialsCache(appCreds), nil
191+
}
192+
156193
func assumeRoleCredentialsProvider(ctx context.Context, awsConfig aws.Config, c *Config) (aws.CredentialsProvider, error) {
157194
ar := c.AssumeRole
158195
// When assuming a role, we need to first authenticate the base credentials above, then assume the desired role
@@ -173,16 +210,7 @@ func assumeRoleCredentialsProvider(ctx context.Context, awsConfig aws.Config, c
173210
}
174211

175212
if len(ar.PolicyARNs) > 0 {
176-
var policyDescriptorTypes []types.PolicyDescriptorType
177-
178-
for _, policyARN := range ar.PolicyARNs {
179-
policyDescriptorType := types.PolicyDescriptorType{
180-
Arn: aws.String(policyARN),
181-
}
182-
policyDescriptorTypes = append(policyDescriptorTypes, policyDescriptorType)
183-
}
184-
185-
opts.PolicyARNs = policyDescriptorTypes
213+
opts.PolicyARNs = getPolicyDescriptorTypes(ar.PolicyARNs)
186214
}
187215

188216
if len(ar.Tags) > 0 {
@@ -208,3 +236,15 @@ func assumeRoleCredentialsProvider(ctx context.Context, awsConfig aws.Config, c
208236
}
209237
return aws.NewCredentialsCache(appCreds), nil
210238
}
239+
240+
func getPolicyDescriptorTypes(policyARNs []string) []types.PolicyDescriptorType {
241+
var policyDescriptorTypes []types.PolicyDescriptorType
242+
243+
for _, policyARN := range policyARNs {
244+
policyDescriptorType := types.PolicyDescriptorType{
245+
Arn: aws.String(policyARN),
246+
}
247+
policyDescriptorTypes = append(policyDescriptorTypes, policyDescriptorType)
248+
}
249+
return policyDescriptorTypes
250+
}

0 commit comments

Comments
 (0)