diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs index cc6a8952e..ecb57df0a 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs @@ -86,4 +86,10 @@ public sealed class ClientOAuthOptions /// /// public IDictionary AdditionalAuthorizationParameters { get; set; } = new Dictionary(); + + /// + /// Gets or sets the token cache to use for storing and retrieving tokens beyond the lifetime of the transport. + /// If none is provided, tokens will be cached with the transport. + /// + public ITokenCache? TokenCache { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 468728982..572024fdc 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -13,9 +13,7 @@ namespace ModelContextProtocol.Authentication; /// -/// A generic implementation of an OAuth authorization provider for MCP. This does not do any advanced token -/// protection or caching - it acquires a token and server metadata and holds it in memory. -/// This is suitable for demonstration and development purposes. +/// A generic implementation of an OAuth authorization provider. /// internal sealed partial class ClientOAuthProvider { @@ -24,6 +22,8 @@ internal sealed partial class ClientOAuthProvider /// private const string BearerScheme = "Bearer"; + private static readonly string[] s_wellKnownPaths = [".well-known/openid-configuration", ".well-known/oauth-authorization-server"]; + private readonly Uri _serverUrl; private readonly Uri _redirectUri; private readonly string[]? _scopes; @@ -43,7 +43,7 @@ internal sealed partial class ClientOAuthProvider private string? _clientId; private string? _clientSecret; - private TokenContainer? _token; + private ITokenCache _tokenCache; private AuthorizationServerMetadata? _authServerMetadata; /// @@ -57,11 +57,11 @@ internal sealed partial class ClientOAuthProvider public ClientOAuthProvider( Uri serverUrl, ClientOAuthOptions options, - HttpClient? httpClient = null, + HttpClient httpClient, ILoggerFactory? loggerFactory = null) { _serverUrl = serverUrl ?? throw new ArgumentNullException(nameof(serverUrl)); - _httpClient = httpClient ?? new HttpClient(); + _httpClient = httpClient; _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; if (options is null) @@ -85,6 +85,7 @@ public ClientOAuthProvider( _dcrClientUri = options.DynamicClientRegistration?.ClientUri; _dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken; _dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate; + _tokenCache = options.TokenCache ?? new InMemoryTokenCache(); } /// @@ -138,20 +139,21 @@ public ClientOAuthProvider( { ThrowIfNotBearerScheme(scheme); + var tokens = await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false); + // Return the token if it's valid - if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + if (tokens is not null && !tokens.IsExpired) { - return _token.AccessToken; + return tokens.AccessToken; } - // Try to refresh the token if we have a refresh token - if (_token?.RefreshToken != null && _authServerMetadata != null) + // Try to refresh the access token if it is invalid and we have a refresh token. + if (tokens?.RefreshToken != null && _authServerMetadata != null) { - var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); - if (newToken != null) + var newTokens = await RefreshTokenAsync(tokens.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); + if (newTokens is not null) { - _token = newToken; - return _token.AccessToken; + return newTokens.AccessToken; } } @@ -174,12 +176,7 @@ public async Task HandleUnauthorizedResponseAsync( HttpResponseMessage response, CancellationToken cancellationToken = default) { - // This provider only supports Bearer scheme - if (!string.Equals(scheme, BearerScheme, StringComparison.OrdinalIgnoreCase)) - { - throw new InvalidOperationException("This credential provider only supports the Bearer scheme"); - } - + ThrowIfNotBearerScheme(scheme); await PerformOAuthAuthorizationAsync(response, cancellationToken).ConfigureAwait(false); } @@ -223,6 +220,17 @@ private async Task PerformOAuthAuthorizationAsync( // Store auth server metadata for future refresh operations _authServerMetadata = authServerMetadata; + // The existing access token must be invalid to have resulted in a 401 response, but refresh might still work. + if (await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false) is { RefreshToken: {} refreshToken }) + { + var refreshedTokens = await RefreshTokenAsync(refreshToken, protectedResourceMetadata.Resource, authServerMetadata, cancellationToken).ConfigureAwait(false); + if (refreshedTokens is not null) + { + // A non-null result indicates the refresh succeeded and the new tokens have been stored. + return; + } + } + // Perform dynamic client registration if needed if (string.IsNullOrEmpty(_clientId)) { @@ -230,19 +238,11 @@ private async Task PerformOAuthAuthorizationAsync( } // Perform the OAuth flow - var token = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); - - if (token is null) - { - ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); - } + await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); - _token = token; LogOAuthAuthorizationCompleted(); } - private static readonly string[] s_wellKnownPaths = [".well-known/openid-configuration", ".well-known/oauth-authorization-server"]; - private async Task GetAuthServerMetadataAsync(Uri authServerUri, CancellationToken cancellationToken) { if (authServerUri.OriginalString.Length == 0 || @@ -298,7 +298,7 @@ private async Task GetAuthServerMetadataAsync(Uri a throw new McpException($"Failed to find .well-known/openid-configuration or .well-known/oauth-authorization-server metadata for authorization server: '{authServerUri}'"); } - private async Task RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) + private async Task RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) { var requestContent = new FormUrlEncodedContent(new Dictionary { @@ -314,10 +314,17 @@ private async Task RefreshTokenAsync(string refreshToken, Uri re Content = requestContent }; - return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false); + using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); + + if (!httpResponse.IsSuccessStatusCode) + { + return null; + } + + return await HandleSuccessfulTokenResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false); } - private async Task InitiateAuthorizationCodeFlowAsync( + private async Task InitiateAuthorizationCodeFlowAsync( ProtectedResourceMetadata protectedResourceMetadata, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) @@ -330,10 +337,10 @@ private async Task RefreshTokenAsync(string refreshToken, Uri re if (string.IsNullOrEmpty(authCode)) { - return null; + ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty authorization code."); } - return await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false); + await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false); } private Uri BuildAuthorizationUrl( @@ -377,7 +384,7 @@ private Uri BuildAuthorizationUrl( return uriBuilder.Uri; } - private async Task ExchangeCodeForTokenAsync( + private async Task ExchangeCodeForTokenAsync( ProtectedResourceMetadata protectedResourceMetadata, AuthorizationServerMetadata authServerMetadata, string authorizationCode, @@ -400,24 +407,39 @@ private async Task ExchangeCodeForTokenAsync( Content = requestContent }; - return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false); - } - - private async Task FetchTokenAsync(HttpRequestMessage request, CancellationToken cancellationToken) - { using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); httpResponse.EnsureSuccessStatusCode(); + await HandleSuccessfulTokenResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false); + } - using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); - var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenContainer, cancellationToken).ConfigureAwait(false); + private async Task HandleSuccessfulTokenResponseAsync(HttpResponseMessage response, CancellationToken cancellationToken) + { + using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenResponse, cancellationToken).ConfigureAwait(false); if (tokenResponse is null) { - ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{request.RequestUri}' returned an empty response."); + ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{response.RequestMessage?.RequestUri}' returned an empty response."); + } + + if (tokenResponse.TokenType is null || !string.Equals(tokenResponse.TokenType, BearerScheme, StringComparison.OrdinalIgnoreCase)) + { + ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{response.RequestMessage?.RequestUri}' returned an unsupported token type: '{tokenResponse.TokenType ?? ""}'. Only 'Bearer' tokens are supported."); } - tokenResponse.ObtainedAt = DateTimeOffset.UtcNow; - return tokenResponse; + TokenContainer tokens = new() + { + AccessToken = tokenResponse.AccessToken, + RefreshToken = tokenResponse.RefreshToken, + ExpiresIn = tokenResponse.ExpiresIn, + TokenType = tokenResponse.TokenType, + Scope = tokenResponse.Scope, + ObtainedAt = DateTimeOffset.UtcNow, + }; + + await _tokenCache.StoreTokensAsync(tokens, cancellationToken).ConfigureAwait(false); + + return tokens; } /// @@ -581,7 +603,7 @@ private async Task ExtractProtectedResourceMetadata(H string? resourceMetadataUrl = null; foreach (var header in response.Headers.WwwAuthenticate) { - if (string.Equals(header.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter)) + if (string.Equals(header.Scheme, BearerScheme, StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter)) { resourceMetadataUrl = ParseWwwAuthenticateParameters(header.Parameter, "resource_metadata"); if (resourceMetadataUrl != null) diff --git a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs new file mode 100644 index 000000000..3dc6e6351 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Allows the client to cache access tokens beyond the lifetime of the transport. +/// +public interface ITokenCache +{ + /// + /// Cache the token. After a new access token is acquired, this method is invoked to store it. + /// + ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken); + + /// + /// Get the cached token. This method is invoked for every request. + /// + ValueTask GetTokensAsync(CancellationToken cancellationToken); +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs new file mode 100644 index 000000000..977cb6f88 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs @@ -0,0 +1,27 @@ + +namespace ModelContextProtocol.Authentication; + +/// +/// Caches the token in-memory within this instance. +/// +internal class InMemoryTokenCache : ITokenCache +{ + private TokenContainer? _tokens; + + /// + /// Cache the token. + /// + public ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken) + { + _tokens = tokens; + return default; + } + + /// + /// Get the cached token. + /// + public ValueTask GetTokensAsync(CancellationToken cancellationToken) + { + return new ValueTask(_tokens); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs index dc55292b9..8126137f8 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -1,57 +1,39 @@ -using System.Text.Json.Serialization; - namespace ModelContextProtocol.Authentication; /// -/// Represents a token response from the OAuth server. +/// Represents a cacheable combination of tokens ready to be used for authentication. /// -internal sealed class TokenContainer +public sealed class TokenContainer { + /// + /// Gets or sets the token type (typically "Bearer"). + /// + public required string TokenType { get; set; } + /// /// Gets or sets the access token. /// - [JsonPropertyName("access_token")] - public string AccessToken { get; set; } = string.Empty; + public required string AccessToken { get; set; } /// /// Gets or sets the refresh token. /// - [JsonPropertyName("refresh_token")] public string? RefreshToken { get; set; } /// /// Gets or sets the number of seconds until the access token expires. /// - [JsonPropertyName("expires_in")] - public int ExpiresIn { get; set; } - - /// - /// Gets or sets the extended expiration time in seconds. - /// - [JsonPropertyName("ext_expires_in")] - public int ExtExpiresIn { get; set; } - - /// - /// Gets or sets the token type (typically "Bearer"). - /// - [JsonPropertyName("token_type")] - public string TokenType { get; set; } = string.Empty; + public int? ExpiresIn { get; set; } /// /// Gets or sets the scope of the access token. /// - [JsonPropertyName("scope")] - public string Scope { get; set; } = string.Empty; + public string? Scope { get; set; } /// /// Gets or sets the timestamp when the token was obtained. /// - [JsonIgnore] - public DateTimeOffset ObtainedAt { get; set; } + public required DateTimeOffset ObtainedAt { get; set; } - /// - /// Gets the timestamp when the token expires, calculated from ObtainedAt and ExpiresIn. - /// - [JsonIgnore] - public DateTimeOffset ExpiresAt => ObtainedAt.AddSeconds(ExpiresIn); + internal bool IsExpired => ExpiresIn is not null && DateTimeOffset.UtcNow >= ObtainedAt.AddSeconds(ExpiresIn.Value); } diff --git a/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs b/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs new file mode 100644 index 000000000..721196734 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenResponse.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a token response from the OAuth server. +/// +internal sealed class TokenResponse +{ + /// + /// Gets or sets the access token. + /// + [JsonPropertyName("access_token")] + public required string AccessToken { get; set; } + + /// + /// Gets or sets the refresh token. + /// + [JsonPropertyName("refresh_token")] + public string? RefreshToken { get; set; } + + /// + /// Gets or sets the number of seconds until the access token expires. + /// + [JsonPropertyName("expires_in")] + public int? ExpiresIn { get; set; } + + /// + /// Gets or sets the token type (typically "Bearer"). + /// + [JsonPropertyName("token_type")] + public required string TokenType { get; set; } + + /// + /// Gets or sets the scope of the access token. + /// + [JsonPropertyName("scope")] + public string? Scope { get; set; } +} diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index 3d08bd82e..5aff9fb83 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -158,7 +158,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(ProtectedResourceMetadata))] [JsonSerializable(typeof(AuthorizationServerMetadata))] - [JsonSerializable(typeof(TokenContainer))] + [JsonSerializable(typeof(TokenResponse))] [JsonSerializable(typeof(DynamicClientRegistrationRequest))] [JsonSerializable(typeof(DynamicClientRegistrationResponse))] diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthEventTests.cs similarity index 63% rename from tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs rename to tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthEventTests.cs index 9144121e8..001aa82f0 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthEventTests.cs @@ -1,110 +1,43 @@ -using System.Net; -using System.Net.Http.Json; -using System.Text.Json; -using Microsoft.AspNetCore.Authentication.JwtBearer; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.DependencyInjection; -using Microsoft.IdentityModel.Tokens; using ModelContextProtocol.AspNetCore.Authentication; -using ModelContextProtocol.AspNetCore.Tests.Utils; using ModelContextProtocol.Authentication; using ModelContextProtocol.Client; +using System.Net; +using System.Net.Http.Json; -namespace ModelContextProtocol.AspNetCore.Tests; +namespace ModelContextProtocol.AspNetCore.Tests.OAuth; /// /// Tests for MCP authentication when resource metadata is provided via events rather than static configuration. /// -public class AuthEventTests : KestrelInMemoryTest, IAsyncDisposable +public class AuthEventTests : OAuthTestBase { - private const string McpServerUrl = "http://localhost:5000"; - private const string OAuthServerUrl = "https://localhost:7029"; - - private readonly CancellationTokenSource _testCts = new(); - private readonly TestOAuthServer.Program _testOAuthServer; - private readonly Task _testOAuthRunTask; - public AuthEventTests(ITestOutputHelper outputHelper) - : base(outputHelper) + : base(outputHelper, configureMcpMetadata: false) { - // Let the HandleAuthorizationUrlAsync take a look at the Location header - SocketsHttpHandler.AllowAutoRedirect = false; - // The dev cert may not be installed on the CI, but AddJwtBearer requires an HTTPS backchannel by default. - // The easiest workaround is to disable cert validation for testing purposes. - SocketsHttpHandler.SslOptions.RemoteCertificateValidationCallback = (_, _, _, _) => true; - - _testOAuthServer = new TestOAuthServer.Program( - XunitLoggerProvider, - KestrelInMemoryTransport - ); - _testOAuthRunTask = _testOAuthServer.RunServerAsync(cancellationToken: _testCts.Token); + Builder.Services.Configure(McpAuthenticationDefaults.AuthenticationScheme, options => + { + // Note: ResourceMetadata is NOT set here - it will be provided via events + options.ResourceMetadata = null; - Builder - .Services.AddAuthentication(options => - { - options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; - options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; - }) - .AddJwtBearer(options => - { - options.Backchannel = HttpClient; - options.Authority = OAuthServerUrl; - options.TokenValidationParameters = new TokenValidationParameters - { - ValidateIssuer = true, - ValidateAudience = true, - ValidateLifetime = true, - ValidateIssuerSigningKey = true, - ValidAudience = McpServerUrl, - ValidIssuer = OAuthServerUrl, - NameClaimType = "name", - RoleClaimType = "roles", - }; - }) - .AddMcp(options => + options.Events.OnResourceMetadataRequest = async context => { - // Note: ResourceMetadata is NOT set here - it will be provided via events - options.Events.OnResourceMetadataRequest = async context => + // Dynamically provide the resource metadata + context.ResourceMetadata = new ProtectedResourceMetadata { - // Dynamically provide the resource metadata - context.ResourceMetadata = new ProtectedResourceMetadata - { - Resource = new Uri(McpServerUrl), - AuthorizationServers = { new Uri(OAuthServerUrl) }, - ScopesSupported = ["mcp:tools"], - }; - await Task.CompletedTask; + Resource = new Uri(McpServerUrl), + AuthorizationServers = { new Uri(OAuthServerUrl) }, + ScopesSupported = ["mcp:tools"], }; - }); - - Builder.Services.AddAuthorization(); - } - - public async ValueTask DisposeAsync() - { - _testCts.Cancel(); - try - { - await _testOAuthRunTask; - } - catch (OperationCanceledException) { } - finally - { - _testCts.Dispose(); - } + await Task.CompletedTask; + }; + }); } [Fact] public async Task CanAuthenticate_WithResourceMetadataFromEvent() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport( new() @@ -132,13 +65,7 @@ public async Task CanAuthenticate_WithResourceMetadataFromEvent() [Fact] public async Task CanAuthenticate_WithDynamicClientRegistration_FromEvent() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); DynamicClientRegistrationResponse? dcrResponse = null; @@ -181,13 +108,7 @@ public async Task CanAuthenticate_WithDynamicClientRegistration_FromEvent() [Fact] public async Task ResourceMetadataEndpoint_ReturnsCorrectMetadata_FromEvent() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); // Make a direct request to the resource metadata endpoint using var response = await HttpClient.GetAsync( @@ -211,8 +132,6 @@ public async Task ResourceMetadataEndpoint_ReturnsCorrectMetadata_FromEvent() [Fact] public async Task ResourceMetadataEndpoint_CanModifyExistingMetadata_InEvent() { - Builder.Services.AddMcpServer().WithHttpTransport(); - // Override the configuration to test modification of existing metadata Builder.Services.Configure( McpAuthenticationDefaults.AuthenticationScheme, @@ -240,11 +159,7 @@ public async Task ResourceMetadataEndpoint_CanModifyExistingMetadata_InEvent() } ); - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); // Make a direct request to the resource metadata endpoint using var response = await HttpClient.GetAsync( @@ -270,15 +185,12 @@ public async Task ResourceMetadataEndpoint_CanModifyExistingMetadata_InEvent() [Fact] public async Task ResourceMetadataEndpoint_ThrowsException_WhenNoMetadataProvided() { - Builder.Services.AddMcpServer().WithHttpTransport(); - // Override the configuration to test the error case where no metadata is provided Builder.Services.Configure( McpAuthenticationDefaults.AuthenticationScheme, options => { // Don't set ResourceMetadata and provide an event that doesn't set it either - options.ResourceMetadata = null; options.Events.OnResourceMetadataRequest = async context => { // Intentionally don't set context.ResourceMetadata to test error handling @@ -287,11 +199,7 @@ public async Task ResourceMetadataEndpoint_ThrowsException_WhenNoMetadataProvide } ); - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); // Make a direct request to the resource metadata endpoint - this should fail using var response = await HttpClient.GetAsync( @@ -306,14 +214,11 @@ public async Task ResourceMetadataEndpoint_ThrowsException_WhenNoMetadataProvide [Fact] public async Task ResourceMetadataEndpoint_HandlesResponse_WhenHandleResponseCalled() { - Builder.Services.AddMcpServer().WithHttpTransport(); - // Override the configuration to test HandleResponse behavior Builder.Services.Configure( McpAuthenticationDefaults.AuthenticationScheme, options => { - options.ResourceMetadata = null; options.Events.OnResourceMetadataRequest = async context => { // Call HandleResponse() to discontinue processing and return to client @@ -323,11 +228,7 @@ public async Task ResourceMetadataEndpoint_HandlesResponse_WhenHandleResponseCal } ); - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); // Make a direct request to the resource metadata endpoint using var response = await HttpClient.GetAsync( @@ -349,14 +250,11 @@ public async Task ResourceMetadataEndpoint_HandlesResponse_WhenHandleResponseCal [Fact] public async Task ResourceMetadataEndpoint_SkipsHandler_WhenSkipHandlerCalled() { - Builder.Services.AddMcpServer().WithHttpTransport(); - // Override the configuration to test SkipHandler behavior Builder.Services.Configure( McpAuthenticationDefaults.AuthenticationScheme, options => { - options.ResourceMetadata = null; options.Events.OnResourceMetadataRequest = async context => { // Call SkipHandler() to discontinue processing in the current handler @@ -366,11 +264,7 @@ public async Task ResourceMetadataEndpoint_SkipsHandler_WhenSkipHandlerCalled() } ); - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); // Make a direct request to the resource metadata endpoint using var response = await HttpClient.GetAsync( @@ -383,23 +277,4 @@ public async Task ResourceMetadataEndpoint_SkipsHandler_WhenSkipHandlerCalled() // other handlers configured for this endpoint, this should result in a 404 Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); } - - private async Task HandleAuthorizationUrlAsync( - Uri authorizationUri, - Uri redirectUri, - CancellationToken cancellationToken - ) - { - using var redirectResponse = await HttpClient.GetAsync(authorizationUri, cancellationToken); - Assert.Equal(HttpStatusCode.Redirect, redirectResponse.StatusCode); - var location = redirectResponse.Headers.Location; - - if (location is not null && !string.IsNullOrEmpty(location.Query)) - { - var queryParams = QueryHelpers.ParseQuery(location.Query); - return queryParams["code"]; - } - - return null; - } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthTests.cs similarity index 68% rename from tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs rename to tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthTests.cs index fff7d6d42..bed6f8aaa 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/AuthTests.cs @@ -1,101 +1,23 @@ -using Microsoft.AspNetCore.Authentication.JwtBearer; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.WebUtilities; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.IdentityModel.Tokens; using ModelContextProtocol.AspNetCore.Authentication; -using ModelContextProtocol.AspNetCore.Tests.Utils; using ModelContextProtocol.Authentication; using ModelContextProtocol.Client; using System.Net; using System.Reflection; using Xunit.Sdk; -namespace ModelContextProtocol.AspNetCore.Tests; +namespace ModelContextProtocol.AspNetCore.Tests.OAuth; -public class AuthTests : KestrelInMemoryTest, IAsyncDisposable +public class AuthTests : OAuthTestBase { - private const string McpServerUrl = "http://localhost:5000"; - private const string OAuthServerUrl = "https://localhost:7029"; - - private readonly CancellationTokenSource _testCts = new(); - private readonly TestOAuthServer.Program _testOAuthServer; - private readonly Task _testOAuthRunTask; - - private Uri? _lastAuthorizationUri; - - public AuthTests(ITestOutputHelper outputHelper) + public AuthTests(ITestOutputHelper outputHelper) : base(outputHelper) { - // Let the HandleAuthorizationUrlAsync take a look at the Location header - SocketsHttpHandler.AllowAutoRedirect = false; - // The dev cert may not be installed on the CI, but AddJwtBearer requires an HTTPS backchannel by default. - // The easiest workaround is to disable cert validation for testing purposes. - SocketsHttpHandler.SslOptions.RemoteCertificateValidationCallback = (_, _, _, _) => true; - - _testOAuthServer = new TestOAuthServer.Program(XunitLoggerProvider, KestrelInMemoryTransport); - _testOAuthRunTask = _testOAuthServer.RunServerAsync(cancellationToken: _testCts.Token); - - Builder.Services.AddAuthentication(options => - { - options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; - options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; - }) - .AddJwtBearer(options => - { - options.Backchannel = HttpClient; - options.Authority = OAuthServerUrl; - options.TokenValidationParameters = new TokenValidationParameters - { - ValidateIssuer = true, - ValidateAudience = true, - ValidateLifetime = true, - ValidateIssuerSigningKey = true, - ValidAudience = McpServerUrl, - ValidIssuer = OAuthServerUrl, - NameClaimType = "name", - RoleClaimType = "roles" - }; - }) - .AddMcp(options => - { - options.ResourceMetadata = new ProtectedResourceMetadata - { - Resource = new Uri(McpServerUrl), - AuthorizationServers = { new Uri(OAuthServerUrl) }, - ScopesSupported = ["mcp:tools"] - }; - }); - - Builder.Services.AddAuthorization(); - } - - public async ValueTask DisposeAsync() - { - _testCts.Cancel(); - try - { - await _testOAuthRunTask; - } - catch (OperationCanceledException) - { - } - finally - { - _testCts.Dispose(); - } } [Fact] public async Task CanAuthenticate() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -116,13 +38,7 @@ public async Task CanAuthenticate() [Fact] public async Task CannotAuthenticate_WithoutOAuthConfiguration() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -138,13 +54,7 @@ public async Task CannotAuthenticate_WithoutOAuthConfiguration() [Fact] public async Task CannotAuthenticate_WithUnregisteredClient() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -166,13 +76,7 @@ public async Task CannotAuthenticate_WithUnregisteredClient() [Fact] public async Task CanAuthenticate_WithDynamicClientRegistration() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -197,13 +101,7 @@ public async Task CanAuthenticate_WithDynamicClientRegistration() [Fact] public async Task CanAuthenticate_WithTokenRefresh() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -222,19 +120,15 @@ public async Task CanAuthenticate_WithTokenRefresh() await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); - Assert.True(_testOAuthServer.HasIssuedRefreshToken); + Assert.True(TestOAuthServer.HasRefreshedToken); } [Fact] public async Task CanAuthenticate_WithExtraParams() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); + await using var app = await StartMcpServerAsync(); - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + Uri? lastAuthorizationUri = null; await using var transport = new HttpClientTransport(new() { @@ -244,7 +138,11 @@ public async Task CanAuthenticate_WithExtraParams() ClientId = "demo-client", ClientSecret = "demo-secret", RedirectUri = new Uri("http://localhost:1179/callback"), - AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + lastAuthorizationUri = uri; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, AdditionalAuthorizationParameters = new Dictionary { ["custom_param"] = "custom_value", @@ -255,20 +153,14 @@ public async Task CanAuthenticate_WithExtraParams() await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); - Assert.NotNull(_lastAuthorizationUri?.Query); - Assert.Contains("custom_param=custom_value", _lastAuthorizationUri?.Query); + Assert.NotNull(lastAuthorizationUri?.Query); + Assert.Contains("custom_param=custom_value", lastAuthorizationUri?.Query); } [Fact] public async Task CannotOverrideExistingParameters_WithExtraParams() { - Builder.Services.AddMcpServer().WithHttpTransport(); - - await using var app = Builder.Build(); - - app.MapMcp().RequireAuthorization(); - - await app.StartAsync(TestContext.Current.CancellationToken); + await using var app = await StartMcpServerAsync(); await using var transport = new HttpClientTransport(new() { @@ -390,21 +282,4 @@ public void CloneResourceMetadataClonesAllProperties() // Ensure we've checked every property. When new properties get added, we'll have to update this test along with the CloneResourceMetadata implementation. Assert.Empty(propertyNames); } - - private async Task HandleAuthorizationUrlAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken) - { - _lastAuthorizationUri = authorizationUri; - - var redirectResponse = await HttpClient.GetAsync(authorizationUri, cancellationToken); - Assert.Equal(HttpStatusCode.Redirect, redirectResponse.StatusCode); - var location = redirectResponse.Headers.Location; - - if (location is not null && !string.IsNullOrEmpty(location.Query)) - { - var queryParams = QueryHelpers.ParseQuery(location.Query); - return queryParams["code"]; - } - - return null; - } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/OAuthTestBase.cs b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/OAuthTestBase.cs new file mode 100644 index 000000000..834bf83fd --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/OAuthTestBase.cs @@ -0,0 +1,107 @@ +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; +using ModelContextProtocol.AspNetCore.Authentication; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Authentication; +using System.Net; +using Xunit.Sdk; + +namespace ModelContextProtocol.AspNetCore.Tests.OAuth; + +public abstract class OAuthTestBase : KestrelInMemoryTest, IAsyncDisposable +{ + protected const string McpServerUrl = "http://localhost:5000"; + protected const string OAuthServerUrl = "https://localhost:7029"; + + protected readonly CancellationTokenSource TestCts = new(); + protected readonly TestOAuthServer.Program TestOAuthServer; + private readonly Task _testOAuthRunTask; + + protected OAuthTestBase(ITestOutputHelper outputHelper, bool configureMcpMetadata = true) + : base(outputHelper) + { + // Let the HandleAuthorizationUrlAsync take a look at the Location header + SocketsHttpHandler.AllowAutoRedirect = false; + // The dev cert may not be installed on the CI, but AddJwtBearer requires an HTTPS backchannel by default. + // The easiest workaround is to disable cert validation for testing purposes. + SocketsHttpHandler.SslOptions.RemoteCertificateValidationCallback = (_, _, _, _) => true; + + TestOAuthServer = new TestOAuthServer.Program(XunitLoggerProvider, KestrelInMemoryTransport); + _testOAuthRunTask = TestOAuthServer.RunServerAsync(cancellationToken: TestCts.Token); + + Builder.Services.AddAuthentication(options => + { + options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; + }) + .AddJwtBearer(options => + { + options.Backchannel = HttpClient; + options.Authority = OAuthServerUrl; + options.TokenValidationParameters = new TokenValidationParameters + { + ValidAudience = McpServerUrl, + ValidIssuer = OAuthServerUrl, + NameClaimType = "name", + RoleClaimType = "roles" + }; + }) + .AddMcp(options => + { + if (configureMcpMetadata) + { + options.ResourceMetadata = new ProtectedResourceMetadata + { + Resource = new Uri(McpServerUrl), + AuthorizationServers = { new Uri(OAuthServerUrl) }, + ScopesSupported = ["mcp:tools"] + }; + } + }); + + Builder.Services.AddAuthorization(); + Builder.Services.AddMcpServer().WithHttpTransport(); + } + + public async ValueTask DisposeAsync() + { + TestCts.Cancel(); + try + { + await _testOAuthRunTask; + } + catch (OperationCanceledException) + { + } + finally + { + TestCts.Dispose(); + } + } + + protected async Task StartMcpServerAsync() + { + var app = Builder.Build(); + app.MapMcp().RequireAuthorization(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + protected async Task HandleAuthorizationUrlAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken) + { + using var redirectResponse = await HttpClient.GetAsync(authorizationUri, cancellationToken); + Assert.Equal(HttpStatusCode.Redirect, redirectResponse.StatusCode); + var location = redirectResponse.Headers.Location; + + if (location is not null && !string.IsNullOrEmpty(location.Query)) + { + var queryParams = QueryHelpers.ParseQuery(location.Query); + return queryParams["code"]; + } + + return null; + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/TokenCacheTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/TokenCacheTests.cs new file mode 100644 index 000000000..fb9e2bfda --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/TokenCacheTests.cs @@ -0,0 +1,215 @@ +using ModelContextProtocol.Authentication; +using ModelContextProtocol.Client; + +namespace ModelContextProtocol.AspNetCore.Tests.OAuth; + +public class TokenCacheTests : OAuthTestBase +{ + public TokenCacheTests(ITestOutputHelper outputHelper) + : base(outputHelper) + { + } + + [Fact] + public async Task GetTokenAsync_CachedAccessTokenIsUsedForOutgoingRequests() + { + await using var app = await StartMcpServerAsync(); + + var tokenCache = new TestTokenCache(); + bool authDelegateCalledInitially = false; + + await using var setupTransport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + authDelegateCalledInitially = true; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, + TokenCache = tokenCache, + }, + }, HttpClient, LoggerFactory); + + await using (var setupClient = await McpClient.CreateAsync(setupTransport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) + { + // Just connecting should trigger auth and storage. + } + + Assert.True(authDelegateCalledInitially, "AuthorizationRedirectDelegate should be called to get initial token"); + Assert.NotNull(tokenCache.LastStoredToken); + + var authDelegateCalledAgain = false; + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + authDelegateCalledAgain = true; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, + TokenCache = tokenCache + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(authDelegateCalledAgain, "AuthorizationRedirectDelegate should not be called when token is valid"); + } + + [Fact] + public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() + { + await using var app = await StartMcpServerAsync(); + + var tokenCache = new TestTokenCache(); + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + TokenCache = tokenCache + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(tokenCache.LastStoredToken); + Assert.False(string.IsNullOrEmpty(tokenCache.LastStoredToken.AccessToken)); + } + + [Fact] + public async Task GetTokenAsync_InvalidCachedTokenTriggersAuthDelegate() + { + await using var app = await StartMcpServerAsync(); + + var tokenCache = new TestTokenCache(CreateInvalidToken()); + bool authDelegateCalled = false; + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + authDelegateCalled = true; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, + TokenCache = tokenCache, + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.True(authDelegateCalled, "AuthorizationRedirectDelegate should be called when cached token is invalid"); + Assert.NotNull(tokenCache.LastStoredToken); + Assert.NotEqual("invalid-token", tokenCache.LastStoredToken.AccessToken); + } + + [Fact] + public async Task GetTokenAsync_InvalidAccessTokenTriggersRefresh() + { + await using var app = await StartMcpServerAsync(); + + var tokenCache = new TestTokenCache(); + bool authDelegateCalledInitially = false; + + await using var setupTransport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + authDelegateCalledInitially = true; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, + TokenCache = tokenCache, + }, + }, HttpClient, LoggerFactory); + + await using (var setupClient = await McpClient.CreateAsync(setupTransport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) + { + // Just connecting should trigger auth and storage. + } + + Assert.True(authDelegateCalledInitially, "AuthorizationRedirectDelegate should be called to get initial token"); + Assert.False(TestOAuthServer.HasRefreshedToken, "Token should not have been refreshed yet"); + Assert.NotNull(tokenCache.LastStoredToken); + + // Invalidate the access token but keep the refresh token valid (if any) + tokenCache.LastStoredToken.AccessToken = "invalid-token"; + var authDelegateCalledAgain = false; + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = (uri, redirect, ct) => + { + authDelegateCalledAgain = true; + return HandleAuthorizationUrlAsync(uri, redirect, ct); + }, + TokenCache = tokenCache + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(authDelegateCalledAgain, "AuthorizationRedirectDelegate should not be called when refresh token is valid"); + Assert.True(TestOAuthServer.HasRefreshedToken, "Token should have been refreshed"); + Assert.NotEqual("invalid-token", tokenCache.LastStoredToken.AccessToken); + } + + private TokenContainer CreateInvalidToken() + { + return new TokenContainer + { + TokenType = "Bearer", + AccessToken = "invalid-token", + ObtainedAt = DateTimeOffset.UtcNow, + }; + } + + private class TestTokenCache(TokenContainer? initialToken = null) : ITokenCache + { + public TokenContainer? LastStoredToken { get; private set; } = initialToken; + + public ValueTask GetTokensAsync(CancellationToken cancellationToken) + { + return new ValueTask(LastStoredToken); + } + + public ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken) + { + LastStoredToken = tokens; + return ValueTask.CompletedTask; + } + } +} diff --git a/tests/ModelContextProtocol.TestOAuthServer/Program.cs b/tests/ModelContextProtocol.TestOAuthServer/Program.cs index dea484bfe..2a67a0c9c 100644 --- a/tests/ModelContextProtocol.TestOAuthServer/Program.cs +++ b/tests/ModelContextProtocol.TestOAuthServer/Program.cs @@ -42,7 +42,7 @@ public Program(ILoggerProvider? loggerProvider = null, IConnectionListenerFactor // Track if we've already issued an already-expired token for the CanAuthenticate_WithTokenRefresh test which uses the test-refresh-client registration. public bool HasIssuedExpiredToken { get; set; } - public bool HasIssuedRefreshToken { get; set; } + public bool HasRefreshedToken { get; set; } /// /// Entry point for the application. @@ -368,7 +368,7 @@ public async Task RunServerAsync(string[]? args = null, CancellationToken cancel _tokens.TryRemove(refresh_token, out _); } - HasIssuedRefreshToken = true; + HasRefreshedToken = true; return Results.Ok(response); } else