|
| 1 | +package oauth |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "encoding/json" |
| 7 | + "fmt" |
| 8 | + "io" |
| 9 | + "net/http" |
| 10 | +) |
| 11 | + |
| 12 | +// PerformDCR performs Dynamic Client Registration with the authorization server |
| 13 | +// Returns client credentials for the registered public client |
| 14 | +// |
| 15 | +// RFC 7591 COMPLIANCE: |
| 16 | +// - Uses token_endpoint_auth_method="none" for public clients |
| 17 | +// - Includes redirect_uris pointing to mcp-oauth proxy |
| 18 | +// - Requests authorization_code and refresh_token grant types |
| 19 | +func PerformDCR(ctx context.Context, discovery *Discovery, serverName string) (*ClientCredentials, error) { |
| 20 | + if discovery.RegistrationEndpoint == "" { |
| 21 | + return nil, fmt.Errorf("no registration endpoint found for %s", serverName) |
| 22 | + } |
| 23 | + |
| 24 | + // Build DCR request for PUBLIC client |
| 25 | + registration := DCRRequest{ |
| 26 | + ClientName: fmt.Sprintf("MCP Gateway - %s", serverName), |
| 27 | + RedirectURIs: []string{ |
| 28 | + "https://mcp.docker.com/oauth/callback", // mcp-oauth proxy callback only |
| 29 | + }, |
| 30 | + TokenEndpointAuthMethod: "none", // PUBLIC client (no client secret) |
| 31 | + GrantTypes: []string{"authorization_code", "refresh_token"}, |
| 32 | + ResponseTypes: []string{"code"}, |
| 33 | + |
| 34 | + // Additional metadata for better client identification |
| 35 | + ClientURI: "https://github.com/docker/mcp-gateway", |
| 36 | + SoftwareID: "mcp-gateway", |
| 37 | + SoftwareVersion: "1.0.0", |
| 38 | + Contacts: [] string{ "[email protected]"}, |
| 39 | + } |
| 40 | + |
| 41 | + // Add requested scopes if provided |
| 42 | + if len(discovery.Scopes) > 0 { |
| 43 | + registration.Scope = joinScopes(discovery.Scopes) |
| 44 | + } else { |
| 45 | + } |
| 46 | + |
| 47 | + // Marshal the registration request |
| 48 | + body, err := json.Marshal(registration) |
| 49 | + if err != nil { |
| 50 | + return nil, fmt.Errorf("failed to marshal DCR request: %w", err) |
| 51 | + } |
| 52 | + |
| 53 | + // Create HTTP request |
| 54 | + req, err := http.NewRequestWithContext(ctx, http.MethodPost, discovery.RegistrationEndpoint, bytes.NewReader(body)) |
| 55 | + if err != nil { |
| 56 | + return nil, fmt.Errorf("failed to create DCR request: %w", err) |
| 57 | + } |
| 58 | + |
| 59 | + req.Header.Set("Content-Type", "application/json") |
| 60 | + req.Header.Set("Accept", "application/json") |
| 61 | + req.Header.Set("User-Agent", "MCP-Gateway/1.0.0") |
| 62 | + |
| 63 | + // Send the request |
| 64 | + client := &http.Client{} |
| 65 | + resp, err := client.Do(req) |
| 66 | + if err != nil { |
| 67 | + return nil, fmt.Errorf("failed to send DCR request to %s: %w", discovery.RegistrationEndpoint, err) |
| 68 | + } |
| 69 | + defer resp.Body.Close() |
| 70 | + |
| 71 | + // Check response status (201 Created or 200 OK are acceptable) |
| 72 | + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { |
| 73 | + // Read error response body to understand why DCR failed |
| 74 | + errorBody, err := io.ReadAll(resp.Body) |
| 75 | + if err != nil { |
| 76 | + return nil, fmt.Errorf("DCR failed with status %d for %s", resp.StatusCode, serverName) |
| 77 | + } |
| 78 | + |
| 79 | + errorMsg := string(errorBody) |
| 80 | + |
| 81 | + // Try to parse as JSON for structured error |
| 82 | + var errorResp map[string]any |
| 83 | + if err := json.Unmarshal(errorBody, &errorResp); err == nil { |
| 84 | + // Successfully parsed as JSON - look for common error fields |
| 85 | + if errDesc, ok := errorResp["error_description"].(string); ok { |
| 86 | + errorMsg = errDesc |
| 87 | + } else if errField, ok := errorResp["error"].(string); ok { |
| 88 | + errorMsg = errField |
| 89 | + } else if message, ok := errorResp["message"].(string); ok { |
| 90 | + errorMsg = message |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + return nil, fmt.Errorf("DCR failed with status %d for %s: %s", resp.StatusCode, serverName, errorMsg) |
| 95 | + } |
| 96 | + |
| 97 | + // Parse the response |
| 98 | + var dcrResponse DCRResponse |
| 99 | + if err := json.NewDecoder(resp.Body).Decode(&dcrResponse); err != nil { |
| 100 | + return nil, fmt.Errorf("failed to decode DCR response: %w", err) |
| 101 | + } |
| 102 | + |
| 103 | + if dcrResponse.ClientID == "" { |
| 104 | + return nil, fmt.Errorf("DCR response missing client_id for %s", serverName) |
| 105 | + } |
| 106 | + |
| 107 | + // Create client credentials (public client - no secret) |
| 108 | + creds := &ClientCredentials{ |
| 109 | + ClientID: dcrResponse.ClientID, |
| 110 | + ServerURL: discovery.ResourceURL, |
| 111 | + IsPublic: true, |
| 112 | + AuthorizationEndpoint: discovery.AuthorizationEndpoint, |
| 113 | + TokenEndpoint: discovery.TokenEndpoint, |
| 114 | + // No ClientSecret for public clients |
| 115 | + } |
| 116 | + |
| 117 | + return creds, nil |
| 118 | +} |
| 119 | + |
| 120 | +// joinScopes joins a slice of scopes into a space-separated string |
| 121 | +// per OAuth 2.0 specification (RFC 6749 Section 3.3) |
| 122 | +func joinScopes(scopes []string) string { |
| 123 | + if len(scopes) == 0 { |
| 124 | + return "" |
| 125 | + } |
| 126 | + |
| 127 | + // Use simple string concatenation for small arrays |
| 128 | + result := scopes[0] |
| 129 | + for i := 1; i < len(scopes); i++ { |
| 130 | + result += " " + scopes[i] |
| 131 | + } |
| 132 | + return result |
| 133 | +} |
0 commit comments