Skip to content

Commit 091a493

Browse files
pulimsrsbiscigl
andauthored
validating url for sso and sts (#3610)
Co-authored-by: sbiscigl <[email protected]>
1 parent 792e1c1 commit 091a493

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

src/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <aws/core/http/HttpResponse.h>
1111
#include <aws/core/utils/logging/LogMacros.h>
1212
#include <aws/core/utils/ARN.h>
13+
#include <aws/core/utils/DNS.h>
1314
#include <aws/core/utils/StringUtils.h>
1415
#include <aws/core/utils/HashingUtils.h>
1516
#include <aws/core/platform/Environment.h>
@@ -118,7 +119,6 @@ namespace Aws
118119
}
119120
std::shared_ptr<HttpRequest> request(CreateHttpRequest(ss.str(), HttpMethod::HTTP_GET,
120121
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod));
121-
122122
request->SetUserAgent(m_userAgent);
123123

124124
if (authToken)
@@ -132,6 +132,11 @@ namespace Aws
132132
AmazonWebServiceResult<Aws::String> AWSHttpResourceClient::GetResourceWithAWSWebServiceResult(const std::shared_ptr<HttpRequest> &httpRequest) const
133133
{
134134
AWS_LOGSTREAM_TRACE(m_logtag.c_str(), "Retrieving credentials from " << httpRequest->GetURIString());
135+
if (!Aws::Utils::IsValidHost(httpRequest->GetUri().GetHost())) {
136+
AWS_LOGSTREAM_FATAL(m_logtag.c_str(), "Invalid endpoint host constructed: " << httpRequest->GetURIString());
137+
return {{}, {}, HttpResponseCode::REQUEST_NOT_MADE};
138+
}
139+
135140
if (!m_httpClient)
136141
{
137142
AWS_LOGSTREAM_FATAL(m_logtag.c_str(), "Unable to get a response: missing http client!");
@@ -550,6 +555,7 @@ namespace Aws
550555
{
551556
ss << ".cn";
552557
}
558+
553559
m_endpoint = ss.str();
554560

555561
AWS_LOGSTREAM_INFO(STS_RESOURCE_CLIENT_LOG_TAG, "Creating STS ResourceClient with endpoint: " << m_endpoint);
@@ -685,6 +691,7 @@ namespace Aws
685691
{
686692
ss << ".cn";
687693
}
694+
688695
return ss.str();
689696
}
690697

tests/aws-cpp-sdk-core-tests/aws/auth/AWSCredentialsProviderTest.cpp

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <aws/core/auth/AWSCredentialsProvider.h>
77
#include <aws/core/auth/AWSCredentialsProviderChain.h>
88
#include <aws/core/auth/SSOCredentialsProvider.h>
9-
#include <aws/core/auth/STSCredentialsProvider.h>
109
#include <aws/core/client/AWSError.h>
1110
#include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h>
1211
#include <aws/core/config/AWSProfileConfigLoader.h>
@@ -956,6 +955,48 @@ sso_start_url = https://d-92671207e4.awsapps.com/start
956955
ASSERT_TRUE(mockHttpClient->GetAllRequestsMade().empty());
957956
}
958957

958+
TEST_F(SSOCredentialsProviderTest, TestInvalidRegionCredentials)
959+
{
960+
AWS_LOGSTREAM_DEBUG("TEST_SSO", "Preparing Test Token file in: " << m_ssoTokenFileName);
961+
Aws::OFStream tokenFile(m_ssoTokenFileName.c_str(), Aws::OFStream::out | Aws::OFStream::trunc);
962+
tokenFile << R"({
963+
"accessToken": "base64string",
964+
"expiresAt": ")";
965+
tokenFile << DateTime::Now().GetYear() + 1;
966+
tokenFile << R"(-01-02T00:00:00Z",
967+
"region": "us-west-2",
968+
"startUrl": "https://d-92671207e4.awsapps.com/start"
969+
})";
970+
tokenFile.close();
971+
972+
Aws::OFStream configFile(m_configFileName.c_str(), Aws::OFStream::out | Aws::OFStream::trunc);
973+
configFile << R"([default]
974+
sso_account_id = 012345678901
975+
sso_region = @amazon.com#
976+
sso_role_name = SampleRole
977+
sso_start_url = https://d-92671207e4.awsapps.com/start
978+
)";
979+
configFile.close();
980+
981+
// Mock DNS/connection failure for invalid region
982+
std::shared_ptr<HttpRequest> requestTmp = CreateHttpRequest(URI("https://[email protected]#.amazonaws.com/federation/credentials"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
983+
std::shared_ptr<StandardHttpResponse> dnsFailureResponse = Aws::MakeShared<StandardHttpResponse>(AllocationTag, requestTmp);
984+
dnsFailureResponse->SetResponseCode(HttpResponseCode::REQUEST_NOT_MADE);
985+
mockHttpClient->AddResponseToReturn(dnsFailureResponse);
986+
987+
Aws::Config::ReloadCachedConfigFile();
988+
SSOCredentialsProvider provider;
989+
990+
auto creds = provider.GetAWSCredentials();
991+
ASSERT_TRUE(creds.IsEmpty());
992+
993+
// Check if any requests were made before calling GetMostRecentHttpRequest
994+
if (!mockHttpClient->GetAllRequestsMade().empty()) {
995+
auto request = mockHttpClient->GetMostRecentHttpRequest();
996+
ASSERT_TRUE(request.GetURIString().find("@amazon.com#") != std::string::npos);
997+
}
998+
}
999+
9591000
class AWSCredentialsTest : public Aws::Testing::AwsCppSdkGTestSuite
9601001
{
9611002
};
@@ -1120,3 +1161,42 @@ TEST_F(AWSCachedCredentialsTest, ShouldCacheCredenitalAsync)
11201161
ASSERT_TRUE(containCredentials(creds, {"and", "no", "surprises"}));
11211162
ASSERT_FALSE(containCredentials(creds, {"a", "quiet", "life"}));
11221163
}
1164+
1165+
class STSCredentialsProviderTest : public Aws::Testing::AwsCppSdkGTestSuite {
1166+
public:
1167+
void SetUp() {
1168+
mockHttpClient = Aws::MakeShared<MockHttpClient>(AllocationTag);
1169+
mockHttpClientFactory = Aws::MakeShared<MockHttpClientFactory>(AllocationTag);
1170+
mockHttpClientFactory->SetClient(mockHttpClient);
1171+
SetHttpClientFactory(mockHttpClientFactory);
1172+
}
1173+
1174+
void TearDown() {
1175+
mockHttpClient = nullptr;
1176+
mockHttpClientFactory = nullptr;
1177+
CleanupHttp();
1178+
InitHttp();
1179+
}
1180+
1181+
std::shared_ptr<MockHttpClient> mockHttpClient;
1182+
std::shared_ptr<MockHttpClientFactory> mockHttpClientFactory;
1183+
};
1184+
1185+
TEST_F(STSCredentialsProviderTest, TestInvalidRegionCredentials) {
1186+
ClientConfiguration config;
1187+
config.region = "@amazon.com#";
1188+
1189+
Aws::Internal::STSCredentialsClient stsClient(config);
1190+
Aws::Internal::STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request;
1191+
request.roleArn = "arn:aws:iam::123456789012:role/TestRole";
1192+
request.roleSessionName = "test-session";
1193+
request.webIdentityToken = "test-token";
1194+
1195+
auto result = stsClient.GetAssumeRoleWithWebIdentityCredentials(request);
1196+
ASSERT_TRUE(result.creds.IsEmpty());
1197+
1198+
if (!mockHttpClient ->GetAllRequestsMade().empty()) {
1199+
auto httpRequest = mockHttpClient->GetMostRecentHttpRequest();
1200+
ASSERT_TRUE(httpRequest.GetURIString().find("@amazon.com#") != std::string::npos);
1201+
}
1202+
}

0 commit comments

Comments
 (0)