Skip to content

Commit 8ccf26c

Browse files
committed
mcp: add ModifyRequest to HTTP client transports
Add ModifyRequest func(*http.Request) field to both SSEClientTransport and StreamableClientTransport. This callback is invoked before each outgoing HTTP request, allowing users to add headers, authentication, or other request modifications. This provides a simpler alternative to implementing custom RoundTrippers for common use cases like adding authorization headers or request IDs. Fixes #533
1 parent fd0dc9d commit 8ccf26c

File tree

5 files changed

+361
-26
lines changed

5 files changed

+361
-26
lines changed

mcp/sse.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ type SSEClientTransport struct {
329329
// HTTPClient is the client to use for making HTTP requests. If nil,
330330
// http.DefaultClient is used.
331331
HTTPClient *http.Client
332+
333+
// If set, ModifyRequest is called before each outgoing HTTP request made by the client
334+
// connection. It can be used to, for example, add headers to outgoing requests.
335+
ModifyRequest func(*http.Request)
332336
}
333337

334338
// Connect connects through the client endpoint.
@@ -346,6 +350,9 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
346350
httpClient = http.DefaultClient
347351
}
348352
req.Header.Set("Accept", "text/event-stream")
353+
if c.ModifyRequest != nil {
354+
c.ModifyRequest(req)
355+
}
349356
resp, err := httpClient.Do(req)
350357
if err != nil {
351358
return nil, err
@@ -372,11 +379,12 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
372379

373380
// From here on, the stream takes ownership of resp.Body.
374381
s := &sseClientConn{
375-
client: httpClient,
376-
msgEndpoint: msgEndpoint,
377-
incoming: make(chan []byte, 100),
378-
body: resp.Body,
379-
done: make(chan struct{}),
382+
client: httpClient,
383+
msgEndpoint: msgEndpoint,
384+
modifyRequest: c.ModifyRequest,
385+
incoming: make(chan []byte, 100),
386+
body: resp.Body,
387+
done: make(chan struct{}),
380388
}
381389

382390
go func() {
@@ -403,9 +411,10 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) {
403411
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
404412
// - Close terminates the GET request.
405413
type sseClientConn struct {
406-
client *http.Client // HTTP client to use for requests
407-
msgEndpoint *url.URL // session endpoint for POSTs
408-
incoming chan []byte // queue of incoming messages
414+
client *http.Client // HTTP client to use for requests
415+
msgEndpoint *url.URL // session endpoint for POSTs
416+
modifyRequest func(*http.Request) // optional callback to modify outgoing requests
417+
incoming chan []byte // queue of incoming messages
409418

410419
mu sync.Mutex
411420
body io.ReadCloser // body of the hanging GET
@@ -456,6 +465,9 @@ func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error {
456465
return err
457466
}
458467
req.Header.Set("Content-Type", "application/json")
468+
if c.modifyRequest != nil {
469+
c.modifyRequest(req)
470+
}
459471
resp, err := c.client.Do(req)
460472
if err != nil {
461473
return err

mcp/sse_test.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"io"
1212
"net/http"
1313
"net/http/httptest"
14+
"sync"
1415
"sync/atomic"
1516
"testing"
1617

@@ -131,3 +132,127 @@ type roundTripperFunc func(*http.Request) (*http.Response, error)
131132
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
132133
return f(req)
133134
}
135+
136+
func TestSSEClientModifyRequest(t *testing.T) {
137+
ctx := context.Background()
138+
139+
// Track all HTTP requests
140+
var mu sync.Mutex
141+
var requestMethods []string
142+
var requestHeaders []http.Header
143+
144+
// Create a server
145+
server := NewServer(testImpl, nil)
146+
AddTool(server, &Tool{Name: "greet"}, sayHi)
147+
148+
sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil)
149+
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
150+
mu.Lock()
151+
requestMethods = append(requestMethods, req.Method)
152+
requestHeaders = append(requestHeaders, req.Header.Clone())
153+
mu.Unlock()
154+
sseHandler.ServeHTTP(w, req)
155+
}))
156+
defer httpServer.Close()
157+
158+
// Create transport with ModifyRequest
159+
clientTransport := &SSEClientTransport{
160+
Endpoint: httpServer.URL,
161+
ModifyRequest: func(req *http.Request) {
162+
req.Header.Set("X-Custom-Header", "test-value")
163+
req.Header.Set("Authorization", "Bearer test-token")
164+
},
165+
}
166+
167+
c := NewClient(testImpl, nil)
168+
cs, err := c.Connect(ctx, clientTransport, nil)
169+
if err != nil {
170+
t.Fatalf("Connect failed: %v", err)
171+
}
172+
defer cs.Close()
173+
174+
// Call a tool (which will make a POST request)
175+
_, err = cs.CallTool(ctx, &CallToolParams{
176+
Name: "greet",
177+
Arguments: map[string]any{"Name": "user"},
178+
})
179+
if err != nil {
180+
t.Fatalf("CallTool failed: %v", err)
181+
}
182+
183+
// Verify that we have both GET and POST requests
184+
mu.Lock()
185+
defer mu.Unlock()
186+
187+
if len(requestMethods) < 2 {
188+
t.Fatalf("Expected at least 2 requests (GET and POST), got %d", len(requestMethods))
189+
}
190+
191+
// Verify GET request has custom headers
192+
foundGET := false
193+
for i, method := range requestMethods {
194+
if method == "GET" {
195+
foundGET = true
196+
if got := requestHeaders[i].Get("X-Custom-Header"); got != "test-value" {
197+
t.Errorf("GET request: X-Custom-Header = %q, want %q", got, "test-value")
198+
}
199+
if got := requestHeaders[i].Get("Authorization"); got != "Bearer test-token" {
200+
t.Errorf("GET request: Authorization = %q, want %q", got, "Bearer test-token")
201+
}
202+
}
203+
}
204+
if !foundGET {
205+
t.Error("No GET request found")
206+
}
207+
208+
// Verify POST request has custom headers
209+
foundPOST := false
210+
for i, method := range requestMethods {
211+
if method == "POST" {
212+
foundPOST = true
213+
if got := requestHeaders[i].Get("X-Custom-Header"); got != "test-value" {
214+
t.Errorf("POST request: X-Custom-Header = %q, want %q", got, "test-value")
215+
}
216+
if got := requestHeaders[i].Get("Authorization"); got != "Bearer test-token" {
217+
t.Errorf("POST request: Authorization = %q, want %q", got, "Bearer test-token")
218+
}
219+
}
220+
}
221+
if !foundPOST {
222+
t.Error("No POST request found")
223+
}
224+
}
225+
226+
func TestSSEClientModifyRequestNil(t *testing.T) {
227+
ctx := context.Background()
228+
229+
// Create a server
230+
server := NewServer(testImpl, nil)
231+
AddTool(server, &Tool{Name: "greet"}, sayHi)
232+
233+
sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil)
234+
httpServer := httptest.NewServer(sseHandler)
235+
defer httpServer.Close()
236+
237+
// Create transport with nil ModifyRequest (should not panic)
238+
clientTransport := &SSEClientTransport{
239+
Endpoint: httpServer.URL,
240+
ModifyRequest: nil, // explicitly nil
241+
}
242+
243+
c := NewClient(testImpl, nil)
244+
cs, err := c.Connect(ctx, clientTransport, nil)
245+
if err != nil {
246+
t.Fatalf("Connect failed: %v", err)
247+
}
248+
defer cs.Close()
249+
250+
// Call a tool - should work normally
251+
_, err = cs.CallTool(ctx, &CallToolParams{
252+
Name: "greet",
253+
Arguments: map[string]any{"Name": "user"},
254+
})
255+
if err != nil {
256+
t.Fatalf("CallTool failed: %v", err)
257+
}
258+
}

mcp/streamable.go

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,10 @@ type StreamableClientTransport struct {
983983
// It defaults to 5. To disable retries, use a negative number.
984984
MaxRetries int
985985

986+
// If set, ModifyRequest is called before each outgoing HTTP request made by the client
987+
// connection. It can be used to, for example, add headers to outgoing requests.
988+
ModifyRequest func(*http.Request)
989+
986990
// TODO(rfindley): propose exporting these.
987991
// If strict is set, the transport is in 'strict mode', where any violation
988992
// of the MCP spec causes a failure.
@@ -1029,29 +1033,31 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
10291033
// cancelling its blocking network operations, which prevents hangs on exit.
10301034
connCtx, cancel := context.WithCancel(ctx)
10311035
conn := &streamableClientConn{
1032-
url: t.Endpoint,
1033-
client: client,
1034-
incoming: make(chan jsonrpc.Message, 10),
1035-
done: make(chan struct{}),
1036-
maxRetries: maxRetries,
1037-
strict: t.strict,
1038-
logger: t.logger,
1039-
ctx: connCtx,
1040-
cancel: cancel,
1041-
failed: make(chan struct{}),
1036+
url: t.Endpoint,
1037+
client: client,
1038+
modifyRequest: t.ModifyRequest,
1039+
incoming: make(chan jsonrpc.Message, 10),
1040+
done: make(chan struct{}),
1041+
maxRetries: maxRetries,
1042+
strict: t.strict,
1043+
logger: t.logger,
1044+
ctx: connCtx,
1045+
cancel: cancel,
1046+
failed: make(chan struct{}),
10421047
}
10431048
return conn, nil
10441049
}
10451050

10461051
type streamableClientConn struct {
1047-
url string
1048-
client *http.Client
1049-
ctx context.Context
1050-
cancel context.CancelFunc
1051-
incoming chan jsonrpc.Message
1052-
maxRetries int
1053-
strict bool // from [StreamableClientTransport.strict]
1054-
logger *slog.Logger // from [StreamableClientTransport.logger]
1052+
url string
1053+
client *http.Client
1054+
modifyRequest func(*http.Request) // from [StreamableClientTransport.ModifyRequest]
1055+
ctx context.Context
1056+
cancel context.CancelFunc
1057+
incoming chan jsonrpc.Message
1058+
maxRetries int
1059+
strict bool // from [StreamableClientTransport.strict]
1060+
logger *slog.Logger // from [StreamableClientTransport.logger]
10551061

10561062
// Guard calls to Close, as it may be called multiple times.
10571063
closeOnce sync.Once
@@ -1188,6 +1194,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
11881194
req.Header.Set("Content-Type", "application/json")
11891195
req.Header.Set("Accept", "application/json, text/event-stream")
11901196
c.setMCPHeaders(req)
1197+
if c.modifyRequest != nil {
1198+
c.modifyRequest(req)
1199+
}
11911200

11921201
resp, err := c.client.Do(req)
11931202
if err != nil {
@@ -1448,6 +1457,9 @@ func (c *streamableClientConn) Close() error {
14481457
c.closeErr = err
14491458
} else {
14501459
c.setMCPHeaders(req)
1460+
if c.modifyRequest != nil {
1461+
c.modifyRequest(req)
1462+
}
14511463
if _, err := c.client.Do(req); err != nil {
14521464
c.closeErr = err
14531465
}
@@ -1474,6 +1486,9 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
14741486
req.Header.Set("Last-Event-ID", lastEventID)
14751487
}
14761488
req.Header.Set("Accept", "text/event-stream")
1489+
if c.modifyRequest != nil {
1490+
c.modifyRequest(req)
1491+
}
14771492

14781493
return c.client.Do(req)
14791494
}

0 commit comments

Comments
 (0)