diff --git a/dcr.go b/dcr.go index ec66d91..a56d09c 100644 --- a/dcr.go +++ b/dcr.go @@ -7,10 +7,39 @@ import ( "fmt" "io" "net/http" + "net/url" ) const DefaultRedirectURI = "https://mcp.docker.com/oauth/callback" +// isValidRedirectURI validates that the redirect URI is allowed for this library +// Only localhost and mcp.docker.com are permitted for security +func isValidRedirectURI(redirectURI string) error { + if redirectURI == "" { + return nil // Empty is OK (will use default) + } + + parsed, err := url.Parse(redirectURI) + if err != nil { + return fmt.Errorf("invalid redirect URI format: %w", err) + } + + // Extract hostname (handles ports automatically) + hostname := parsed.Hostname() + + // Allow localhost variations + if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" { + return nil + } + + // Allow mcp.docker.com (production) + if hostname == "mcp.docker.com" { + return nil + } + + return fmt.Errorf("redirect URI host %q not allowed - must be localhost or mcp.docker.com", hostname) +} + // PerformDCR performs Dynamic Client Registration with the authorization server // Returns client credentials for the registered public client // @@ -18,15 +47,27 @@ const DefaultRedirectURI = "https://mcp.docker.com/oauth/callback" // - Uses token_endpoint_auth_method="none" for public clients // - Includes redirect_uris pointing to mcp-oauth proxy // - Requests authorization_code and refresh_token grant types -func PerformDCR(ctx context.Context, discovery *Discovery, serverName string) (*ClientCredentials, error) { +// +// redirectURI: The OAuth callback URI to register. If empty, uses DefaultRedirectURI. +func PerformDCR(ctx context.Context, discovery *Discovery, serverName string, redirectURI string) (*ClientCredentials, error) { if discovery.RegistrationEndpoint == "" { return nil, fmt.Errorf("no registration endpoint found for %s", serverName) } + // Validate redirect URI for security (only localhost or mcp.docker.com allowed) + if err := isValidRedirectURI(redirectURI); err != nil { + return nil, fmt.Errorf("invalid redirect URI: %w", err) + } + + // Use provided redirectURI, fallback to default if empty + if redirectURI == "" { + redirectURI = DefaultRedirectURI + } + // Build DCR request for PUBLIC client registration := DCRRequest{ ClientName: fmt.Sprintf("MCP Gateway - %s", serverName), - RedirectURIs: []string{DefaultRedirectURI}, + RedirectURIs: []string{redirectURI}, TokenEndpointAuthMethod: "none", // PUBLIC client (no client secret) GrantTypes: []string{"authorization_code", "refresh_token"}, ResponseTypes: []string{"code"}, diff --git a/dcr_test.go b/dcr_test.go index b03f016..64a1a20 100644 --- a/dcr_test.go +++ b/dcr_test.go @@ -39,8 +39,8 @@ func TestPerformDCR_PublicClient(t *testing.T) { Scopes: []string{"read", "write"}, } - // Perform DCR - creds, err := PerformDCR(context.Background(), discovery, "test-server") + // Perform DCR (empty redirectURI uses default) + creds, err := PerformDCR(context.Background(), discovery, "test-server", "") // Verify no error if err != nil { t.Fatalf("DCR failed: %v", err) @@ -82,8 +82,8 @@ func TestPerformDCR_NoRegistrationEndpoint(t *testing.T) { RegistrationEndpoint: "", // Empty - DCR not supported } - // Attempt DCR - creds, err := PerformDCR(context.Background(), discovery, "test-server") + // Attempt DCR (empty redirectURI uses default) + creds, err := PerformDCR(context.Background(), discovery, "test-server", "") // Verify error occurred if err == nil { @@ -93,3 +93,86 @@ func TestPerformDCR_NoRegistrationEndpoint(t *testing.T) { t.Error("Expected nil credentials on error") } } + +// TestIsValidRedirectURI verifies redirect URI validation logic +func TestIsValidRedirectURI(t *testing.T) { + tests := []struct { + name string + redirectURI string + expectError bool + description string + }{ + { + name: "empty string", + redirectURI: "", + expectError: false, + description: "Empty string should be allowed (uses default)", + }, + { + name: "localhost http", + redirectURI: "http://localhost:5000/callback", + expectError: false, + description: "Localhost with HTTP should be allowed", + }, + { + name: "localhost https", + redirectURI: "https://localhost:5000/callback", + expectError: false, + description: "Localhost with HTTPS should be allowed", + }, + { + name: "127.0.0.1", + redirectURI: "http://127.0.0.1:8080/callback", + expectError: false, + description: "127.0.0.1 should be allowed", + }, + { + name: "IPv6 localhost", + redirectURI: "http://[::1]:8080/callback", + expectError: false, + description: "IPv6 localhost should be allowed", + }, + { + name: "mcp.docker.com production", + redirectURI: "https://mcp.docker.com/oauth/callback", + expectError: false, + description: "Production mcp.docker.com should be allowed", + }, + { + name: "evil domain", + redirectURI: "https://evil.com/callback", + expectError: true, + description: "Arbitrary domains should be blocked", + }, + { + name: "attacker ngrok", + redirectURI: "https://attacker.ngrok.io/callback", + expectError: true, + description: "Attacker-controlled domains should be blocked", + }, + { + name: "subdomain of docker.com", + redirectURI: "https://evil.docker.com/callback", + expectError: true, + description: "Only mcp.docker.com should be allowed, not subdomains", + }, + { + name: "invalid URL", + redirectURI: "not-a-valid-url", + expectError: true, + description: "Invalid URL format should be rejected", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isValidRedirectURI(tt.redirectURI) + if tt.expectError && err == nil { + t.Errorf("Expected error for %q (%s)", tt.redirectURI, tt.description) + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error for %q: %v (%s)", tt.redirectURI, err, tt.description) + } + }) + } +}