diff --git a/src/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp b/src/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp index 683cf68a953..be617649e34 100644 --- a/src/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp +++ b/src/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -118,7 +119,6 @@ namespace Aws } std::shared_ptr request(CreateHttpRequest(ss.str(), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); - request->SetUserAgent(m_userAgent); if (authToken) @@ -132,6 +132,11 @@ namespace Aws AmazonWebServiceResult AWSHttpResourceClient::GetResourceWithAWSWebServiceResult(const std::shared_ptr &httpRequest) const { AWS_LOGSTREAM_TRACE(m_logtag.c_str(), "Retrieving credentials from " << httpRequest->GetURIString()); + if (!Aws::Utils::IsValidHost(httpRequest->GetUri().GetHost())) { + AWS_LOGSTREAM_FATAL(m_logtag.c_str(), "Invalid endpoint host constructed: " << httpRequest->GetURIString()); + return {{}, {}, HttpResponseCode::REQUEST_NOT_MADE}; + } + if (!m_httpClient) { AWS_LOGSTREAM_FATAL(m_logtag.c_str(), "Unable to get a response: missing http client!"); @@ -550,6 +555,7 @@ namespace Aws { ss << ".cn"; } + m_endpoint = ss.str(); AWS_LOGSTREAM_INFO(STS_RESOURCE_CLIENT_LOG_TAG, "Creating STS ResourceClient with endpoint: " << m_endpoint); @@ -685,6 +691,7 @@ namespace Aws { ss << ".cn"; } + return ss.str(); } diff --git a/tests/aws-cpp-sdk-core-tests/aws/auth/AWSCredentialsProviderTest.cpp b/tests/aws-cpp-sdk-core-tests/aws/auth/AWSCredentialsProviderTest.cpp index b5524217aef..c67272fcd5f 100644 --- a/tests/aws-cpp-sdk-core-tests/aws/auth/AWSCredentialsProviderTest.cpp +++ b/tests/aws-cpp-sdk-core-tests/aws/auth/AWSCredentialsProviderTest.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include @@ -956,6 +955,48 @@ sso_start_url = https://d-92671207e4.awsapps.com/start ASSERT_TRUE(mockHttpClient->GetAllRequestsMade().empty()); } +TEST_F(SSOCredentialsProviderTest, TestInvalidRegionCredentials) +{ + AWS_LOGSTREAM_DEBUG("TEST_SSO", "Preparing Test Token file in: " << m_ssoTokenFileName); + Aws::OFStream tokenFile(m_ssoTokenFileName.c_str(), Aws::OFStream::out | Aws::OFStream::trunc); + tokenFile << R"({ + "accessToken": "base64string", + "expiresAt": ")"; + tokenFile << DateTime::Now().GetYear() + 1; + tokenFile << R"(-01-02T00:00:00Z", + "region": "us-west-2", + "startUrl": "https://d-92671207e4.awsapps.com/start" +})"; + tokenFile.close(); + + Aws::OFStream configFile(m_configFileName.c_str(), Aws::OFStream::out | Aws::OFStream::trunc); + configFile << R"([default] +sso_account_id = 012345678901 +sso_region = @amazon.com# +sso_role_name = SampleRole +sso_start_url = https://d-92671207e4.awsapps.com/start +)"; + configFile.close(); + + // Mock DNS/connection failure for invalid region + std::shared_ptr requestTmp = CreateHttpRequest(URI("https://portal.sso.@amazon.com#.amazonaws.com/federation/credentials"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + std::shared_ptr dnsFailureResponse = Aws::MakeShared(AllocationTag, requestTmp); + dnsFailureResponse->SetResponseCode(HttpResponseCode::REQUEST_NOT_MADE); + mockHttpClient->AddResponseToReturn(dnsFailureResponse); + + Aws::Config::ReloadCachedConfigFile(); + SSOCredentialsProvider provider; + + auto creds = provider.GetAWSCredentials(); + ASSERT_TRUE(creds.IsEmpty()); + + // Check if any requests were made before calling GetMostRecentHttpRequest + if (!mockHttpClient->GetAllRequestsMade().empty()) { + auto request = mockHttpClient->GetMostRecentHttpRequest(); + ASSERT_TRUE(request.GetURIString().find("@amazon.com#") != std::string::npos); + } +} + class AWSCredentialsTest : public Aws::Testing::AwsCppSdkGTestSuite { }; @@ -1120,3 +1161,42 @@ TEST_F(AWSCachedCredentialsTest, ShouldCacheCredenitalAsync) ASSERT_TRUE(containCredentials(creds, {"and", "no", "surprises"})); ASSERT_FALSE(containCredentials(creds, {"a", "quiet", "life"})); } + +class STSCredentialsProviderTest : public Aws::Testing::AwsCppSdkGTestSuite { +public: + void SetUp() { + mockHttpClient = Aws::MakeShared(AllocationTag); + mockHttpClientFactory = Aws::MakeShared(AllocationTag); + mockHttpClientFactory->SetClient(mockHttpClient); + SetHttpClientFactory(mockHttpClientFactory); + } + + void TearDown() { + mockHttpClient = nullptr; + mockHttpClientFactory = nullptr; + CleanupHttp(); + InitHttp(); + } + + std::shared_ptr mockHttpClient; + std::shared_ptr mockHttpClientFactory; +}; + +TEST_F(STSCredentialsProviderTest, TestInvalidRegionCredentials) { + ClientConfiguration config; + config.region = "@amazon.com#"; + + Aws::Internal::STSCredentialsClient stsClient(config); + Aws::Internal::STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request; + request.roleArn = "arn:aws:iam::123456789012:role/TestRole"; + request.roleSessionName = "test-session"; + request.webIdentityToken = "test-token"; + + auto result = stsClient.GetAssumeRoleWithWebIdentityCredentials(request); + ASSERT_TRUE(result.creds.IsEmpty()); + + if (!mockHttpClient ->GetAllRequestsMade().empty()) { + auto httpRequest = mockHttpClient->GetMostRecentHttpRequest(); + ASSERT_TRUE(httpRequest.GetURIString().find("@amazon.com#") != std::string::npos); + } +} \ No newline at end of file