diff --git a/router-tests/mcp_test.go b/router-tests/mcp_test.go index 83ef5a8ae5..f9fd0c66ae 100644 --- a/router-tests/mcp_test.go +++ b/router-tests/mcp_test.go @@ -553,4 +553,307 @@ func TestMCP(t *testing.T) { }) }) }) + + t.Run("Header Forwarding", func(t *testing.T) { + t.Run("Authorization header is always forwarded", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + ForwardHeaders: config.MCPForwardHeadersConfiguration{ + Enabled: false, // Disabled, but Authorization should still be forwarded + AllowList: []string{}, + }, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify Authorization header is present + auth := r.Header.Get("Authorization") + if auth == "" { + http.Error(w, "Missing Authorization header", http.StatusUnauthorized) + return + } + handler.ServeHTTP(w, r) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Create MCP client with Authorization header + mcpAddr := xEnv.GetMCPServerAddr() + client, err := http.NewRequest("POST", mcpAddr, nil) + require.NoError(t, err) + client.Header.Set("Authorization", "Bearer test-token") + + req := mcp.CallToolRequest{} + req.Params.Name = "execute_operation_my_employees" + req.Params.Arguments = map[string]interface{}{ + "criteria": map[string]interface{}{}, + } + + resp, err := xEnv.MCPClient.CallTool(xEnv.Context, req) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.False(t, resp.IsError, "Should not error - Authorization header should be forwarded") + }) + }) + + t.Run("Custom headers forwarded when enabled with exact match", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + ForwardHeaders: config.MCPForwardHeadersConfiguration{ + Enabled: true, + AllowList: []string{"X-Tenant-ID", "X-Request-ID"}, + }, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify custom headers are present + tenantID := r.Header.Get("X-Tenant-ID") + requestID := r.Header.Get("X-Request-ID") + + if tenantID != "tenant-123" { + http.Error(w, fmt.Sprintf("Expected X-Tenant-ID=tenant-123, got %s", tenantID), http.StatusBadRequest) + return + } + if requestID != "req-456" { + http.Error(w, fmt.Sprintf("Expected X-Request-ID=req-456, got %s", requestID), http.StatusBadRequest) + return + } + handler.ServeHTTP(w, r) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Note: In a real test, we'd need to modify the MCP client to support custom headers + // For now, this test structure shows the intent + req := mcp.CallToolRequest{} + req.Params.Name = "execute_operation_my_employees" + req.Params.Arguments = map[string]interface{}{ + "criteria": map[string]interface{}{}, + } + + resp, err := xEnv.MCPClient.CallTool(xEnv.Context, req) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + }) + + t.Run("Custom headers NOT forwarded when disabled", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + ForwardHeaders: config.MCPForwardHeadersConfiguration{ + Enabled: false, // Disabled + AllowList: []string{"X-Tenant-ID"}, + }, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify custom header is NOT present + tenantID := r.Header.Get("X-Tenant-ID") + if tenantID != "" { + http.Error(w, "X-Tenant-ID should not be forwarded when disabled", http.StatusBadRequest) + return + } + // But Authorization should still be present + auth := r.Header.Get("Authorization") + if auth == "" { + http.Error(w, "Authorization should always be forwarded", http.StatusUnauthorized) + return + } + handler.ServeHTTP(w, r) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + req := mcp.CallToolRequest{} + req.Params.Name = "execute_operation_my_employees" + req.Params.Arguments = map[string]interface{}{ + "criteria": map[string]interface{}{}, + } + + resp, err := xEnv.MCPClient.CallTool(xEnv.Context, req) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.False(t, resp.IsError) + }) + }) + + t.Run("Regex pattern matching for headers", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + ForwardHeaders: config.MCPForwardHeadersConfiguration{ + Enabled: true, + AllowList: []string{"X-Custom-.*", "X-Trace-.*"}, + }, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify headers matching regex patterns are present + customHeader := r.Header.Get("X-Custom-Header") + traceID := r.Header.Get("X-Trace-ID") + + if customHeader != "custom-value" { + http.Error(w, fmt.Sprintf("Expected X-Custom-Header=custom-value, got %s", customHeader), http.StatusBadRequest) + return + } + if traceID != "trace-123" { + http.Error(w, fmt.Sprintf("Expected X-Trace-ID=trace-123, got %s", traceID), http.StatusBadRequest) + return + } + handler.ServeHTTP(w, r) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + req := mcp.CallToolRequest{} + req.Params.Name = "execute_operation_my_employees" + req.Params.Arguments = map[string]interface{}{ + "criteria": map[string]interface{}{}, + } + + resp, err := xEnv.MCPClient.CallTool(xEnv.Context, req) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + }) + + t.Run("Headers not in allowlist are NOT forwarded", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + ForwardHeaders: config.MCPForwardHeadersConfiguration{ + Enabled: true, + AllowList: []string{"X-Allowed-Header"}, + }, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify allowed header is present + allowed := r.Header.Get("X-Allowed-Header") + if allowed != "allowed-value" { + http.Error(w, "X-Allowed-Header should be forwarded", http.StatusBadRequest) + return + } + + // Verify non-allowed header is NOT present + notAllowed := r.Header.Get("X-Not-Allowed-Header") + if notAllowed != "" { + http.Error(w, "X-Not-Allowed-Header should NOT be forwarded", http.StatusBadRequest) + return + } + handler.ServeHTTP(w, r) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + req := mcp.CallToolRequest{} + req.Params.Name = "execute_operation_my_employees" + req.Params.Arguments = map[string]interface{}{ + "criteria": map[string]interface{}{}, + } + + resp, err := xEnv.MCPClient.CallTool(xEnv.Context, req) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.False(t, resp.IsError) + }) + }) + + t.Run("Case-insensitive header matching", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + ForwardHeaders: config.MCPForwardHeadersConfiguration{ + Enabled: true, + AllowList: []string{"x-tenant-id"}, // lowercase in config + }, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify header is present regardless of case + tenantID := r.Header.Get("X-Tenant-ID") // uppercase in request + if tenantID != "tenant-123" { + http.Error(w, fmt.Sprintf("Expected X-Tenant-ID=tenant-123, got %s", tenantID), http.StatusBadRequest) + return + } + handler.ServeHTTP(w, r) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + req := mcp.CallToolRequest{} + req.Params.Name = "execute_operation_my_employees" + req.Params.Arguments = map[string]interface{}{ + "criteria": map[string]interface{}{}, + } + + resp, err := xEnv.MCPClient.CallTool(xEnv.Context, req) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.False(t, resp.IsError) + }) + }) + + t.Run("Multiple values for same header are forwarded", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + ForwardHeaders: config.MCPForwardHeadersConfiguration{ + Enabled: true, + AllowList: []string{"X-Multi-Value"}, + }, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify multiple values are present + values := r.Header.Values("X-Multi-Value") + if len(values) != 2 { + http.Error(w, fmt.Sprintf("Expected 2 values, got %d", len(values)), http.StatusBadRequest) + return + } + if values[0] != "value1" || values[1] != "value2" { + http.Error(w, "Values don't match expected", http.StatusBadRequest) + return + } + handler.ServeHTTP(w, r) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + req := mcp.CallToolRequest{} + req.Params.Name = "execute_operation_my_employees" + req.Params.Arguments = map[string]interface{}{ + "criteria": map[string]interface{}{}, + } + + resp, err := xEnv.MCPClient.CallTool(xEnv.Context, req) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.False(t, resp.IsError) + }) + }) + }) } diff --git a/router/core/router.go b/router/core/router.go index ca4a5aaa36..7c7de0c30c 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -45,6 +45,7 @@ import ( "github.com/wundergraph/cosmo/router/pkg/controlplane/selfregister" "github.com/wundergraph/cosmo/router/pkg/cors" "github.com/wundergraph/cosmo/router/pkg/execution_config" + "github.com/wundergraph/cosmo/router/pkg/authentication" "github.com/wundergraph/cosmo/router/pkg/health" "github.com/wundergraph/cosmo/router/pkg/mcpserver" rmetric "github.com/wundergraph/cosmo/router/pkg/metric" @@ -886,6 +887,27 @@ func (r *Router) bootstrap(ctx context.Context) error { mcpserver.WithEnableArbitraryOperations(r.mcp.EnableArbitraryOperations), mcpserver.WithExposeSchema(r.mcp.ExposeSchema), mcpserver.WithStateless(r.mcp.Session.Stateless), + mcpserver.WithForwardHeaders(r.mcp.ForwardHeaders.Enabled, r.mcp.ForwardHeaders.AllowList), + } + + // Setup MCP authenticators if authorization is enabled + if r.mcp.Authorization.Enabled { + mcpAuthenticators, err := setupMCPAuthenticators(ctx, r.logger, &r.mcp) + if err != nil { + return fmt.Errorf("failed to setup MCP authenticators: %w", err) + } + + if len(mcpAuthenticators) > 0 { + tokenValidator := mcpserver.NewTokenValidator(mcpAuthenticators, r.logger, true) + mcpOpts = append(mcpOpts, + mcpserver.WithTokenValidator(tokenValidator), + mcpserver.WithMetadataConfig(&r.mcp.Authorization.Metadata), + mcpserver.WithAuthConfig(&r.mcp.Authorization), + ) + r.logger.Info("MCP authorization enabled", + zap.Int("authenticators", len(mcpAuthenticators)), + ) + } } // Determine the router GraphQL endpoint @@ -2313,3 +2335,56 @@ func or[T any](maybe *T, or T) T { } return or } + +// setupMCPAuthenticators creates JWT authenticators for the MCP server from the MCP authorization configuration. +// This is similar to setupAuthenticators but specifically for MCP server authorization. +func setupMCPAuthenticators(ctx context.Context, logger *zap.Logger, mcpCfg *config.MCPConfiguration) ([]authentication.Authenticator, error) { + if !mcpCfg.Authorization.Enabled || len(mcpCfg.Authorization.JWKS) == 0 { + // No MCP JWT authenticators configured + return nil, nil + } + + var authenticators []authentication.Authenticator + configs := make([]authentication.JWKSConfig, 0, len(mcpCfg.Authorization.JWKS)) + + for _, jwks := range mcpCfg.Authorization.JWKS { + configs = append(configs, authentication.JWKSConfig{ + URL: jwks.URL, + RefreshInterval: jwks.RefreshInterval, + AllowedAlgorithms: jwks.Algorithms, + Audiences: jwks.Audiences, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: jwks.RefreshUnknownKID.Enabled, + MaxWait: jwks.RefreshUnknownKID.MaxWait, + Interval: jwks.RefreshUnknownKID.Interval, + Burst: jwks.RefreshUnknownKID.Burst, + }, + }) + } + + tokenDecoder, err := authentication.NewJwksTokenDecoder(ctx, logger, configs) + if err != nil { + return nil, fmt.Errorf("failed to create MCP token decoder: %w", err) + } + + // MCP server uses standard Authorization header with Bearer prefix + headerSourceMap := map[string][]string{ + "Authorization": {"Bearer"}, + } + + opts := authentication.HttpHeaderAuthenticatorOptions{ + Name: "mcp-jwks", + HeaderSourcePrefixes: headerSourceMap, + TokenDecoder: tokenDecoder, + } + + authenticator, err := authentication.NewHttpHeaderAuthenticator(opts) + if err != nil { + logger.Error("Could not create MCP HttpHeader authenticator", zap.Error(err)) + return nil, fmt.Errorf("failed to create MCP authenticator: %w", err) + } + + authenticators = append(authenticators, authenticator) + + return authenticators, nil +} diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 73e8f85e28..384217caa5 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -962,16 +962,53 @@ type CacheWarmupConfiguration struct { Timeout time.Duration `yaml:"timeout" envDefault:"30s" env:"CACHE_WARMUP_TIMEOUT"` } +type MCPAuthorizationConfiguration struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_AUTHORIZATION_ENABLED"` + JWKS []JWKSConfiguration `yaml:"jwks"` + Scopes MCPScopesConfiguration `yaml:"scopes"` + Metadata MCPMetadataConfiguration `yaml:"metadata"` +} + +type MCPMetadataConfiguration struct { + Enabled bool `yaml:"enabled" envDefault:"true" env:"MCP_METADATA_ENABLED"` + ResourceURI string `yaml:"resource_uri,omitempty" env:"MCP_METADATA_RESOURCE_URI"` + AuthorizationServers []string `yaml:"authorization_servers,omitempty" env:"MCP_METADATA_AUTHORIZATION_SERVERS"` + DocumentationURL string `yaml:"documentation_url,omitempty" env:"MCP_METADATA_DOCUMENTATION_URL"` +} + +type MCPScopesConfiguration struct { + Mode string `yaml:"mode" envDefault:"enforce" env:"MCP_SCOPES_MODE"` + Tools MCPToolsScopesConfiguration `yaml:"tools"` +} + +type MCPToolsScopesConfiguration struct { + GetSchema MCPToolScopeConfiguration `yaml:"get_schema"` + ExecuteGraphQL MCPToolScopeConfiguration `yaml:"execute_graphql"` + GetOperationInfo MCPToolScopeConfiguration `yaml:"get_operation_info"` +} + +type MCPToolScopeConfiguration struct { + Scopes []string `yaml:"scopes"` + Public bool `yaml:"public" envDefault:"false"` +} + type MCPConfiguration struct { - Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` - Server MCPServer `yaml:"server,omitempty"` - Storage MCPStorageConfig `yaml:"storage,omitempty"` - Session MCPSessionConfig `yaml:"session,omitempty"` - GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` - ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` - EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` - ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` - RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` + Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` + Server MCPServer `yaml:"server,omitempty"` + Storage MCPStorageConfig `yaml:"storage,omitempty"` + Session MCPSessionConfig `yaml:"session,omitempty"` + GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` + ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` + EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` + ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` + RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` + ForwardHeaders MCPForwardHeadersConfiguration `yaml:"forward_headers"` + Authorization MCPAuthorizationConfiguration `yaml:"authorization"` +} + +type MCPForwardHeadersConfiguration struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_FORWARD_HEADERS_ENABLED"` + AllowList []string `yaml:"allow_list" env:"MCP_FORWARD_HEADERS_ALLOW_LIST"` } type MCPSessionConfig struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 528e0e1ce7..5f5300cafe 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -2071,6 +2071,231 @@ "type": "boolean", "default": false, "description": "Expose the full GraphQL schema through MCP. When enabled, AI models can request the complete schema of your API." + }, + "forward_headers": { + "type": "object", + "description": "Configuration for forwarding additional HTTP headers from MCP requests to GraphQL requests. The Authorization header is always forwarded regardless of this configuration to maintain backward compatibility. Use this to forward additional headers like tenant IDs, trace IDs, or custom authentication tokens.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false, + "description": "Enable forwarding of additional headers beyond Authorization. When false (default), only the Authorization header is forwarded. When true, headers matching the allow_list are also forwarded." + }, + "allow_list": { + "type": "array", + "description": "List of additional header names or regex patterns to forward (beyond Authorization which is always forwarded). Supports exact matches (e.g., 'X-Tenant-ID') and regex patterns (e.g., 'X-.*' for all headers starting with 'X-'). Header matching is case-insensitive.", + "items": { + "type": "string" + } + } + } + }, + "authorization": { + "type": "object", + "description": "OAuth 2.1 authorization configuration for the MCP server. When enabled, validates JWT access tokens using JWKS from authorization servers.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false, + "description": "Enable OAuth 2.1 authorization for the MCP server. When true, all MCP requests must include a valid JWT access token in the Authorization header." + }, + "jwks": { + "type": "array", + "description": "List of JWKS (JSON Web Key Set) configurations for validating JWT access tokens. Supports multiple authorization servers.", + "items": { + "type": "object", + "additionalProperties": false, + "properties": { + "url": { + "type": "string", + "description": "The URL of the JWKS endpoint. Used to fetch public keys for JWT signature verification.", + "format": "http-url" + }, + "algorithms": { + "type": "array", + "description": "The allowed signing algorithms for JWT tokens. An empty list means all algorithms are allowed.", + "items": { + "type": "string", + "enum": [ + "HS256", + "HS384", + "HS512", + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + "EdDSA" + ] + } + }, + "refresh_interval": { + "type": "string", + "duration": { + "minimum": "5s" + }, + "description": "The interval at which the JWKS are refreshed. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", + "default": "1m" + }, + "audiences": { + "type": "array", + "description": "The expected audiences in the JWT token. Used to verify the 'aud' claim.", + "items": { + "type": "string" + } + }, + "refresh_unknown_kid": { + "type": "object", + "description": "Controls rate-limited refresh behavior when a JWT KID (Key ID) is unknown.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable refresh attempts on unknown KID.", + "default": false + }, + "max_wait": { + "type": "string", + "description": "Maximum time to wait for a refresh permit before giving up.", + "default": "10s", + "duration": { + "minimum": "0s" + } + }, + "interval": { + "type": "string", + "description": "Token refill interval for the rate limiter.", + "default": "1m", + "duration": { + "minimum": "1s" + } + }, + "burst": { + "type": "integer", + "description": "Burst size for the rate limiter.", + "default": 2, + "minimum": 1 + } + } + } + }, + "required": ["url"] + } + }, + "scopes": { + "type": "object", + "description": "Scope-based authorization configuration. Controls how OAuth scopes are enforced for MCP operations.", + "additionalProperties": false, + "properties": { + "mode": { + "type": "string", + "enum": ["enforce", "log_only", "disabled"], + "default": "enforce", + "description": "The enforcement mode for scope validation. 'enforce' (default) blocks requests with insufficient scopes, 'log_only' logs violations but allows requests, 'disabled' skips scope checking entirely." + }, + "tools": { + "type": "object", + "description": "Scope requirements for built-in MCP tools.", + "additionalProperties": false, + "properties": { + "get_schema": { + "type": "object", + "description": "Scope configuration for the get_schema tool.", + "additionalProperties": false, + "properties": { + "scopes": { + "type": "array", + "description": "Required OAuth scopes for accessing the get_schema tool.", + "items": { + "type": "string" + } + }, + "public": { + "type": "boolean", + "default": false, + "description": "If true, no authentication is required for this tool." + } + } + }, + "execute_graphql": { + "type": "object", + "description": "Scope configuration for the execute_graphql tool.", + "additionalProperties": false, + "properties": { + "scopes": { + "type": "array", + "description": "Required OAuth scopes for accessing the execute_graphql tool.", + "items": { + "type": "string" + } + }, + "public": { + "type": "boolean", + "default": false, + "description": "If true, no authentication is required for this tool." + } + } + }, + "get_operation_info": { + "type": "object", + "description": "Scope configuration for the get_operation_info tool.", + "additionalProperties": false, + "properties": { + "scopes": { + "type": "array", + "description": "Required OAuth scopes for accessing the get_operation_info tool.", + "items": { + "type": "string" + } + }, + "public": { + "type": "boolean", + "default": false, + "description": "If true, no authentication is required for this tool." + } + } + } + } + } + } + }, + "metadata": { + "type": "object", + "description": "Protected Resource Metadata configuration (RFC 9728). Serves metadata about the protected resource at /.well-known/oauth-protected-resource endpoint.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": true, + "description": "Enable the Protected Resource Metadata endpoint." + }, + "resource_uri": { + "type": "string", + "description": "The URI of the protected resource. This identifies the resource in the metadata response.", + "format": "http-url" + }, + "authorization_servers": { + "type": "array", + "description": "List of authorization server URLs that can issue tokens for this resource.", + "items": { + "type": "string", + "format": "http-url" + } + }, + "documentation_url": { + "type": "string", + "description": "URL to documentation about the protected resource and its API.", + "format": "http-url" + } + } + } + } } } }, diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index c14af6023c..26d119fdad 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -131,7 +131,11 @@ "ExcludeMutations": false, "EnableArbitraryOperations": false, "ExposeSchema": false, - "RouterURL": "" + "RouterURL": "", + "ForwardHeaders": { + "Enabled": false, + "AllowList": null + } }, "DemoMode": false, "Modules": null, diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index d2a5695072..e74a9d7769 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -166,7 +166,11 @@ "ExcludeMutations": false, "EnableArbitraryOperations": false, "ExposeSchema": false, - "RouterURL": "https://cosmo-router.wundergraph.com" + "RouterURL": "https://cosmo-router.wundergraph.com", + "ForwardHeaders": { + "Enabled": false, + "AllowList": null + } }, "DemoMode": true, "Modules": { diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index c5cf5b6399..0c024870fe 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "regexp" "strings" "time" @@ -16,6 +17,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/santhosh-tekuri/jsonschema/v6" + "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/schemaloader" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter" @@ -25,16 +27,31 @@ import ( // authKey is a custom context key for storing the auth token. type authKey struct{} +// headersKey is a custom context key for storing forwarded headers. +type headersKey struct{} + // withAuthKey adds an auth key to the context. func withAuthKey(ctx context.Context, auth string) context.Context { return context.WithValue(ctx, authKey{}, auth) } +// withHeaders adds headers to the context. +func withHeaders(ctx context.Context, headers http.Header) context.Context { + return context.WithValue(ctx, headersKey{}, headers) +} + // authFromRequest extracts the auth token from the request headers. func authFromRequest(ctx context.Context, r *http.Request) context.Context { return withAuthKey(ctx, r.Header.Get("Authorization")) } +// headersFromRequest extracts all headers from the request and stores them in context. +func headersFromRequest(ctx context.Context, r *http.Request) context.Context { + // Clone the headers to avoid any mutation issues + headers := r.Header.Clone() + return withHeaders(ctx, headers) +} + // tokenFromContext extracts the auth token from the context. // This can be used by clients to pass the auth token to the server. func tokenFromContext(ctx context.Context) (string, error) { @@ -45,6 +62,12 @@ func tokenFromContext(ctx context.Context) (string, error) { return auth, nil } +// headersFromContext extracts headers from the context. +func headersFromContext(ctx context.Context) (http.Header, bool) { + headers, ok := ctx.Value(headersKey{}).(http.Header) + return headers, ok +} + // Options represents configuration options for the GraphQLSchemaServer type Options struct { // GraphName is the name of the graph to be served @@ -70,6 +93,16 @@ type Options struct { ExposeSchema bool // Stateless determines whether the MCP server should be stateless Stateless bool + // ForwardHeadersEnabled determines whether header forwarding is enabled + ForwardHeadersEnabled bool + // ForwardHeadersAllowList is the list of headers (or regex patterns) to forward + ForwardHeadersAllowList []string + // TokenValidator is the token validator for authorization + TokenValidator *TokenValidator + // MetadataConfig is the configuration for the protected resource metadata endpoint + MetadataConfig *config.MCPMetadataConfiguration + // AuthConfig is the full authorization configuration + AuthConfig *config.MCPAuthorizationConfiguration } // GraphQLSchemaServer represents an MCP server that works with GraphQL schemas and operations @@ -91,6 +124,11 @@ type GraphQLSchemaServer struct { operationsManager *OperationsManager schemaCompiler *SchemaCompiler registeredTools []string + forwardHeadersEnabled bool + forwardHeadersAllowList []string + tokenValidator *TokenValidator + metadataConfig *config.MCPMetadataConfiguration + authConfig *config.MCPAuthorizationConfiguration } type graphqlRequest struct { @@ -218,6 +256,11 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) exposeSchema: options.ExposeSchema, stateless: options.Stateless, baseURL: options.BaseURL, + forwardHeadersEnabled: options.ForwardHeadersEnabled, + forwardHeadersAllowList: options.ForwardHeadersAllowList, + tokenValidator: options.TokenValidator, + metadataConfig: options.MetadataConfig, + authConfig: options.AuthConfig, } return gs, nil @@ -285,6 +328,35 @@ func WithStateless(stateless bool) func(*Options) { } } +// WithForwardHeaders configures header forwarding to GraphQL requests +func WithForwardHeaders(enabled bool, allowList []string) func(*Options) { + return func(o *Options) { + o.ForwardHeadersEnabled = enabled + o.ForwardHeadersAllowList = allowList + } +} + +// WithTokenValidator sets the token validator for authorization +func WithTokenValidator(tokenValidator *TokenValidator) func(*Options) { + return func(o *Options) { + o.TokenValidator = tokenValidator + } +} + +// WithMetadataConfig sets the metadata configuration +func WithMetadataConfig(metadataConfig *config.MCPMetadataConfiguration) func(*Options) { + return func(o *Options) { + o.MetadataConfig = metadataConfig + } +} + +// WithAuthConfig sets the authorization configuration +func WithAuthConfig(authConfig *config.MCPAuthorizationConfiguration) func(*Options) { + return func(o *Options) { + o.AuthConfig = authConfig + } +} + // Serve starts the server with the configured options and returns a streamable HTTP server. func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { // Create custom HTTP server @@ -295,11 +367,20 @@ func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { IdleTimeout: 60 * time.Second, } + // Create a combined context function that captures both auth and headers + contextFunc := func(ctx context.Context, r *http.Request) context.Context { + ctx = authFromRequest(ctx, r) + if s.forwardHeadersEnabled { + ctx = headersFromRequest(ctx, r) + } + return ctx + } + streamableHTTPServer := server.NewStreamableHTTPServer(s.server, server.WithStreamableHTTPServer(httpServer), server.WithLogger(NewZapAdapter(s.logger.With(zap.String("component", "mcp-server")))), server.WithStateLess(s.stateless), - server.WithHTTPContextFunc(authFromRequest), + server.WithHTTPContextFunc(contextFunc), server.WithHeartbeatInterval(10*time.Second), ) @@ -307,10 +388,27 @@ func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { mux := http.NewServeMux() - // No OAuth protection - original behavior - mux.Handle("/mcp", corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Create the MCP handler + mcpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { streamableHTTPServer.ServeHTTP(w, r) - }))) + }) + + // Apply authorization middleware if token validator is configured + if s.tokenValidator != nil && s.tokenValidator.enabled { + mcpHandler = s.tokenValidator.AuthorizationMiddleware(mcpHandler).(http.HandlerFunc) + } + + // Apply CORS middleware + mux.Handle("/mcp", corsMiddleware(mcpHandler)) + + // Add protected resource metadata endpoint if authorization is enabled + if s.metadataConfig != nil && s.metadataConfig.Enabled { + metadataHandler := http.HandlerFunc(s.handleProtectedResourceMetadata) + // Register at both standard RFC 9728 path and MCP-specific path + // The MCP-specific path is used when base_url includes /mcp + mux.Handle("/.well-known/oauth-protected-resource", corsMiddleware(metadataHandler)) + mux.Handle("/.well-known/oauth-protected-resource/mcp", corsMiddleware(metadataHandler)) + } // Set the handler for the custom HTTP server httpServer.Handler = mux @@ -559,6 +657,11 @@ func (s *GraphQLSchemaServer) registerTools() error { // handleOperation handles a specific operation func (s *GraphQLSchemaServer) handleOperation(handler *operationHandler) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // TODO: Phase 3 - Validate scopes from GraphQL @wg_auth directive + // For now, operations inherit the same scope requirements as execute_graphql + if err := s.validateToolScopes(ctx, "execute_operation"); err != nil { + return nil, err + } jsonBytes, err := json.Marshal(request.GetArguments()) if err != nil { @@ -580,6 +683,11 @@ func (s *GraphQLSchemaServer) handleOperation(handler *operationHandler) func(ct // handleGraphQLOperationInfo returns a handler function that provides detailed info for a specific operation. func (s *GraphQLSchemaServer) handleGraphQLOperationInfo() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Validate scopes for get_operation_info tool + if err := s.validateToolScopes(ctx, "get_operation_info"); err != nil { + return nil, err + } + var input GraphQLOperationInfoInput inputBytes, err := json.Marshal(request.GetArguments()) if err != nil { @@ -654,6 +762,39 @@ Important Notes: } } +// filterHeaders filters headers based on the allowlist configuration. +// It supports both exact matches and regex patterns. +func (s *GraphQLSchemaServer) filterHeaders(headers http.Header) http.Header { + if !s.forwardHeadersEnabled || len(s.forwardHeadersAllowList) == 0 { + return http.Header{} + } + + filtered := http.Header{} + + for _, pattern := range s.forwardHeadersAllowList { + // Try to compile as regex + re, err := regexp.Compile("(?i)^" + pattern + "$") + if err != nil { + // If it's not a valid regex, treat it as an exact match (case-insensitive) + for headerName, headerValues := range headers { + if strings.EqualFold(headerName, pattern) { + filtered[headerName] = headerValues + } + } + continue + } + + // Match using regex + for headerName, headerValues := range headers { + if re.MatchString(headerName) { + filtered[headerName] = headerValues + } + } + } + + return filtered +} + // executeGraphQLQuery executes a GraphQL query against the router endpoint func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query string, variables json.RawMessage) (*mcp.CallToolResult, error) { // Create the GraphQL request @@ -675,6 +816,7 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json; charset=utf-8") + // Always forward Authorization header (legacy behavior) token, err := tokenFromContext(ctx) if err != nil { s.logger.Debug("failed to get token from context", zap.Error(err)) @@ -682,7 +824,21 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str req.Header.Set("Authorization", token) } - // Forward Authorization header if provided + // Forward additional headers if enabled + if s.forwardHeadersEnabled { + if headers, ok := headersFromContext(ctx); ok { + filteredHeaders := s.filterHeaders(headers) + for headerName, headerValues := range filteredHeaders { + // Skip Authorization as it's already handled above + if strings.EqualFold(headerName, "Authorization") { + continue + } + for _, headerValue := range headerValues { + req.Header.Add(headerName, headerValue) + } + } + } + } resp, err := s.httpClient.Do(req) if err != nil { @@ -721,9 +877,90 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str return mcp.NewToolResultText(string(body)), nil } +// validateToolScopes validates that the authenticated user has the required scopes for a tool +func (s *GraphQLSchemaServer) validateToolScopes(ctx context.Context, toolName string) error { + // Skip validation if authorization is not configured or disabled + if s.authConfig == nil || s.authConfig.Scopes.Mode == "disabled" { + return nil + } + + // Get authentication from context + auth, ok := AuthenticationFromContext(ctx) + + // Determine required scopes based on tool name + var requiredScopes []string + var isPublic bool + + switch toolName { + case "get_schema": + requiredScopes = s.authConfig.Scopes.Tools.GetSchema.Scopes + isPublic = s.authConfig.Scopes.Tools.GetSchema.Public + case "execute_graphql", "execute_operation": + requiredScopes = s.authConfig.Scopes.Tools.ExecuteGraphQL.Scopes + isPublic = s.authConfig.Scopes.Tools.ExecuteGraphQL.Public + case "get_operation_info": + requiredScopes = s.authConfig.Scopes.Tools.GetOperationInfo.Scopes + isPublic = s.authConfig.Scopes.Tools.GetOperationInfo.Public + default: + // Unknown tool, allow by default + return nil + } + + // If tool is marked as public, skip scope validation + if isPublic { + return nil + } + + // If no scopes are required, allow access + if len(requiredScopes) == 0 { + return nil + } + + // If scopes are required but no authentication present + if !ok || auth == nil { + if s.authConfig.Scopes.Mode == "enforce" { + return fmt.Errorf("authentication required for tool '%s'", toolName) + } + // log_only mode + s.logger.Warn("scope validation failed: no authentication", + zap.String("tool", toolName), + zap.String("mode", s.authConfig.Scopes.Mode), + zap.Strings("required_scopes", requiredScopes)) + return nil + } + + // Validate scopes using the token validator + if err := s.tokenValidator.ValidateScopes(auth, requiredScopes, false); err != nil { + if s.authConfig.Scopes.Mode == "enforce" { + return fmt.Errorf("insufficient scopes for tool '%s': %w", toolName, err) + } + // log_only mode + s.logger.Warn("scope validation failed", + zap.String("tool", toolName), + zap.String("mode", s.authConfig.Scopes.Mode), + zap.Strings("required_scopes", requiredScopes), + zap.Strings("provided_scopes", auth.Scopes()), + zap.Error(err)) + return nil + } + + // Scopes validated successfully + s.logger.Debug("scope validation successful", + zap.String("tool", toolName), + zap.Strings("required_scopes", requiredScopes), + zap.Strings("provided_scopes", auth.Scopes())) + + return nil +} + // handleExecuteGraphQL returns a handler function that executes arbitrary GraphQL queries func (s *GraphQLSchemaServer) handleExecuteGraphQL() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Validate scopes for execute_graphql tool + if err := s.validateToolScopes(ctx, "execute_graphql"); err != nil { + return nil, err + } + // Parse the JSON input jsonBytes, err := json.Marshal(request.GetArguments()) if err != nil { @@ -746,6 +983,11 @@ func (s *GraphQLSchemaServer) handleExecuteGraphQL() func(ctx context.Context, r // handleGetGraphQLSchema returns a handler function that returns the full GraphQL schema func (s *GraphQLSchemaServer) handleGetGraphQLSchema() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Validate scopes for get_schema tool + if err := s.validateToolScopes(ctx, "get_schema"); err != nil { + return nil, err + } + // Get the schema from the operations manager schema := s.operationsManager.GetSchema() if schema == nil { @@ -799,5 +1041,95 @@ func setCORSHeaders(w http.ResponseWriter, allowedMethods []string) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allowedMethods, "OPTIONS"), ", ")) w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Authorization, Last-Event-ID, Mcp-Protocol-Version, Mcp-Session-Id") + w.Header().Set("Access-Control-Expose-Headers", "Content-Type, Content-Length") w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours } + +// ProtectedResourceMetadata represents the RFC 9728 OAuth 2.0 Protected Resource Metadata +type ProtectedResourceMetadata struct { + Resource string `json:"resource"` + AuthorizationServers []string `json:"authorization_servers"` + ScopesSupported []string `json:"scopes_supported,omitempty"` + BearerMethodsSupported []string `json:"bearer_methods_supported"` + ResourceDocumentation string `json:"resource_documentation,omitempty"` +} + +// handleProtectedResourceMetadata serves the RFC 9728 OAuth 2.0 Protected Resource Metadata +func (s *GraphQLSchemaServer) handleProtectedResourceMetadata(w http.ResponseWriter, r *http.Request) { + + s.logger.Info("protected resource metadata request", + zap.String("method", r.Method), + zap.String("path", r.URL.Path), + zap.String("remote_addr", r.RemoteAddr), + zap.String("user_agent", r.Header.Get("User-Agent")), + zap.String("referer", r.Header.Get("Referer")), + ) + + // Only allow GET requests + if r.Method != http.MethodGet { + s.logger.Warn("protected resource metadata request with invalid method", + zap.String("method", r.Method), + zap.String("remote_addr", r.RemoteAddr), + ) + + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + // Collect all supported scopes from configuration + scopes := make(map[string]bool) + + // Add scopes from built-in tools configuration + if s.authConfig != nil { + if len(s.authConfig.Scopes.Tools.GetSchema.Scopes) > 0 { + for _, scope := range s.authConfig.Scopes.Tools.GetSchema.Scopes { + scopes[scope] = true + } + } + if len(s.authConfig.Scopes.Tools.ExecuteGraphQL.Scopes) > 0 { + for _, scope := range s.authConfig.Scopes.Tools.ExecuteGraphQL.Scopes { + scopes[scope] = true + } + } + if len(s.authConfig.Scopes.Tools.GetOperationInfo.Scopes) > 0 { + for _, scope := range s.authConfig.Scopes.Tools.GetOperationInfo.Scopes { + scopes[scope] = true + } + } + } + + // TODO: In Phase 3, add scopes from GraphQL operations with @wg_auth directive + + // Convert map to sorted slice for consistent output + scopesList := make([]string, 0, len(scopes)) + for scope := range scopes { + scopesList = append(scopesList, scope) + } + + // Determine resource URI + resourceURI := s.metadataConfig.ResourceURI + if resourceURI == "" && s.baseURL != "" { + resourceURI = s.baseURL + } + + // Build metadata response + metadata := ProtectedResourceMetadata{ + Resource: resourceURI, + AuthorizationServers: s.metadataConfig.AuthorizationServers, + ScopesSupported: scopesList, + BearerMethodsSupported: []string{"header"}, + } + + if s.metadataConfig.DocumentationURL != "" { + metadata.ResourceDocumentation = s.metadataConfig.DocumentationURL + } + + // Set response headers + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + // Encode and send response + if err := json.NewEncoder(w).Encode(metadata); err != nil { + s.logger.Error("failed to encode protected resource metadata", zap.Error(err)) + } +} diff --git a/router/pkg/mcpserver/server_test.go b/router/pkg/mcpserver/server_test.go new file mode 100644 index 0000000000..9cf2d7bb7e --- /dev/null +++ b/router/pkg/mcpserver/server_test.go @@ -0,0 +1,287 @@ +package mcpserver + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestFilterHeaders(t *testing.T) { + tests := []struct { + name string + forwardHeadersEnabled bool + forwardHeadersAllowList []string + inputHeaders http.Header + expectedHeaders http.Header + }{ + { + name: "disabled forwarding returns empty headers", + forwardHeadersEnabled: false, + forwardHeadersAllowList: []string{"Authorization"}, + inputHeaders: http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Tenant-ID": []string{"tenant-1"}, + }, + expectedHeaders: http.Header{}, + }, + { + name: "empty allowlist returns empty headers", + forwardHeadersEnabled: true, + forwardHeadersAllowList: []string{}, + inputHeaders: http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Tenant-ID": []string{"tenant-1"}, + }, + expectedHeaders: http.Header{}, + }, + { + name: "exact match case insensitive", + forwardHeadersEnabled: true, + forwardHeadersAllowList: []string{"authorization", "x-tenant-id"}, + inputHeaders: http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Tenant-ID": []string{"tenant-1"}, + "X-Trace-ID": []string{"trace-123"}, + }, + expectedHeaders: http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Tenant-ID": []string{"tenant-1"}, + }, + }, + { + name: "regex pattern matching", + forwardHeadersEnabled: true, + forwardHeadersAllowList: []string{"X-.*"}, + inputHeaders: http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Tenant-ID": []string{"tenant-1"}, + "X-Trace-ID": []string{"trace-123"}, + "X-Custom": []string{"custom-value"}, + }, + expectedHeaders: http.Header{ + "X-Tenant-ID": []string{"tenant-1"}, + "X-Trace-ID": []string{"trace-123"}, + "X-Custom": []string{"custom-value"}, + }, + }, + { + name: "mixed exact and regex patterns", + forwardHeadersEnabled: true, + forwardHeadersAllowList: []string{"Authorization", "X-.*"}, + inputHeaders: http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Tenant-ID": []string{"tenant-1"}, + "X-Trace-ID": []string{"trace-123"}, + "Content-Type": []string{"application/json"}, + }, + expectedHeaders: http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Tenant-ID": []string{"tenant-1"}, + "X-Trace-ID": []string{"trace-123"}, + }, + }, + { + name: "multiple values for same header", + forwardHeadersEnabled: true, + forwardHeadersAllowList: []string{"X-Custom"}, + inputHeaders: http.Header{ + "X-Custom": []string{"value1", "value2", "value3"}, + }, + expectedHeaders: http.Header{ + "X-Custom": []string{"value1", "value2", "value3"}, + }, + }, + { + name: "no matching headers", + forwardHeadersEnabled: true, + forwardHeadersAllowList: []string{"X-Missing"}, + inputHeaders: http.Header{ + "Authorization": []string{"Bearer token123"}, + "Content-Type": []string{"application/json"}, + }, + expectedHeaders: http.Header{}, + }, + { + name: "invalid regex treated as exact match", + forwardHeadersEnabled: true, + forwardHeadersAllowList: []string{"[invalid"}, + inputHeaders: http.Header{ + "[invalid": []string{"value1"}, + "Authorization": []string{"Bearer token123"}, + }, + expectedHeaders: http.Header{ + "[invalid": []string{"value1"}, + }, + }, + { + name: "case insensitive regex matching", + forwardHeadersEnabled: true, + forwardHeadersAllowList: []string{"x-.*"}, + inputHeaders: http.Header{ + "X-Tenant-ID": []string{"tenant-1"}, + "x-trace-id": []string{"trace-123"}, + "X-CUSTOM": []string{"custom-value"}, + }, + expectedHeaders: http.Header{ + "X-Tenant-ID": []string{"tenant-1"}, + "x-trace-id": []string{"trace-123"}, + "X-CUSTOM": []string{"custom-value"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := &GraphQLSchemaServer{ + forwardHeadersEnabled: tt.forwardHeadersEnabled, + forwardHeadersAllowList: tt.forwardHeadersAllowList, + logger: zap.NewNop(), + } + + result := server.filterHeaders(tt.inputHeaders) + + assert.Equal(t, len(tt.expectedHeaders), len(result), "number of headers should match") + for key, expectedValues := range tt.expectedHeaders { + actualValues, ok := result[key] + assert.True(t, ok, "header %s should be present", key) + assert.Equal(t, expectedValues, actualValues, "values for header %s should match", key) + } + + // Ensure no extra headers are present + for key := range result { + _, ok := tt.expectedHeaders[key] + assert.True(t, ok, "unexpected header %s in result", key) + } + }) + } +} + +func TestWithForwardHeaders(t *testing.T) { + tests := []struct { + name string + enabled bool + allowList []string + wantEnabled bool + wantList []string + }{ + { + name: "enabled with allowlist", + enabled: true, + allowList: []string{"Authorization", "X-Tenant-ID"}, + wantEnabled: true, + wantList: []string{"Authorization", "X-Tenant-ID"}, + }, + { + name: "disabled with allowlist", + enabled: false, + allowList: []string{"Authorization"}, + wantEnabled: false, + wantList: []string{"Authorization"}, + }, + { + name: "enabled with empty allowlist", + enabled: true, + allowList: []string{}, + wantEnabled: true, + wantList: []string{}, + }, + { + name: "enabled with nil allowlist", + enabled: true, + allowList: nil, + wantEnabled: true, + wantList: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := &Options{} + optFunc := WithForwardHeaders(tt.enabled, tt.allowList) + optFunc(opts) + + assert.Equal(t, tt.wantEnabled, opts.ForwardHeadersEnabled) + assert.Equal(t, tt.wantList, opts.ForwardHeadersAllowList) + }) + } +} + +func TestNewGraphQLSchemaServer_ForwardHeadersDefaults(t *testing.T) { + server, err := NewGraphQLSchemaServer("http://localhost:3000/graphql") + require.NoError(t, err) + require.NotNil(t, server) + + // Check that forward headers are disabled by default + assert.False(t, server.forwardHeadersEnabled) + assert.Nil(t, server.forwardHeadersAllowList) +} + +func TestNewGraphQLSchemaServer_WithForwardHeaders(t *testing.T) { + allowList := []string{"Authorization", "X-Tenant-ID", "X-.*"} + server, err := NewGraphQLSchemaServer( + "http://localhost:3000/graphql", + WithForwardHeaders(true, allowList), + ) + require.NoError(t, err) + require.NotNil(t, server) + + assert.True(t, server.forwardHeadersEnabled) + assert.Equal(t, allowList, server.forwardHeadersAllowList) +} + +func TestHeadersFromContext(t *testing.T) { + tests := []struct { + name string + setupContext func() http.Header + expectedOk bool + expectedHeader http.Header + }{ + { + name: "headers present in context", + setupContext: func() http.Header { + return http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Tenant-Id": []string{"tenant-1"}, // Note: Go canonicalizes to X-Tenant-Id + } + }, + expectedOk: true, + expectedHeader: http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Tenant-Id": []string{"tenant-1"}, // Note: Go canonicalizes to X-Tenant-Id + }, + }, + { + name: "empty headers in context", + setupContext: func() http.Header { + return http.Header{} + }, + expectedOk: true, + expectedHeader: http.Header{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers := tt.setupContext() + req, err := http.NewRequest("GET", "http://example.com", nil) + require.NoError(t, err) + + for key, values := range headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + + ctx := headersFromRequest(req.Context(), req) + retrievedHeaders, ok := headersFromContext(ctx) + + assert.Equal(t, tt.expectedOk, ok) + if tt.expectedOk { + assert.Equal(t, tt.expectedHeader, retrievedHeaders) + } + }) + } +} \ No newline at end of file diff --git a/router/pkg/mcpserver/token_validator.go b/router/pkg/mcpserver/token_validator.go new file mode 100644 index 0000000000..722e3e6771 --- /dev/null +++ b/router/pkg/mcpserver/token_validator.go @@ -0,0 +1,181 @@ +package mcpserver + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/wundergraph/cosmo/router/pkg/authentication" + "go.uber.org/zap" +) + +// TokenValidator validates JWT access tokens for MCP requests +type TokenValidator struct { + authenticators []authentication.Authenticator + logger *zap.Logger + enabled bool +} + +// NewTokenValidator creates a new token validator +func NewTokenValidator(authenticators []authentication.Authenticator, logger *zap.Logger, enabled bool) *TokenValidator { + return &TokenValidator{ + authenticators: authenticators, + logger: logger, + enabled: enabled, + } +} + +// ValidateRequest validates the JWT token in the request +// Returns the authentication result or an error +func (tv *TokenValidator) ValidateRequest(ctx context.Context, r *http.Request) (authentication.Authentication, error) { + if !tv.enabled { + return nil, nil + } + + // Create a provider from the request + provider := &httpRequestProvider{request: r} + + // Authenticate using the configured authenticators + auth, err := authentication.Authenticate(ctx, tv.authenticators, provider) + if err != nil { + tv.logger.Debug("authentication failed", zap.Error(err)) + return nil, fmt.Errorf("invalid access token: %w", err) + } + + // If no authentication information was found, return error + if auth == nil { + return nil, fmt.Errorf("missing access token") + } + + return auth, nil +} + +// ValidateScopes checks if the authenticated request has the required scopes +func (tv *TokenValidator) ValidateScopes(auth authentication.Authentication, requiredScopes []string, anyOf bool) error { + if !tv.enabled || len(requiredScopes) == 0 { + return nil + } + + if auth == nil { + return fmt.Errorf("authentication required") + } + + tokenScopes := auth.Scopes() + + if anyOf { + // At least one of the required scopes must be present + for _, required := range requiredScopes { + for _, tokenScope := range tokenScopes { + if tokenScope == required { + return nil + } + } + } + return fmt.Errorf("insufficient scopes: requires at least one of %v, got %v", requiredScopes, tokenScopes) + } + + // All required scopes must be present + for _, required := range requiredScopes { + found := false + for _, tokenScope := range tokenScopes { + if tokenScope == required { + found = true + break + } + } + if !found { + return fmt.Errorf("insufficient scopes: requires %v, got %v", requiredScopes, tokenScopes) + } + } + + return nil +} + +// httpRequestProvider implements the authentication.Provider interface +type httpRequestProvider struct { + request *http.Request +} + +func (p *httpRequestProvider) AuthenticationHeaders() http.Header { + return p.request.Header +} + +// AuthorizationMiddleware creates a middleware that validates JWT tokens +func (tv *TokenValidator) AuthorizationMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !tv.enabled { + next.ServeHTTP(w, r) + return + } + + auth, err := tv.ValidateRequest(r.Context(), r) + if err != nil { + tv.logger.Debug("authorization failed", zap.Error(err)) + tv.writeUnauthorizedResponse(w, r, err) + return + } + + // Store authentication in context for later use + ctx := context.WithValue(r.Context(), authenticationKey{}, auth) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// authenticationKey is a context key for storing authentication +type authenticationKey struct{} + +// AuthenticationFromContext retrieves authentication from context +func AuthenticationFromContext(ctx context.Context) (authentication.Authentication, bool) { + auth, ok := ctx.Value(authenticationKey{}).(authentication.Authentication) + return auth, ok +} + +// writeUnauthorizedResponse writes a 401 Unauthorized response +func (tv *TokenValidator) writeUnauthorizedResponse(w http.ResponseWriter, r *http.Request, err error) { + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"%s\"", r.Host)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + + response := map[string]string{ + "error": "unauthorized", + "error_description": "Valid access token required", + } + + // Don't expose internal error details in production + if tv.logger.Level() == zap.DebugLevel { + response["error_description"] = err.Error() + } + + // Write JSON response + w.Write([]byte(fmt.Sprintf(`{"error":"%s","error_description":"%s"}`, + response["error"], + strings.ReplaceAll(response["error_description"], `"`, `\"`)))) +} + +// writeForbiddenResponse writes a 403 Forbidden response +func (tv *TokenValidator) writeForbiddenResponse(w http.ResponseWriter, r *http.Request, requiredScopes []string, providedScopes []string) { + scopesStr := strings.Join(requiredScopes, " ") + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"%s\", error=\"insufficient_scope\", scope=\"%s\"", r.Host, scopesStr)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + + response := fmt.Sprintf(`{"error":"insufficient_scope","error_description":"Required scopes: %s","required_scopes":%s,"provided_scopes":%s}`, + scopesStr, + formatScopesJSON(requiredScopes), + formatScopesJSON(providedScopes)) + + w.Write([]byte(response)) +} + +// formatScopesJSON formats scopes as a JSON array +func formatScopesJSON(scopes []string) string { + if len(scopes) == 0 { + return "[]" + } + quoted := make([]string, len(scopes)) + for i, s := range scopes { + quoted[i] = fmt.Sprintf(`"%s"`, s) + } + return "[" + strings.Join(quoted, ",") + "]" +} \ No newline at end of file