From 93f0d91d0b7b332f79816f3574e4b797e7abd17f Mon Sep 17 00:00:00 2001 From: Manuel Naujoks Date: Sat, 4 Oct 2025 14:10:22 +0200 Subject: [PATCH 1/6] Tokens can be cached beyond the lifetime of the (http) transport. --- .../Authentication/ClientOAuthOptions.cs | 6 +++++ .../Authentication/ClientOAuthProvider.cs | 19 +++++++------ .../Authentication/ITokenCache.cs | 17 ++++++++++++ .../Authentication/InMemoryTokenCache.cs | 27 +++++++++++++++++++ .../Authentication/TokenContainer.cs | 4 +-- 5 files changed, 63 insertions(+), 10 deletions(-) create mode 100644 src/ModelContextProtocol.Core/Authentication/ITokenCache.cs create mode 100644 src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs index cc6a8952e..ecb57df0a 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs @@ -86,4 +86,10 @@ public sealed class ClientOAuthOptions /// /// public IDictionary AdditionalAuthorizationParameters { get; set; } = new Dictionary(); + + /// + /// Gets or sets the token cache to use for storing and retrieving tokens beyond the lifetime of the transport. + /// If none is provided, tokens will be cached with the transport. + /// + public ITokenCache? TokenCache { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 468728982..e59fc22e8 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -43,7 +43,7 @@ internal sealed partial class ClientOAuthProvider private string? _clientId; private string? _clientSecret; - private TokenContainer? _token; + private ITokenCache _tokenCache; private AuthorizationServerMetadata? _authServerMetadata; /// @@ -85,6 +85,7 @@ public ClientOAuthProvider( _dcrClientUri = options.DynamicClientRegistration?.ClientUri; _dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken; _dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate; + _tokenCache = options.TokenCache ?? new InMemoryTokenCache(); } /// @@ -138,20 +139,22 @@ public ClientOAuthProvider( { ThrowIfNotBearerScheme(scheme); + var token = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); + // Return the token if it's valid - if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) { - return _token.AccessToken; + return token.AccessToken; } // Try to refresh the token if we have a refresh token - if (_token?.RefreshToken != null && _authServerMetadata != null) + if (token?.RefreshToken != null && _authServerMetadata != null) { - var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); + var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); if (newToken != null) { - _token = newToken; - return _token.AccessToken; + await _tokenCache.StoreTokenAsync(newToken, cancellationToken).ConfigureAwait(false); + return newToken.AccessToken; } } @@ -237,7 +240,7 @@ private async Task PerformOAuthAuthorizationAsync( ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); } - _token = token; + await _tokenCache.StoreTokenAsync(token, cancellationToken).ConfigureAwait(false); LogOAuthAuthorizationCompleted(); } diff --git a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs new file mode 100644 index 000000000..3619286b3 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Allows the client to cache access tokens beyond the lifetime of the transport. +/// +public interface ITokenCache +{ + /// + /// Cache the token. + /// + Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken); + + /// + /// Get the cached token. + /// + Task GetTokenAsync(CancellationToken cancellationToken); +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs new file mode 100644 index 000000000..529d56269 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs @@ -0,0 +1,27 @@ + +namespace ModelContextProtocol.Authentication; + +/// +/// Caches the token in-memory within this instance. +/// +internal class InMemoryTokenCache : ITokenCache +{ + private TokenContainer? _token; + + /// + /// Cache the token. + /// + public Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken) + { + _token = token; + return Task.CompletedTask; + } + + /// + /// Get the cached token. + /// + public Task GetTokenAsync(CancellationToken cancellationToken) + { + return Task.FromResult(_token); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs index dc55292b9..7ffe05372 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Authentication; /// /// Represents a token response from the OAuth server. /// -internal sealed class TokenContainer +public sealed class TokenContainer { /// /// Gets or sets the access token. @@ -46,7 +46,7 @@ internal sealed class TokenContainer /// /// Gets or sets the timestamp when the token was obtained. /// - [JsonIgnore] + [JsonPropertyName("obtained_at")] public DateTimeOffset ObtainedAt { get; set; } /// From 26f80ef6dbab2215f095a3c15cb72d417b727d0c Mon Sep 17 00:00:00 2001 From: Manuel Naujoks Date: Sat, 11 Oct 2025 18:45:31 +0200 Subject: [PATCH 2/6] Tests, ValueTasks, and dedicated type for caching. --- .../Authentication/ClientOAuthProvider.cs | 7 +- .../Authentication/ITokenCache.cs | 8 +- .../Authentication/InMemoryTokenCache.cs | 10 +- .../Authentication/TokenContainer.cs | 4 +- .../Authentication/TokenContainerCacheable.cs | 42 ++++ .../Authentication/TokenContainerConvert.cs | 26 ++ .../Client/CustomTokenCacheTests.cs | 233 ++++++++++++++++++ 7 files changed, 316 insertions(+), 14 deletions(-) create mode 100644 src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs create mode 100644 src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs create mode 100644 tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index e59fc22e8..bb411eae8 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -139,7 +139,8 @@ public ClientOAuthProvider( { ThrowIfNotBearerScheme(scheme); - var token = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); + var cachedToken = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); + var token = cachedToken?.ForUse(); // Return the token if it's valid if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) @@ -153,7 +154,7 @@ public ClientOAuthProvider( var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); if (newToken != null) { - await _tokenCache.StoreTokenAsync(newToken, cancellationToken).ConfigureAwait(false); + await _tokenCache.StoreTokenAsync(newToken.ForCache(), cancellationToken).ConfigureAwait(false); return newToken.AccessToken; } } @@ -240,7 +241,7 @@ private async Task PerformOAuthAuthorizationAsync( ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); } - await _tokenCache.StoreTokenAsync(token, cancellationToken).ConfigureAwait(false); + await _tokenCache.StoreTokenAsync(token.ForCache(), cancellationToken).ConfigureAwait(false); LogOAuthAuthorizationCompleted(); } diff --git a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs index 3619286b3..46d4cc37b 100644 --- a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs +++ b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs @@ -6,12 +6,12 @@ namespace ModelContextProtocol.Authentication; public interface ITokenCache { /// - /// Cache the token. + /// Cache the token. After a new access token is acquired, this method is invoked to store it. /// - Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken); + ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken); /// - /// Get the cached token. + /// Get the cached token. This method is invoked for every request. /// - Task GetTokenAsync(CancellationToken cancellationToken); + ValueTask GetTokenAsync(CancellationToken cancellationToken); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs index 529d56269..56346f731 100644 --- a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs +++ b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs @@ -6,22 +6,22 @@ namespace ModelContextProtocol.Authentication; /// internal class InMemoryTokenCache : ITokenCache { - private TokenContainer? _token; + private TokenContainerCacheable? _token; /// /// Cache the token. /// - public Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken) + public ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken) { _token = token; - return Task.CompletedTask; + return default; } /// /// Get the cached token. /// - public Task GetTokenAsync(CancellationToken cancellationToken) + public ValueTask GetTokenAsync(CancellationToken cancellationToken) { - return Task.FromResult(_token); + return new ValueTask(_token); } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs index 7ffe05372..dc55292b9 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Authentication; /// /// Represents a token response from the OAuth server. /// -public sealed class TokenContainer +internal sealed class TokenContainer { /// /// Gets or sets the access token. @@ -46,7 +46,7 @@ public sealed class TokenContainer /// /// Gets or sets the timestamp when the token was obtained. /// - [JsonPropertyName("obtained_at")] + [JsonIgnore] public DateTimeOffset ObtainedAt { get; set; } /// diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs new file mode 100644 index 000000000..5f6bf0e5c --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs @@ -0,0 +1,42 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a cacheable token representation. +/// +public class TokenContainerCacheable +{ + /// + /// Gets or sets the access token. + /// + public string AccessToken { get; set; } = string.Empty; + + /// + /// Gets or sets the refresh token. + /// + public string? RefreshToken { get; set; } + + /// + /// Gets or sets the number of seconds until the access token expires. + /// + public int ExpiresIn { get; set; } + + /// + /// Gets or sets the extended expiration time in seconds. + /// + public int ExtExpiresIn { get; set; } + + /// + /// Gets or sets the token type (typically "Bearer"). + /// + public string TokenType { get; set; } = string.Empty; + + /// + /// Gets or sets the scope of the access token. + /// + public string Scope { get; set; } = string.Empty; + + /// + /// Gets or sets the timestamp when the token was obtained. + /// + public DateTimeOffset ObtainedAt { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs new file mode 100644 index 000000000..6e2c8e9cd --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs @@ -0,0 +1,26 @@ +namespace ModelContextProtocol.Authentication; + +internal static class TokenContainerConvert +{ + internal static TokenContainer ForUse(this TokenContainerCacheable token) => new() + { + AccessToken = token.AccessToken, + RefreshToken = token.RefreshToken, + ExpiresIn = token.ExpiresIn, + ExtExpiresIn = token.ExtExpiresIn, + TokenType = token.TokenType, + Scope = token.Scope, + ObtainedAt = token.ObtainedAt, + }; + + internal static TokenContainerCacheable ForCache(this TokenContainer token) => new() + { + AccessToken = token.AccessToken, + RefreshToken = token.RefreshToken, + ExpiresIn = token.ExpiresIn, + ExtExpiresIn = token.ExtExpiresIn, + TokenType = token.TokenType, + Scope = token.Scope, + ObtainedAt = token.ObtainedAt, + }; +} diff --git a/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs new file mode 100644 index 000000000..3ea1262ae --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs @@ -0,0 +1,233 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Authentication; +using System.Text.Json; +using Moq; +using Moq.Protected; +using System.Net; +using System.Text.Json.Nodes; +using System.Linq.Expressions; + +namespace ModelContextProtocol.Tests.Client; + +public class CustomTokenCacheTests +{ + [Fact] + public async Task GetTokenAsync_CachedAccessTokenIsUsedForOutgoingRequests() + { + // Arrange + var cachedAccessToken = $"my_access_token_{Guid.NewGuid()}"; + + var tokenCacheMock = new Mock(); + MockCachedAccessToken(tokenCacheMock, cachedAccessToken); + + var httpMessageHandlerMock = new Mock(); + MockInitializeResponse(httpMessageHandlerMock); + + var httpClientTransport = new HttpClientTransport( + transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object), + httpClient: new HttpClient(httpMessageHandlerMock.Object)); + + var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Act + var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }; + await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken); + + // Assert + httpMessageHandlerMock + .Protected() + .Verify("SendAsync", Times.AtLeastOnce(), ItExpr.Is(req => + req.RequestUri == new Uri("http://localhost:1337/") + && req.Headers.Authorization != null + && req.Headers.Authorization.Scheme == "Bearer" + && req.Headers.Authorization.Parameter == cachedAccessToken + ), ItExpr.IsAny()); + + httpMessageHandlerMock + .Protected() + .Verify("SendAsync", Times.Never(), ItExpr.Is(req => + req.RequestUri == new Uri("http://localhost:1337/") + && (req.Headers.Authorization == null || req.Headers.Authorization.Parameter != cachedAccessToken) + ), ItExpr.IsAny()); + } + + [Fact] + public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() + { + // Arrange + var tokenCacheMock = new Mock(); + MockNoAccessTokenUntilStored(tokenCacheMock); + + var newAccessToken = $"new_access_token_{Guid.NewGuid()}"; + + var httpMessageHandlerMock = new Mock(); + MockUnauthorizedResponse(httpMessageHandlerMock); + MockProtectedResourceMetadataResponse(httpMessageHandlerMock); + MockAuthorizationServerMetadataResponse(httpMessageHandlerMock); + MockAccessTokenResponse(httpMessageHandlerMock, newAccessToken); + MockInitializeResponse(httpMessageHandlerMock); + + var httpClientTransport = new HttpClientTransport( + transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object), + httpClient: new HttpClient(httpMessageHandlerMock.Object)); + + var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Act + var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }; + await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken); + + // Assert + tokenCacheMock + .Verify(tc => tc.StoreTokenAsync( + It.Is(token => token.AccessToken == newAccessToken), + It.IsAny()), Times.Once); + } + + static HttpClientTransportOptions NewHttpClientTransportOptions(ITokenCache? tokenCache = null) => new() + { + Name = "MCP Server", + Endpoint = new Uri("http://localhost:1337/"), + TransportMode = HttpTransportMode.StreamableHttp, + OAuth = new() + { + ClientId = "mcp_inspector", + RedirectUri = new Uri("http://localhost:6274/oauth/callback"), + Scopes = ["openid", "profile", "offline_access"], + AuthorizationRedirectDelegate = (authorizationUrl, redirectUri, cancellationToken) => Task.FromResult($"auth_code_{Guid.NewGuid()}"), + TokenCache = tokenCache, + }, + }; + + static void MockCachedAccessToken(Mock tokenCache, string cachedAccessToken) + { + tokenCache + .Setup(tc => tc.GetTokenAsync(It.IsAny())) + .ReturnsAsync(new TokenContainerCacheable + { + AccessToken = cachedAccessToken, + ObtainedAt = DateTimeOffset.UtcNow, + ExpiresIn = (int)TimeSpan.FromHours(1).TotalSeconds, + }); + } + + static void MockNoAccessTokenUntilStored(Mock tokenCache) + { + tokenCache + .Setup(tc => tc.StoreTokenAsync(It.IsAny(), It.IsAny())) + .Callback((token, ct) => + { + // Simulate that the token is now cached + MockCachedAccessToken(tokenCache, token.AccessToken); + }) + .Returns(default(ValueTask)); + } + + static void MockUnauthorizedResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/") + && req.Method == HttpMethod.Post + && (req.Headers.Authorization == null || string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter)), + response: new HttpResponseMessage(HttpStatusCode.Unauthorized) + { + Headers = { + { "WWW-Authenticate", "Bearer realm=\"Bearer\", resource_metadata=\"http://localhost:1337/.well-known/oauth-protected-resource\"" } + }, + }); + } + + static void MockProtectedResourceMetadataResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/.well-known/oauth-protected-resource"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + resource = "http://localhost:1337/", + authorization_servers = new[] { "http://localhost:1336/" }, + }) + }); + } + + static void MockAuthorizationServerMetadataResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1336/.well-known/openid-configuration"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + authorization_endpoint = "http://localhost:1336/connect/authorize", + token_endpoint = "http://localhost:1336/connect/token", + }) + }); + } + + static void MockAccessTokenResponse(Mock httpMessageHandler, string accessToken) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1336/connect/token"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + access_token = accessToken, + }) + }); + } + + static void MockInitializeResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/") + && req.Method == HttpMethod.Post + && req.Headers.Authorization != null + && req.Headers.Authorization.Scheme == "Bearer" + && !string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new JsonRpcResponse + { + Id = new RequestId(1), + Result = ToJson(new InitializeResult + { + ProtocolVersion = "2024-11-05", + Capabilities = new ServerCapabilities + { + Prompts = new PromptsCapability { ListChanged = true }, + Resources = new ResourcesCapability { Subscribe = true, ListChanged = true }, + Tools = new ToolsCapability { ListChanged = true }, + Logging = new LoggingCapability(), + Completions = new CompletionsCapability(), + }, + ServerInfo = new Implementation + { + Name = "mcp-test-server", + Version = "1.0.0" + }, + Instructions = "This server provides weather information and file system access." + }) + }), + }); + } + + static void MockHttpResponse(Mock httpMessageHandler, Expression>? request = null, HttpResponseMessage? response = null) + { + httpMessageHandler + .Protected() + .Setup>("SendAsync", request != null ? ItExpr.Is(request) : ItExpr.IsAny(), ItExpr.IsAny()) + .ReturnsAsync(response ?? new HttpResponseMessage()); + } + + static StringContent ToJsonContent(T content) => new( + content: JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions), + encoding: System.Text.Encoding.UTF8, + mediaType: "application/json"); + + static JsonNode? ToJson(T content) => JsonSerializer.SerializeToNode( + value: content, + options: McpJsonUtilities.DefaultOptions); +} From bd4f0ff95f916f67a4072388c7f63674dd3b3e02 Mon Sep 17 00:00:00 2001 From: Manuel Naujoks Date: Sun, 26 Oct 2025 23:23:30 +0100 Subject: [PATCH 3/6] Type rename; alignment test --- .../Authentication/ClientOAuthProvider.cs | 39 +++++++++++-------- .../Authentication/ITokenCache.cs | 4 +- .../Authentication/InMemoryTokenCache.cs | 10 ++--- .../Authentication/TokenContainer.cs | 14 +------ .../Authentication/TokenContainerConvert.cs | 26 ------------- ...ContainerCacheable.cs => TokenResponse.cs} | 17 ++++---- .../McpJsonUtilities.cs | 2 +- .../Client/CustomTokenCacheTests.cs | 34 +++++++++++----- 8 files changed, 68 insertions(+), 78 deletions(-) delete mode 100644 src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs rename src/ModelContextProtocol.Core/Authentication/{TokenContainerCacheable.cs => TokenResponse.cs} (72%) diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index bb411eae8..503c6e402 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -139,23 +139,22 @@ public ClientOAuthProvider( { ThrowIfNotBearerScheme(scheme); - var cachedToken = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); - var token = cachedToken?.ForUse(); - + var tokens = await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false); + // Return the token if it's valid - if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + if (tokens != null && tokens.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) { - return token.AccessToken; + return tokens.AccessToken; } // Try to refresh the token if we have a refresh token - if (token?.RefreshToken != null && _authServerMetadata != null) + if (tokens?.RefreshToken != null && _authServerMetadata != null) { - var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); - if (newToken != null) + var newTokens = await RefreshTokenAsync(tokens.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); + if (newTokens != null) { - await _tokenCache.StoreTokenAsync(newToken.ForCache(), cancellationToken).ConfigureAwait(false); - return newToken.AccessToken; + await _tokenCache.StoreTokensAsync(newTokens, cancellationToken).ConfigureAwait(false); + return newTokens.AccessToken; } } @@ -234,14 +233,14 @@ private async Task PerformOAuthAuthorizationAsync( } // Perform the OAuth flow - var token = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); + var tokens = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); - if (token is null) + if (tokens is null) { ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); } - await _tokenCache.StoreTokenAsync(token.ForCache(), cancellationToken).ConfigureAwait(false); + await _tokenCache.StoreTokensAsync(tokens, cancellationToken).ConfigureAwait(false); LogOAuthAuthorizationCompleted(); } @@ -413,15 +412,23 @@ private async Task FetchTokenAsync(HttpRequestMessage request, C httpResponse.EnsureSuccessStatusCode(); using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); - var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenContainer, cancellationToken).ConfigureAwait(false); + var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenResponse, cancellationToken).ConfigureAwait(false); if (tokenResponse is null) { ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{request.RequestUri}' returned an empty response."); } - tokenResponse.ObtainedAt = DateTimeOffset.UtcNow; - return tokenResponse; + return new() + { + AccessToken = tokenResponse.AccessToken, + RefreshToken = tokenResponse.RefreshToken, + ExpiresIn = tokenResponse.ExpiresIn, + ExtExpiresIn = tokenResponse.ExtExpiresIn, + TokenType = tokenResponse.TokenType, + Scope = tokenResponse.Scope, + ObtainedAt = DateTimeOffset.UtcNow, + }; } /// diff --git a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs index 46d4cc37b..3dc6e6351 100644 --- a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs +++ b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs @@ -8,10 +8,10 @@ public interface ITokenCache /// /// Cache the token. After a new access token is acquired, this method is invoked to store it. /// - ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken); + ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken); /// /// Get the cached token. This method is invoked for every request. /// - ValueTask GetTokenAsync(CancellationToken cancellationToken); + ValueTask GetTokensAsync(CancellationToken cancellationToken); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs index 56346f731..977cb6f88 100644 --- a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs +++ b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs @@ -6,22 +6,22 @@ namespace ModelContextProtocol.Authentication; /// internal class InMemoryTokenCache : ITokenCache { - private TokenContainerCacheable? _token; + private TokenContainer? _tokens; /// /// Cache the token. /// - public ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken) + public ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken) { - _token = token; + _tokens = tokens; return default; } /// /// Get the cached token. /// - public ValueTask GetTokenAsync(CancellationToken cancellationToken) + public ValueTask GetTokensAsync(CancellationToken cancellationToken) { - return new ValueTask(_token); + return new ValueTask(_tokens); } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs index dc55292b9..5503c96f1 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -1,57 +1,47 @@ -using System.Text.Json.Serialization; - namespace ModelContextProtocol.Authentication; /// -/// Represents a token response from the OAuth server. +/// Represents a cacheable combination of tokens ready to be used for authentication. /// -internal sealed class TokenContainer +public class TokenContainer { /// /// Gets or sets the access token. /// - [JsonPropertyName("access_token")] public string AccessToken { get; set; } = string.Empty; /// /// Gets or sets the refresh token. /// - [JsonPropertyName("refresh_token")] public string? RefreshToken { get; set; } /// /// Gets or sets the number of seconds until the access token expires. /// - [JsonPropertyName("expires_in")] public int ExpiresIn { get; set; } /// /// Gets or sets the extended expiration time in seconds. /// - [JsonPropertyName("ext_expires_in")] public int ExtExpiresIn { get; set; } /// /// Gets or sets the token type (typically "Bearer"). /// - [JsonPropertyName("token_type")] public string TokenType { get; set; } = string.Empty; /// /// Gets or sets the scope of the access token. /// - [JsonPropertyName("scope")] public string Scope { get; set; } = string.Empty; /// /// Gets or sets the timestamp when the token was obtained. /// - [JsonIgnore] public DateTimeOffset ObtainedAt { get; set; } /// /// Gets the timestamp when the token expires, calculated from ObtainedAt and ExpiresIn. /// - [JsonIgnore] public DateTimeOffset ExpiresAt => ObtainedAt.AddSeconds(ExpiresIn); } diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs deleted file mode 100644 index 6e2c8e9cd..000000000 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs +++ /dev/null @@ -1,26 +0,0 @@ -namespace ModelContextProtocol.Authentication; - -internal static class TokenContainerConvert -{ - internal static TokenContainer ForUse(this TokenContainerCacheable token) => new() - { - AccessToken = token.AccessToken, - RefreshToken = token.RefreshToken, - ExpiresIn = token.ExpiresIn, - ExtExpiresIn = token.ExtExpiresIn, - TokenType = token.TokenType, - Scope = token.Scope, - ObtainedAt = token.ObtainedAt, - }; - - internal static TokenContainerCacheable ForCache(this TokenContainer token) => new() - { - AccessToken = token.AccessToken, - RefreshToken = token.RefreshToken, - ExpiresIn = token.ExpiresIn, - ExtExpiresIn = token.ExtExpiresIn, - TokenType = token.TokenType, - Scope = token.Scope, - ObtainedAt = token.ObtainedAt, - }; -} diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs b/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs similarity index 72% rename from src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs rename to src/ModelContextProtocol.Core/Authentication/TokenResponse.cs index 5f6bf0e5c..9eba5ffbf 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs @@ -1,42 +1,45 @@ +using System.Text.Json.Serialization; + namespace ModelContextProtocol.Authentication; /// -/// Represents a cacheable token representation. +/// Represents a token response from the OAuth server. /// -public class TokenContainerCacheable +internal sealed class TokenResponse { /// /// Gets or sets the access token. /// + [JsonPropertyName("access_token")] public string AccessToken { get; set; } = string.Empty; /// /// Gets or sets the refresh token. /// + [JsonPropertyName("refresh_token")] public string? RefreshToken { get; set; } /// /// Gets or sets the number of seconds until the access token expires. /// + [JsonPropertyName("expires_in")] public int ExpiresIn { get; set; } /// /// Gets or sets the extended expiration time in seconds. /// + [JsonPropertyName("ext_expires_in")] public int ExtExpiresIn { get; set; } /// /// Gets or sets the token type (typically "Bearer"). /// + [JsonPropertyName("token_type")] public string TokenType { get; set; } = string.Empty; /// /// Gets or sets the scope of the access token. /// + [JsonPropertyName("scope")] public string Scope { get; set; } = string.Empty; - - /// - /// Gets or sets the timestamp when the token was obtained. - /// - public DateTimeOffset ObtainedAt { get; set; } } diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index 8bc9e21b0..a6cb2e13e 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -158,7 +158,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(ProtectedResourceMetadata))] [JsonSerializable(typeof(AuthorizationServerMetadata))] - [JsonSerializable(typeof(TokenContainer))] + [JsonSerializable(typeof(TokenResponse))] [JsonSerializable(typeof(DynamicClientRegistrationRequest))] [JsonSerializable(typeof(DynamicClientRegistrationResponse))] diff --git a/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs index 3ea1262ae..fd16d3073 100644 --- a/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs @@ -2,6 +2,7 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Authentication; using System.Text.Json; +using System.Text.Json.Serialization.Metadata; using Moq; using Moq.Protected; using System.Net; @@ -12,6 +13,16 @@ namespace ModelContextProtocol.Tests.Client; public class CustomTokenCacheTests { + [Fact] + public void TokenContainerIsAlignedWithTokenResponse() + { + var tokenResponseType = Type.GetType("ModelContextProtocol.Authentication.TokenResponse, ModelContextProtocol.Core"); + Assert.NotNull(tokenResponseType); + var tokenResponseProperties = tokenResponseType.GetProperties().Select(p => p.Name); + var tokenContainerProperties = typeof(TokenContainer).GetProperties().Select(p => p.Name); + Assert.Equivalent(tokenResponseProperties, tokenContainerProperties); + } + [Fact] public async Task GetTokenAsync_CachedAccessTokenIsUsedForOutgoingRequests() { @@ -80,8 +91,8 @@ public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() // Assert tokenCacheMock - .Verify(tc => tc.StoreTokenAsync( - It.Is(token => token.AccessToken == newAccessToken), + .Verify(tc => tc.StoreTokensAsync( + It.Is(token => token.AccessToken == newAccessToken), It.IsAny()), Times.Once); } @@ -103,8 +114,8 @@ public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() static void MockCachedAccessToken(Mock tokenCache, string cachedAccessToken) { tokenCache - .Setup(tc => tc.GetTokenAsync(It.IsAny())) - .ReturnsAsync(new TokenContainerCacheable + .Setup(tc => tc.GetTokensAsync(It.IsAny())) + .ReturnsAsync(new TokenContainer { AccessToken = cachedAccessToken, ObtainedAt = DateTimeOffset.UtcNow, @@ -115,8 +126,8 @@ static void MockCachedAccessToken(Mock tokenCache, string cachedAcc static void MockNoAccessTokenUntilStored(Mock tokenCache) { tokenCache - .Setup(tc => tc.StoreTokenAsync(It.IsAny(), It.IsAny())) - .Callback((token, ct) => + .Setup(tc => tc.StoreTokensAsync(It.IsAny(), It.IsAny())) + .Callback((token, ct) => { // Simulate that the token is now cached MockCachedAccessToken(tokenCache, token.AccessToken); @@ -216,18 +227,23 @@ static void MockInitializeResponse(Mock httpMessageHandler) static void MockHttpResponse(Mock httpMessageHandler, Expression>? request = null, HttpResponseMessage? response = null) { - httpMessageHandler + _ = httpMessageHandler .Protected() .Setup>("SendAsync", request != null ? ItExpr.Is(request) : ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(response ?? new HttpResponseMessage()); } static StringContent ToJsonContent(T content) => new( - content: JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions), + content: JsonSerializer.Serialize(content, GetReflectionCapableJsonOptions()), encoding: System.Text.Encoding.UTF8, mediaType: "application/json"); static JsonNode? ToJson(T content) => JsonSerializer.SerializeToNode( value: content, - options: McpJsonUtilities.DefaultOptions); + options: GetReflectionCapableJsonOptions()); + + static JsonSerializerOptions GetReflectionCapableJsonOptions() => new(JsonSerializerDefaults.Web) + { + TypeInfoResolver = new DefaultJsonTypeInfoResolver() + }; } From db341ad1f5c7a70389b3b05889495474883f6e4a Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Tue, 18 Nov 2025 15:36:30 -0800 Subject: [PATCH 4/6] Add OAuthTestBase and use it to test TokenCache --- .../Authentication/ClientOAuthProvider.cs | 82 +++--- .../Authentication/TokenContainer.cs | 30 +-- .../Authentication/TokenResponse.cs | 14 +- .../{ => OAuth}/AuthEventTests.cs | 175 ++---------- .../{ => OAuth}/AuthTests.cs | 163 ++---------- .../OAuth/OAuthTestBase.cs | 107 ++++++++ .../OAuth/TokenCacheTests.cs | 215 +++++++++++++++ .../Program.cs | 4 +- .../Client/CustomTokenCacheTests.cs | 249 ------------------ 9 files changed, 433 insertions(+), 606 deletions(-) rename tests/ModelContextProtocol.AspNetCore.Tests/{ => OAuth}/AuthEventTests.cs (63%) rename tests/ModelContextProtocol.AspNetCore.Tests/{ => OAuth}/AuthTests.cs (68%) create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/OAuth/OAuthTestBase.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/OAuth/TokenCacheTests.cs delete mode 100644 tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 503c6e402..162f3ecb5 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -24,6 +24,8 @@ internal sealed partial class ClientOAuthProvider /// private const string BearerScheme = "Bearer"; + private static readonly string[] s_wellKnownPaths = [".well-known/openid-configuration", ".well-known/oauth-authorization-server"]; + private readonly Uri _serverUrl; private readonly Uri _redirectUri; private readonly string[]? _scopes; @@ -57,11 +59,11 @@ internal sealed partial class ClientOAuthProvider public ClientOAuthProvider( Uri serverUrl, ClientOAuthOptions options, - HttpClient? httpClient = null, + HttpClient httpClient, ILoggerFactory? loggerFactory = null) { _serverUrl = serverUrl ?? throw new ArgumentNullException(nameof(serverUrl)); - _httpClient = httpClient ?? new HttpClient(); + _httpClient = httpClient; _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; if (options is null) @@ -140,20 +142,19 @@ public ClientOAuthProvider( ThrowIfNotBearerScheme(scheme); var tokens = await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false); - + // Return the token if it's valid - if (tokens != null && tokens.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + if (tokens is not null && !tokens.IsExpired) { return tokens.AccessToken; } - // Try to refresh the token if we have a refresh token + // Try to refresh the access token if it is invalid and we have a refresh token. if (tokens?.RefreshToken != null && _authServerMetadata != null) { var newTokens = await RefreshTokenAsync(tokens.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); - if (newTokens != null) + if (newTokens is not null) { - await _tokenCache.StoreTokensAsync(newTokens, cancellationToken).ConfigureAwait(false); return newTokens.AccessToken; } } @@ -226,6 +227,17 @@ private async Task PerformOAuthAuthorizationAsync( // Store auth server metadata for future refresh operations _authServerMetadata = authServerMetadata; + // The existing access token must be invalid to have resulted in a 401 response, but refresh might still work. + if (await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false) is { RefreshToken: {} refreshToken }) + { + var refreshedTokens = await RefreshTokenAsync(refreshToken, protectedResourceMetadata.Resource, authServerMetadata, cancellationToken).ConfigureAwait(false); + if (refreshedTokens is not null) + { + // A non-null result indicates the refresh succeeded and the new tokens have been stored. + return; + } + } + // Perform dynamic client registration if needed if (string.IsNullOrEmpty(_clientId)) { @@ -233,19 +245,11 @@ private async Task PerformOAuthAuthorizationAsync( } // Perform the OAuth flow - var tokens = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); + await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); - if (tokens is null) - { - ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); - } - - await _tokenCache.StoreTokensAsync(tokens, cancellationToken).ConfigureAwait(false); LogOAuthAuthorizationCompleted(); } - private static readonly string[] s_wellKnownPaths = [".well-known/openid-configuration", ".well-known/oauth-authorization-server"]; - private async Task GetAuthServerMetadataAsync(Uri authServerUri, CancellationToken cancellationToken) { if (authServerUri.OriginalString.Length == 0 || @@ -301,7 +305,7 @@ private async Task GetAuthServerMetadataAsync(Uri a throw new McpException($"Failed to find .well-known/openid-configuration or .well-known/oauth-authorization-server metadata for authorization server: '{authServerUri}'"); } - private async Task RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) + private async Task RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) { var requestContent = new FormUrlEncodedContent(new Dictionary { @@ -317,10 +321,17 @@ private async Task RefreshTokenAsync(string refreshToken, Uri re Content = requestContent }; - return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false); + using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); + + if (!httpResponse.IsSuccessStatusCode) + { + return null; + } + + return await HandleSuccessfulTokenResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false); } - private async Task InitiateAuthorizationCodeFlowAsync( + private async Task InitiateAuthorizationCodeFlowAsync( ProtectedResourceMetadata protectedResourceMetadata, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) @@ -333,10 +344,10 @@ private async Task RefreshTokenAsync(string refreshToken, Uri re if (string.IsNullOrEmpty(authCode)) { - return null; + ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty authorization code."); } - return await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false); + await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode, codeVerifier, cancellationToken).ConfigureAwait(false); } private Uri BuildAuthorizationUrl( @@ -380,7 +391,7 @@ private Uri BuildAuthorizationUrl( return uriBuilder.Uri; } - private async Task ExchangeCodeForTokenAsync( + private async Task ExchangeCodeForTokenAsync( ProtectedResourceMetadata protectedResourceMetadata, AuthorizationServerMetadata authServerMetadata, string authorizationCode, @@ -403,32 +414,39 @@ private async Task ExchangeCodeForTokenAsync( Content = requestContent }; - return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false); - } - - private async Task FetchTokenAsync(HttpRequestMessage request, CancellationToken cancellationToken) - { using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); httpResponse.EnsureSuccessStatusCode(); + await HandleSuccessfulTokenResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false); + } - using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + private async Task HandleSuccessfulTokenResponseAsync(HttpResponseMessage response, CancellationToken cancellationToken) + { + using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenResponse, cancellationToken).ConfigureAwait(false); if (tokenResponse is null) { - ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{request.RequestUri}' returned an empty response."); + ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{response.RequestMessage?.RequestUri}' returned an empty response."); + } + + if (tokenResponse.TokenType is null || !string.Equals(tokenResponse.TokenType, BearerScheme, StringComparison.OrdinalIgnoreCase)) + { + ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{response.RequestMessage?.RequestUri}' returned an unsupported token type: '{tokenResponse.TokenType ?? ""}'. Only 'Bearer' tokens are supported."); } - return new() + TokenContainer tokens = new() { AccessToken = tokenResponse.AccessToken, RefreshToken = tokenResponse.RefreshToken, ExpiresIn = tokenResponse.ExpiresIn, - ExtExpiresIn = tokenResponse.ExtExpiresIn, TokenType = tokenResponse.TokenType, Scope = tokenResponse.Scope, ObtainedAt = DateTimeOffset.UtcNow, }; + + await _tokenCache.StoreTokensAsync(tokens, cancellationToken).ConfigureAwait(false); + + return tokens; } /// @@ -592,7 +610,7 @@ private async Task ExtractProtectedResourceMetadata(H string? resourceMetadataUrl = null; foreach (var header in response.Headers.WwwAuthenticate) { - if (string.Equals(header.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter)) + if (string.Equals(header.Scheme, BearerScheme, StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter)) { resourceMetadataUrl = ParseWwwAuthenticateParameters(header.Parameter, "resource_metadata"); if (resourceMetadataUrl != null) diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs index 5503c96f1..8126137f8 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -3,12 +3,17 @@ namespace ModelContextProtocol.Authentication; /// /// Represents a cacheable combination of tokens ready to be used for authentication. /// -public class TokenContainer +public sealed class TokenContainer { + /// + /// Gets or sets the token type (typically "Bearer"). + /// + public required string TokenType { get; set; } + /// /// Gets or sets the access token. /// - public string AccessToken { get; set; } = string.Empty; + public required string AccessToken { get; set; } /// /// Gets or sets the refresh token. @@ -18,30 +23,17 @@ public class TokenContainer /// /// Gets or sets the number of seconds until the access token expires. /// - public int ExpiresIn { get; set; } - - /// - /// Gets or sets the extended expiration time in seconds. - /// - public int ExtExpiresIn { get; set; } - - /// - /// Gets or sets the token type (typically "Bearer"). - /// - public string TokenType { get; set; } = string.Empty; + public int? ExpiresIn { get; set; } /// /// Gets or sets the scope of the access token. /// - public string Scope { get; set; } = string.Empty; + public string? Scope { get; set; } /// /// Gets or sets the timestamp when the token was obtained. /// - public DateTimeOffset ObtainedAt { get; set; } + public required DateTimeOffset ObtainedAt { get; set; } - /// - /// Gets the timestamp when the token expires, calculated from ObtainedAt and ExpiresIn. - /// - public DateTimeOffset ExpiresAt => ObtainedAt.AddSeconds(ExpiresIn); + internal bool IsExpired => ExpiresIn is not null && DateTimeOffset.UtcNow >= ObtainedAt.AddSeconds(ExpiresIn.Value); } diff --git a/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs b/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs index 9eba5ffbf..721196734 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs @@ -11,7 +11,7 @@ internal sealed class TokenResponse /// Gets or sets the access token. /// [JsonPropertyName("access_token")] - public string AccessToken { get; set; } = string.Empty; + public required string AccessToken { get; set; } /// /// Gets or sets the refresh token. @@ -23,23 +23,17 @@ internal sealed class TokenResponse /// Gets or sets the number of seconds until the access token expires. /// [JsonPropertyName("expires_in")] - public int ExpiresIn { get; set; } - - /// - /// Gets or sets the extended expiration time in seconds. - /// - [JsonPropertyName("ext_expires_in")] - public int ExtExpiresIn { get; set; } + public int? ExpiresIn { get; set; } /// /// Gets or sets the token type (typically "Bearer"). /// [JsonPropertyName("token_type")] - public string TokenType { get; set; } = string.Empty; + public required string TokenType { get; set; } /// /// Gets or sets the scope of the access token. /// [JsonPropertyName("scope")] - public string Scope { get; set; } = string.Empty; + public string? Scope { get; set; } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthEventTests.cs similarity index 63% rename from tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs rename to tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthEventTests.cs index 9144121e8..001aa82f0 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthEventTests.cs @@ -1,110 +1,43 @@ -using System.Net; -using System.Net.Http.Json; -using System.Text.Json; -using Microsoft.AspNetCore.Authentication.JwtBearer; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.DependencyInjection; -using Microsoft.IdentityModel.Tokens; using ModelContextProtocol.AspNetCore.Authentication; -using ModelContextProtocol.AspNetCore.Tests.Utils; using ModelContextProtocol.Authentication; using ModelContextProtocol.Client; +using System.Net; +using System.Net.Http.Json; -namespace ModelContextProtocol.AspNetCore.Tests; +namespace ModelContextProtocol.AspNetCore.Tests.OAuth; /// /// Tests for MCP authentication when resource metadata is provided via events rather than static configuration. /// -public class AuthEventTests : KestrelInMemoryTest, IAsyncDisposable +public class AuthEventTests : OAuthTestBase { - private const string McpServerUrl = "http://localhost:5000"; - private const string OAuthServerUrl = "https://localhost:7029"; - - private readonly CancellationTokenSource _testCts = new(); - private readonly TestOAuthServer.Program _testOAuthServer; - private readonly Task _testOAuthRunTask; - public AuthEventTests(ITestOutputHelper outputHelper) - : base(outputHelper) + : base(outputHelper, configureMcpMetadata: false) { - // Let the HandleAuthorizationUrlAsync take a look at the Location header - SocketsHttpHandler.AllowAutoRedirect = false; - // The dev cert may not be installed on the CI, but AddJwtBearer requires an HTTPS backchannel by default. - // The easiest workaround is to disable cert validation for testing purposes. - SocketsHttpHandler.SslOptions.RemoteCertificateValidationCallback = (_, _, _, _) => true; - - _testOAuthServer = new TestOAuthServer.Program( - XunitLoggerProvider, - KestrelInMemoryTransport - ); - _testOAuthRunTask = _testOAuthServer.RunServerAsync(cancellationToken: _testCts.Token); + Builder.Services.Configure(McpAuthenticationDefaults.AuthenticationScheme, options => + { + // Note: ResourceMetadata is NOT set here - it will be provided via events + options.ResourceMetadata = null; - Builder - .Services.AddAuthentication(options => - { - options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; - options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; - }) - .AddJwtBearer(options => - { - options.Backchannel = HttpClient; - options.Authority = OAuthServerUrl; - options.TokenValidationParameters = new TokenValidationParameters - { - ValidateIssuer = true, - ValidateAudience = true, - ValidateLifetime = true, - ValidateIssuerSigningKey = true, - ValidAudience = McpServerUrl, - ValidIssuer = OAuthServerUrl, - NameClaimType = "name", - RoleClaimType = "roles", - }; - }) - .AddMcp(options => + options.Events.OnResourceMetadataRequest = async context => { - // Note: ResourceMetadata is NOT set here - it will be provided via events - options.Events.OnResourceMetadataRequest = async context => + // Dynamically provide the resource metadata + context.ResourceMetadata = new ProtectedResourceMetadata { - // Dynamically provide the resource metadata - context.ResourceMetadata = new ProtectedResourceMetadata - { - Resource = new Uri(McpServerUrl), - AuthorizationServers = { new Uri(OAuthServerUrl) }, - ScopesSupported = ["mcp:tools"], - }; - await Task.CompletedTask; + Resource = new Uri(McpServerUrl), + AuthorizationServers = { new Uri(OAuthServerUrl) }, + ScopesSupported = ["mcp:tools"], }; - }); - - Builder.Services.AddAuthorization(); - } - - public async ValueTask DisposeAsync() - { - _testCts.Cancel(); - try - { - await _testOAuthRunTask; - } - catch (OperationCanceledException) { } - finally - { - _testCts.Dispose(); - } + await Task.CompletedTask; + }; + }); } [Fact] public async Task CanAuthenticate_WithResourceMetadataFromEvent() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport( new() @@ -132,13 +65,7 @@ public async Task CanAuthenticate_WithResourceMetadataFromEvent() [Fact] public async Task CanAuthenticate_WithDynamicClientRegistration_FromEvent() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); DynamicClientRegistrationResponse? dcrResponse = null; @@ -181,13 +108,7 @@ public async Task CanAuthenticate_WithDynamicClientRegistration_FromEvent() [Fact] public async Task ResourceMetadataEndpoint_ReturnsCorrectMetadata_FromEvent() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); // Make a direct request to the resource metadata endpoint using var response = await HttpClient.GetAsync( @@ -211,8 +132,6 @@ public async Task ResourceMetadataEndpoint_ReturnsCorrectMetadata_FromEvent() [Fact] public async Task ResourceMetadataEndpoint_CanModifyExistingMetadata_InEvent() { - Builder.Services.AddMcpServer().WithHttpTransport(); - // Override the configuration to test modification of existing metadata Builder.Services.Configure( McpAuthenticationDefaults.AuthenticationScheme, @@ -240,11 +159,7 @@ public async Task ResourceMetadataEndpoint_CanModifyExistingMetadata_InEvent() } ); - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); // Make a direct request to the resource metadata endpoint using var response = await HttpClient.GetAsync( @@ -270,15 +185,12 @@ public async Task ResourceMetadataEndpoint_CanModifyExistingMetadata_InEvent() [Fact] public async Task ResourceMetadataEndpoint_ThrowsException_WhenNoMetadataProvided() { - Builder.Services.AddMcpServer().WithHttpTransport(); - // Override the configuration to test the error case where no metadata is provided Builder.Services.Configure( McpAuthenticationDefaults.AuthenticationScheme, options => { // Don't set ResourceMetadata and provide an event that doesn't set it either - options.ResourceMetadata = null; options.Events.OnResourceMetadataRequest = async context => { // Intentionally don't set context.ResourceMetadata to test error handling @@ -287,11 +199,7 @@ public async Task ResourceMetadataEndpoint_ThrowsException_WhenNoMetadataProvide } ); - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); // Make a direct request to the resource metadata endpoint - this should fail using var response = await HttpClient.GetAsync( @@ -306,14 +214,11 @@ public async Task ResourceMetadataEndpoint_ThrowsException_WhenNoMetadataProvide [Fact] public async Task ResourceMetadataEndpoint_HandlesResponse_WhenHandleResponseCalled() { - Builder.Services.AddMcpServer().WithHttpTransport(); - // Override the configuration to test HandleResponse behavior Builder.Services.Configure( McpAuthenticationDefaults.AuthenticationScheme, options => { - options.ResourceMetadata = null; options.Events.OnResourceMetadataRequest = async context => { // Call HandleResponse() to discontinue processing and return to client @@ -323,11 +228,7 @@ public async Task ResourceMetadataEndpoint_HandlesResponse_WhenHandleResponseCal } ); - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); // Make a direct request to the resource metadata endpoint using var response = await HttpClient.GetAsync( @@ -349,14 +250,11 @@ public async Task ResourceMetadataEndpoint_HandlesResponse_WhenHandleResponseCal [Fact] public async Task ResourceMetadataEndpoint_SkipsHandler_WhenSkipHandlerCalled() { - Builder.Services.AddMcpServer().WithHttpTransport(); - // Override the configuration to test SkipHandler behavior Builder.Services.Configure( McpAuthenticationDefaults.AuthenticationScheme, options => { - options.ResourceMetadata = null; options.Events.OnResourceMetadataRequest = async context => { // Call SkipHandler() to discontinue processing in the current handler @@ -366,11 +264,7 @@ public async Task ResourceMetadataEndpoint_SkipsHandler_WhenSkipHandlerCalled() } ); - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); // Make a direct request to the resource metadata endpoint using var response = await HttpClient.GetAsync( @@ -383,23 +277,4 @@ public async Task ResourceMetadataEndpoint_SkipsHandler_WhenSkipHandlerCalled() // other handlers configured for this endpoint, this should result in a 404 Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); } - - private async Task HandleAuthorizationUrlAsync( - Uri authorizationUri, - Uri redirectUri, - CancellationToken cancellationToken - ) - { - using var redirectResponse = await HttpClient.GetAsync(authorizationUri, cancellationToken); - Assert.Equal(HttpStatusCode.Redirect, redirectResponse.StatusCode); - var location = redirectResponse.Headers.Location; - - if (location is not null && !string.IsNullOrEmpty(location.Query)) - { - var queryParams = QueryHelpers.ParseQuery(location.Query); - return queryParams["code"]; - } - - return null; - } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthTests.cs similarity index 68% rename from tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs rename to tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthTests.cs index fff7d6d42..bed6f8aaa 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthTests.cs @@ -1,101 +1,23 @@ -using Microsoft.AspNetCore.Authentication.JwtBearer; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.WebUtilities; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.IdentityModel.Tokens; using ModelContextProtocol.AspNetCore.Authentication; -using ModelContextProtocol.AspNetCore.Tests.Utils; using ModelContextProtocol.Authentication; using ModelContextProtocol.Client; using System.Net; using System.Reflection; using Xunit.Sdk; -namespace ModelContextProtocol.AspNetCore.Tests; +namespace ModelContextProtocol.AspNetCore.Tests.OAuth; -public class AuthTests : KestrelInMemoryTest, IAsyncDisposable +public class AuthTests : OAuthTestBase { - private const string McpServerUrl = "http://localhost:5000"; - private const string OAuthServerUrl = "https://localhost:7029"; - - private readonly CancellationTokenSource _testCts = new(); - private readonly TestOAuthServer.Program _testOAuthServer; - private readonly Task _testOAuthRunTask; - - private Uri? _lastAuthorizationUri; - - public AuthTests(ITestOutputHelper outputHelper) + public AuthTests(ITestOutputHelper outputHelper) : base(outputHelper) { - // Let the HandleAuthorizationUrlAsync take a look at the Location header - SocketsHttpHandler.AllowAutoRedirect = false; - // The dev cert may not be installed on the CI, but AddJwtBearer requires an HTTPS backchannel by default. - // The easiest workaround is to disable cert validation for testing purposes. - SocketsHttpHandler.SslOptions.RemoteCertificateValidationCallback = (_, _, _, _) => true; - - _testOAuthServer = new TestOAuthServer.Program(XunitLoggerProvider, KestrelInMemoryTransport); - _testOAuthRunTask = _testOAuthServer.RunServerAsync(cancellationToken: _testCts.Token); - - Builder.Services.AddAuthentication(options => - { - options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; - options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; - }) - .AddJwtBearer(options => - { - options.Backchannel = HttpClient; - options.Authority = OAuthServerUrl; - options.TokenValidationParameters = new TokenValidationParameters - { - ValidateIssuer = true, - ValidateAudience = true, - ValidateLifetime = true, - ValidateIssuerSigningKey = true, - ValidAudience = McpServerUrl, - ValidIssuer = OAuthServerUrl, - NameClaimType = "name", - RoleClaimType = "roles" - }; - }) - .AddMcp(options => - { - options.ResourceMetadata = new ProtectedResourceMetadata - { - Resource = new Uri(McpServerUrl), - AuthorizationServers = { new Uri(OAuthServerUrl) }, - ScopesSupported = ["mcp:tools"] - }; - }); - - Builder.Services.AddAuthorization(); - } - - public async ValueTask DisposeAsync() - { - _testCts.Cancel(); - try - { - await _testOAuthRunTask; - } - catch (OperationCanceledException) - { - } - finally - { - _testCts.Dispose(); - } } [Fact] public async Task CanAuthenticate() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -116,13 +38,7 @@ public async Task CanAuthenticate() [Fact] public async Task CannotAuthenticate_WithoutOAuthConfiguration() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -138,13 +54,7 @@ public async Task CannotAuthenticate_WithoutOAuthConfiguration() [Fact] public async Task CannotAuthenticate_WithUnregisteredClient() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -166,13 +76,7 @@ public async Task CannotAuthenticate_WithUnregisteredClient() [Fact] public async Task CanAuthenticate_WithDynamicClientRegistration() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -197,13 +101,7 @@ public async Task CanAuthenticate_WithDynamicClientRegistration() [Fact] public async Task CanAuthenticate_WithTokenRefresh() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -222,19 +120,15 @@ public async Task CanAuthenticate_WithTokenRefresh() await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); - Assert.True(_testOAuthServer.HasIssuedRefreshToken); + Assert.True(TestOAuthServer.HasRefreshedToken); } [Fact] public async Task CanAuthenticate_WithExtraParams() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); + await using var app = await StartMcpServerAsync(); - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + Uri? lastAuthorizationUri = null; await using var transport = new HttpClientTransport(new() { @@ -244,7 +138,11 @@ public async Task CanAuthenticate_WithExtraParams() ClientId = "demo-client", ClientSecret = "demo-secret", RedirectUri = new Uri("http://localhost:1179/callback"), - AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + lastAuthorizationUri = uri; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, AdditionalAuthorizationParameters = new Dictionary { ["custom_param"] = "custom_value", @@ -255,20 +153,14 @@ public async Task CanAuthenticate_WithExtraParams() await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); - Assert.NotNull(_lastAuthorizationUri?.Query); - Assert.Contains("custom_param=custom_value", _lastAuthorizationUri?.Query); + Assert.NotNull(lastAuthorizationUri?.Query); + Assert.Contains("custom_param=custom_value", lastAuthorizationUri?.Query); } [Fact] public async Task CannotOverrideExistingParameters_WithExtraParams() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -390,21 +282,4 @@ public void CloneResourceMetadataClonesAllProperties() // Ensure we've checked every property. When new properties get added, we'll have to update this test along with the CloneResourceMetadata implementation. Assert.Empty(propertyNames); } - - private async Task HandleAuthorizationUrlAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken) - { - _lastAuthorizationUri = authorizationUri; - - var redirectResponse = await HttpClient.GetAsync(authorizationUri, cancellationToken); - Assert.Equal(HttpStatusCode.Redirect, redirectResponse.StatusCode); - var location = redirectResponse.Headers.Location; - - if (location is not null && !string.IsNullOrEmpty(location.Query)) - { - var queryParams = QueryHelpers.ParseQuery(location.Query); - return queryParams["code"]; - } - - return null; - } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/OAuthTestBase.cs b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/OAuthTestBase.cs new file mode 100644 index 000000000..834bf83fd --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/OAuthTestBase.cs @@ -0,0 +1,107 @@ +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; +using ModelContextProtocol.AspNetCore.Authentication; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Authentication; +using System.Net; +using Xunit.Sdk; + +namespace ModelContextProtocol.AspNetCore.Tests.OAuth; + +public abstract class OAuthTestBase : KestrelInMemoryTest, IAsyncDisposable +{ + protected const string McpServerUrl = "http://localhost:5000"; + protected const string OAuthServerUrl = "https://localhost:7029"; + + protected readonly CancellationTokenSource TestCts = new(); + protected readonly TestOAuthServer.Program TestOAuthServer; + private readonly Task _testOAuthRunTask; + + protected OAuthTestBase(ITestOutputHelper outputHelper, bool configureMcpMetadata = true) + : base(outputHelper) + { + // Let the HandleAuthorizationUrlAsync take a look at the Location header + SocketsHttpHandler.AllowAutoRedirect = false; + // The dev cert may not be installed on the CI, but AddJwtBearer requires an HTTPS backchannel by default. + // The easiest workaround is to disable cert validation for testing purposes. + SocketsHttpHandler.SslOptions.RemoteCertificateValidationCallback = (_, _, _, _) => true; + + TestOAuthServer = new TestOAuthServer.Program(XunitLoggerProvider, KestrelInMemoryTransport); + _testOAuthRunTask = TestOAuthServer.RunServerAsync(cancellationToken: TestCts.Token); + + Builder.Services.AddAuthentication(options => + { + options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; + }) + .AddJwtBearer(options => + { + options.Backchannel = HttpClient; + options.Authority = OAuthServerUrl; + options.TokenValidationParameters = new TokenValidationParameters + { + ValidAudience = McpServerUrl, + ValidIssuer = OAuthServerUrl, + NameClaimType = "name", + RoleClaimType = "roles" + }; + }) + .AddMcp(options => + { + if (configureMcpMetadata) + { + options.ResourceMetadata = new ProtectedResourceMetadata + { + Resource = new Uri(McpServerUrl), + AuthorizationServers = { new Uri(OAuthServerUrl) }, + ScopesSupported = ["mcp:tools"] + }; + } + }); + + Builder.Services.AddAuthorization(); + Builder.Services.AddMcpServer().WithHttpTransport(); + } + + public async ValueTask DisposeAsync() + { + TestCts.Cancel(); + try + { + await _testOAuthRunTask; + } + catch (OperationCanceledException) + { + } + finally + { + TestCts.Dispose(); + } + } + + protected async Task StartMcpServerAsync() + { + var app = Builder.Build(); + app.MapMcp().RequireAuthorization(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + protected async Task HandleAuthorizationUrlAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken) + { + using var redirectResponse = await HttpClient.GetAsync(authorizationUri, cancellationToken); + Assert.Equal(HttpStatusCode.Redirect, redirectResponse.StatusCode); + var location = redirectResponse.Headers.Location; + + if (location is not null && !string.IsNullOrEmpty(location.Query)) + { + var queryParams = QueryHelpers.ParseQuery(location.Query); + return queryParams["code"]; + } + + return null; + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/TokenCacheTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/TokenCacheTests.cs new file mode 100644 index 000000000..fb9e2bfda --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/TokenCacheTests.cs @@ -0,0 +1,215 @@ +using ModelContextProtocol.Authentication; +using ModelContextProtocol.Client; + +namespace ModelContextProtocol.AspNetCore.Tests.OAuth; + +public class TokenCacheTests : OAuthTestBase +{ + public TokenCacheTests(ITestOutputHelper outputHelper) + : base(outputHelper) + { + } + + [Fact] + public async Task GetTokenAsync_CachedAccessTokenIsUsedForOutgoingRequests() + { + await using var app = await StartMcpServerAsync(); + + var tokenCache = new TestTokenCache(); + bool authDelegateCalledInitially = false; + + await using var setupTransport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + authDelegateCalledInitially = true; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, + TokenCache = tokenCache, + }, + }, HttpClient, LoggerFactory); + + await using (var setupClient = await McpClient.CreateAsync(setupTransport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) + { + // Just connecting should trigger auth and storage. + } + + Assert.True(authDelegateCalledInitially, "AuthorizationRedirectDelegate should be called to get initial token"); + Assert.NotNull(tokenCache.LastStoredToken); + + var authDelegateCalledAgain = false; + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + authDelegateCalledAgain = true; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, + TokenCache = tokenCache + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(authDelegateCalledAgain, "AuthorizationRedirectDelegate should not be called when token is valid"); + } + + [Fact] + public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() + { + await using var app = await StartMcpServerAsync(); + + var tokenCache = new TestTokenCache(); + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + TokenCache = tokenCache + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(tokenCache.LastStoredToken); + Assert.False(string.IsNullOrEmpty(tokenCache.LastStoredToken.AccessToken)); + } + + [Fact] + public async Task GetTokenAsync_InvalidCachedTokenTriggersAuthDelegate() + { + await using var app = await StartMcpServerAsync(); + + var tokenCache = new TestTokenCache(CreateInvalidToken()); + bool authDelegateCalled = false; + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + authDelegateCalled = true; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, + TokenCache = tokenCache, + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.True(authDelegateCalled, "AuthorizationRedirectDelegate should be called when cached token is invalid"); + Assert.NotNull(tokenCache.LastStoredToken); + Assert.NotEqual("invalid-token", tokenCache.LastStoredToken.AccessToken); + } + + [Fact] + public async Task GetTokenAsync_InvalidAccessTokenTriggersRefresh() + { + await using var app = await StartMcpServerAsync(); + + var tokenCache = new TestTokenCache(); + bool authDelegateCalledInitially = false; + + await using var setupTransport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + authDelegateCalledInitially = true; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, + TokenCache = tokenCache, + }, + }, HttpClient, LoggerFactory); + + await using (var setupClient = await McpClient.CreateAsync(setupTransport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) + { + // Just connecting should trigger auth and storage. + } + + Assert.True(authDelegateCalledInitially, "AuthorizationRedirectDelegate should be called to get initial token"); + Assert.False(TestOAuthServer.HasRefreshedToken, "Token should not have been refreshed yet"); + Assert.NotNull(tokenCache.LastStoredToken); + + // Invalidate the access token but keep the refresh token valid (if any) + tokenCache.LastStoredToken.AccessToken = "invalid-token"; + var authDelegateCalledAgain = false; + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + authDelegateCalledAgain = true; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, + TokenCache = tokenCache + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(authDelegateCalledAgain, "AuthorizationRedirectDelegate should not be called when refresh token is valid"); + Assert.True(TestOAuthServer.HasRefreshedToken, "Token should have been refreshed"); + Assert.NotEqual("invalid-token", tokenCache.LastStoredToken.AccessToken); + } + + private TokenContainer CreateInvalidToken() + { + return new TokenContainer + { + TokenType = "Bearer", + AccessToken = "invalid-token", + ObtainedAt = DateTimeOffset.UtcNow, + }; + } + + private class TestTokenCache(TokenContainer? initialToken = null) : ITokenCache + { + public TokenContainer? LastStoredToken { get; private set; } = initialToken; + + public ValueTask GetTokensAsync(CancellationToken cancellationToken) + { + return new ValueTask(LastStoredToken); + } + + public ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken) + { + LastStoredToken = tokens; + return ValueTask.CompletedTask; + } + } +} diff --git a/tests/ModelContextProtocol.TestOAuthServer/Program.cs b/tests/ModelContextProtocol.TestOAuthServer/Program.cs index dea484bfe..2a67a0c9c 100644 --- a/tests/ModelContextProtocol.TestOAuthServer/Program.cs +++ b/tests/ModelContextProtocol.TestOAuthServer/Program.cs @@ -42,7 +42,7 @@ public Program(ILoggerProvider? loggerProvider = null, IConnectionListenerFactor // Track if we've already issued an already-expired token for the CanAuthenticate_WithTokenRefresh test which uses the test-refresh-client registration. public bool HasIssuedExpiredToken { get; set; } - public bool HasIssuedRefreshToken { get; set; } + public bool HasRefreshedToken { get; set; } /// /// Entry point for the application. @@ -368,7 +368,7 @@ public async Task RunServerAsync(string[]? args = null, CancellationToken cancel _tokens.TryRemove(refresh_token, out _); } - HasIssuedRefreshToken = true; + HasRefreshedToken = true; return Results.Ok(response); } else diff --git a/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs deleted file mode 100644 index fd16d3073..000000000 --- a/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs +++ /dev/null @@ -1,249 +0,0 @@ -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Authentication; -using System.Text.Json; -using System.Text.Json.Serialization.Metadata; -using Moq; -using Moq.Protected; -using System.Net; -using System.Text.Json.Nodes; -using System.Linq.Expressions; - -namespace ModelContextProtocol.Tests.Client; - -public class CustomTokenCacheTests -{ - [Fact] - public void TokenContainerIsAlignedWithTokenResponse() - { - var tokenResponseType = Type.GetType("ModelContextProtocol.Authentication.TokenResponse, ModelContextProtocol.Core"); - Assert.NotNull(tokenResponseType); - var tokenResponseProperties = tokenResponseType.GetProperties().Select(p => p.Name); - var tokenContainerProperties = typeof(TokenContainer).GetProperties().Select(p => p.Name); - Assert.Equivalent(tokenResponseProperties, tokenContainerProperties); - } - - [Fact] - public async Task GetTokenAsync_CachedAccessTokenIsUsedForOutgoingRequests() - { - // Arrange - var cachedAccessToken = $"my_access_token_{Guid.NewGuid()}"; - - var tokenCacheMock = new Mock(); - MockCachedAccessToken(tokenCacheMock, cachedAccessToken); - - var httpMessageHandlerMock = new Mock(); - MockInitializeResponse(httpMessageHandlerMock); - - var httpClientTransport = new HttpClientTransport( - transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object), - httpClient: new HttpClient(httpMessageHandlerMock.Object)); - - var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Act - var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }; - await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - httpMessageHandlerMock - .Protected() - .Verify("SendAsync", Times.AtLeastOnce(), ItExpr.Is(req => - req.RequestUri == new Uri("http://localhost:1337/") - && req.Headers.Authorization != null - && req.Headers.Authorization.Scheme == "Bearer" - && req.Headers.Authorization.Parameter == cachedAccessToken - ), ItExpr.IsAny()); - - httpMessageHandlerMock - .Protected() - .Verify("SendAsync", Times.Never(), ItExpr.Is(req => - req.RequestUri == new Uri("http://localhost:1337/") - && (req.Headers.Authorization == null || req.Headers.Authorization.Parameter != cachedAccessToken) - ), ItExpr.IsAny()); - } - - [Fact] - public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() - { - // Arrange - var tokenCacheMock = new Mock(); - MockNoAccessTokenUntilStored(tokenCacheMock); - - var newAccessToken = $"new_access_token_{Guid.NewGuid()}"; - - var httpMessageHandlerMock = new Mock(); - MockUnauthorizedResponse(httpMessageHandlerMock); - MockProtectedResourceMetadataResponse(httpMessageHandlerMock); - MockAuthorizationServerMetadataResponse(httpMessageHandlerMock); - MockAccessTokenResponse(httpMessageHandlerMock, newAccessToken); - MockInitializeResponse(httpMessageHandlerMock); - - var httpClientTransport = new HttpClientTransport( - transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object), - httpClient: new HttpClient(httpMessageHandlerMock.Object)); - - var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Act - var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }; - await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - tokenCacheMock - .Verify(tc => tc.StoreTokensAsync( - It.Is(token => token.AccessToken == newAccessToken), - It.IsAny()), Times.Once); - } - - static HttpClientTransportOptions NewHttpClientTransportOptions(ITokenCache? tokenCache = null) => new() - { - Name = "MCP Server", - Endpoint = new Uri("http://localhost:1337/"), - TransportMode = HttpTransportMode.StreamableHttp, - OAuth = new() - { - ClientId = "mcp_inspector", - RedirectUri = new Uri("http://localhost:6274/oauth/callback"), - Scopes = ["openid", "profile", "offline_access"], - AuthorizationRedirectDelegate = (authorizationUrl, redirectUri, cancellationToken) => Task.FromResult($"auth_code_{Guid.NewGuid()}"), - TokenCache = tokenCache, - }, - }; - - static void MockCachedAccessToken(Mock tokenCache, string cachedAccessToken) - { - tokenCache - .Setup(tc => tc.GetTokensAsync(It.IsAny())) - .ReturnsAsync(new TokenContainer - { - AccessToken = cachedAccessToken, - ObtainedAt = DateTimeOffset.UtcNow, - ExpiresIn = (int)TimeSpan.FromHours(1).TotalSeconds, - }); - } - - static void MockNoAccessTokenUntilStored(Mock tokenCache) - { - tokenCache - .Setup(tc => tc.StoreTokensAsync(It.IsAny(), It.IsAny())) - .Callback((token, ct) => - { - // Simulate that the token is now cached - MockCachedAccessToken(tokenCache, token.AccessToken); - }) - .Returns(default(ValueTask)); - } - - static void MockUnauthorizedResponse(Mock httpMessageHandler) - { - MockHttpResponse(httpMessageHandler, - request: req => req.RequestUri == new Uri("http://localhost:1337/") - && req.Method == HttpMethod.Post - && (req.Headers.Authorization == null || string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter)), - response: new HttpResponseMessage(HttpStatusCode.Unauthorized) - { - Headers = { - { "WWW-Authenticate", "Bearer realm=\"Bearer\", resource_metadata=\"http://localhost:1337/.well-known/oauth-protected-resource\"" } - }, - }); - } - - static void MockProtectedResourceMetadataResponse(Mock httpMessageHandler) - { - MockHttpResponse(httpMessageHandler, - request: req => req.RequestUri == new Uri("http://localhost:1337/.well-known/oauth-protected-resource"), - response: new HttpResponseMessage(HttpStatusCode.OK) - { - Content = ToJsonContent(new - { - resource = "http://localhost:1337/", - authorization_servers = new[] { "http://localhost:1336/" }, - }) - }); - } - - static void MockAuthorizationServerMetadataResponse(Mock httpMessageHandler) - { - MockHttpResponse(httpMessageHandler, - request: req => req.RequestUri == new Uri("http://localhost:1336/.well-known/openid-configuration"), - response: new HttpResponseMessage(HttpStatusCode.OK) - { - Content = ToJsonContent(new - { - authorization_endpoint = "http://localhost:1336/connect/authorize", - token_endpoint = "http://localhost:1336/connect/token", - }) - }); - } - - static void MockAccessTokenResponse(Mock httpMessageHandler, string accessToken) - { - MockHttpResponse(httpMessageHandler, - request: req => req.RequestUri == new Uri("http://localhost:1336/connect/token"), - response: new HttpResponseMessage(HttpStatusCode.OK) - { - Content = ToJsonContent(new - { - access_token = accessToken, - }) - }); - } - - static void MockInitializeResponse(Mock httpMessageHandler) - { - MockHttpResponse(httpMessageHandler, - request: req => req.RequestUri == new Uri("http://localhost:1337/") - && req.Method == HttpMethod.Post - && req.Headers.Authorization != null - && req.Headers.Authorization.Scheme == "Bearer" - && !string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter), - response: new HttpResponseMessage(HttpStatusCode.OK) - { - Content = ToJsonContent(new JsonRpcResponse - { - Id = new RequestId(1), - Result = ToJson(new InitializeResult - { - ProtocolVersion = "2024-11-05", - Capabilities = new ServerCapabilities - { - Prompts = new PromptsCapability { ListChanged = true }, - Resources = new ResourcesCapability { Subscribe = true, ListChanged = true }, - Tools = new ToolsCapability { ListChanged = true }, - Logging = new LoggingCapability(), - Completions = new CompletionsCapability(), - }, - ServerInfo = new Implementation - { - Name = "mcp-test-server", - Version = "1.0.0" - }, - Instructions = "This server provides weather information and file system access." - }) - }), - }); - } - - static void MockHttpResponse(Mock httpMessageHandler, Expression>? request = null, HttpResponseMessage? response = null) - { - _ = httpMessageHandler - .Protected() - .Setup>("SendAsync", request != null ? ItExpr.Is(request) : ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(response ?? new HttpResponseMessage()); - } - - static StringContent ToJsonContent(T content) => new( - content: JsonSerializer.Serialize(content, GetReflectionCapableJsonOptions()), - encoding: System.Text.Encoding.UTF8, - mediaType: "application/json"); - - static JsonNode? ToJson(T content) => JsonSerializer.SerializeToNode( - value: content, - options: GetReflectionCapableJsonOptions()); - - static JsonSerializerOptions GetReflectionCapableJsonOptions() => new(JsonSerializerDefaults.Web) - { - TypeInfoResolver = new DefaultJsonTypeInfoResolver() - }; -} From 71bf435251328b02e67557fa6adcc4d5c388252d Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 19 Nov 2025 14:17:27 -0800 Subject: [PATCH 5/6] Add back null forgiving operator that should be unnecessary --- .../Authentication/ClientOAuthProvider.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 162f3ecb5..b0d66bebe 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -347,7 +347,7 @@ private async Task InitiateAuthorizationCodeFlowAsync( ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty authorization code."); } - await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode, codeVerifier, cancellationToken).ConfigureAwait(false); + await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false); } private Uri BuildAuthorizationUrl( From feb41026d0b60dd7c70e1baa092720f17befa3cc Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Mon, 24 Nov 2025 17:37:02 -0800 Subject: [PATCH 6/6] Address PR feedback --- .../Authentication/ClientOAuthProvider.cs | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index b0d66bebe..572024fdc 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -13,9 +13,7 @@ namespace ModelContextProtocol.Authentication; /// -/// A generic implementation of an OAuth authorization provider for MCP. This does not do any advanced token -/// protection or caching - it acquires a token and server metadata and holds it in memory. -/// This is suitable for demonstration and development purposes. +/// A generic implementation of an OAuth authorization provider. /// internal sealed partial class ClientOAuthProvider { @@ -178,12 +176,7 @@ public async Task HandleUnauthorizedResponseAsync( HttpResponseMessage response, CancellationToken cancellationToken = default) { - // This provider only supports Bearer scheme - if (!string.Equals(scheme, BearerScheme, StringComparison.OrdinalIgnoreCase)) - { - throw new InvalidOperationException("This credential provider only supports the Bearer scheme"); - } - + ThrowIfNotBearerScheme(scheme); await PerformOAuthAuthorizationAsync(response, cancellationToken).ConfigureAwait(false); }