Skip to content

Commit df56e35

Browse files
dsp-antclaudepcarleton
authored
fix: Pass RequestInit options to auth requests (#1066)
Co-authored-by: Claude <[email protected]> Co-authored-by: Paul Carleton <[email protected]>
1 parent a7e525a commit df56e35

File tree

4 files changed

+165
-23
lines changed

4 files changed

+165
-23
lines changed

src/client/auth.test.ts

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2558,4 +2558,113 @@ describe('OAuth Authorization', () => {
25582558
expect(body.get('refresh_token')).toBe('refresh123');
25592559
});
25602560
});
2561+
2562+
describe('RequestInit headers passthrough', () => {
2563+
it('custom headers from RequestInit are passed to auth discovery requests', async () => {
2564+
const { createFetchWithInit } = await import('../shared/transport.js');
2565+
2566+
const customFetch = vi.fn().mockResolvedValue({
2567+
ok: true,
2568+
status: 200,
2569+
json: async () => ({
2570+
resource: 'https://resource.example.com',
2571+
authorization_servers: ['https://auth.example.com']
2572+
})
2573+
});
2574+
2575+
// Create a wrapped fetch with custom headers
2576+
const wrappedFetch = createFetchWithInit(customFetch, {
2577+
headers: {
2578+
'user-agent': 'MyApp/1.0',
2579+
'x-custom-header': 'test-value'
2580+
}
2581+
});
2582+
2583+
await discoverOAuthProtectedResourceMetadata('https://resource.example.com', undefined, wrappedFetch);
2584+
2585+
expect(customFetch).toHaveBeenCalledTimes(1);
2586+
const [url, options] = customFetch.mock.calls[0];
2587+
2588+
expect(url.toString()).toBe('https://resource.example.com/.well-known/oauth-protected-resource');
2589+
expect(options.headers).toMatchObject({
2590+
'user-agent': 'MyApp/1.0',
2591+
'x-custom-header': 'test-value',
2592+
'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION
2593+
});
2594+
});
2595+
2596+
it('auth-specific headers override base headers from RequestInit', async () => {
2597+
const { createFetchWithInit } = await import('../shared/transport.js');
2598+
2599+
const customFetch = vi.fn().mockResolvedValue({
2600+
ok: true,
2601+
status: 200,
2602+
json: async () => ({
2603+
issuer: 'https://auth.example.com',
2604+
authorization_endpoint: 'https://auth.example.com/authorize',
2605+
token_endpoint: 'https://auth.example.com/token',
2606+
response_types_supported: ['code'],
2607+
code_challenge_methods_supported: ['S256']
2608+
})
2609+
});
2610+
2611+
// Create a wrapped fetch with a custom Accept header
2612+
const wrappedFetch = createFetchWithInit(customFetch, {
2613+
headers: {
2614+
Accept: 'text/plain',
2615+
'user-agent': 'MyApp/1.0'
2616+
}
2617+
});
2618+
2619+
await discoverAuthorizationServerMetadata('https://auth.example.com', {
2620+
fetchFn: wrappedFetch
2621+
});
2622+
2623+
expect(customFetch).toHaveBeenCalled();
2624+
const [, options] = customFetch.mock.calls[0];
2625+
2626+
// Auth-specific Accept header should override base Accept header
2627+
expect(options.headers).toMatchObject({
2628+
Accept: 'application/json', // Auth-specific value wins
2629+
'user-agent': 'MyApp/1.0', // Base value preserved
2630+
'MCP-Protocol-Version': LATEST_PROTOCOL_VERSION
2631+
});
2632+
});
2633+
2634+
it('other RequestInit options are passed through', async () => {
2635+
const { createFetchWithInit } = await import('../shared/transport.js');
2636+
2637+
const customFetch = vi.fn().mockResolvedValue({
2638+
ok: true,
2639+
status: 200,
2640+
json: async () => ({
2641+
resource: 'https://resource.example.com',
2642+
authorization_servers: ['https://auth.example.com']
2643+
})
2644+
});
2645+
2646+
// Create a wrapped fetch with various RequestInit options
2647+
const wrappedFetch = createFetchWithInit(customFetch, {
2648+
credentials: 'include',
2649+
mode: 'cors',
2650+
cache: 'no-cache',
2651+
headers: {
2652+
'user-agent': 'MyApp/1.0'
2653+
}
2654+
});
2655+
2656+
await discoverOAuthProtectedResourceMetadata('https://resource.example.com', undefined, wrappedFetch);
2657+
2658+
expect(customFetch).toHaveBeenCalledTimes(1);
2659+
const [, options] = customFetch.mock.calls[0];
2660+
2661+
// All RequestInit options should be preserved
2662+
expect(options.credentials).toBe('include');
2663+
expect(options.mode).toBe('cors');
2664+
expect(options.cache).toBe('no-cache');
2665+
expect(options.headers).toMatchObject({
2666+
'user-agent': 'MyApp/1.0'
2667+
});
2668+
});
2669+
});
25612670
});

src/client/sse.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { EventSource, type ErrorEvent, type EventSourceInit } from 'eventsource';
2-
import { Transport, FetchLike } from '../shared/transport.js';
2+
import { Transport, FetchLike, createFetchWithInit } from '../shared/transport.js';
33
import { JSONRPCMessage, JSONRPCMessageSchema } from '../types.js';
44
import { auth, AuthResult, extractWWWAuthenticateParams, OAuthClientProvider, UnauthorizedError } from './auth.js';
55

@@ -70,6 +70,7 @@ export class SSEClientTransport implements Transport {
7070
private _requestInit?: RequestInit;
7171
private _authProvider?: OAuthClientProvider;
7272
private _fetch?: FetchLike;
73+
private _fetchWithInit: FetchLike;
7374
private _protocolVersion?: string;
7475

7576
onclose?: () => void;
@@ -84,6 +85,7 @@ export class SSEClientTransport implements Transport {
8485
this._requestInit = opts?.requestInit;
8586
this._authProvider = opts?.authProvider;
8687
this._fetch = opts?.fetch;
88+
this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit);
8789
}
8890

8991
private async _authThenStart(): Promise<void> {
@@ -97,7 +99,7 @@ export class SSEClientTransport implements Transport {
9799
serverUrl: this._url,
98100
resourceMetadataUrl: this._resourceMetadataUrl,
99101
scope: this._scope,
100-
fetchFn: this._fetch
102+
fetchFn: this._fetchWithInit
101103
});
102104
} catch (error) {
103105
this.onerror?.(error as Error);
@@ -220,7 +222,7 @@ export class SSEClientTransport implements Transport {
220222
authorizationCode,
221223
resourceMetadataUrl: this._resourceMetadataUrl,
222224
scope: this._scope,
223-
fetchFn: this._fetch
225+
fetchFn: this._fetchWithInit
224226
});
225227
if (result !== 'AUTHORIZED') {
226228
throw new UnauthorizedError('Failed to authorize');
@@ -260,7 +262,7 @@ export class SSEClientTransport implements Transport {
260262
serverUrl: this._url,
261263
resourceMetadataUrl: this._resourceMetadataUrl,
262264
scope: this._scope,
263-
fetchFn: this._fetch
265+
fetchFn: this._fetchWithInit
264266
});
265267
if (result !== 'AUTHORIZED') {
266268
throw new UnauthorizedError();

src/client/streamableHttp.ts

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Transport, FetchLike } from '../shared/transport.js';
1+
import { Transport, FetchLike, createFetchWithInit, normalizeHeaders } from '../shared/transport.js';
22
import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from '../types.js';
33
import { auth, AuthResult, extractWWWAuthenticateParams, OAuthClientProvider, UnauthorizedError } from './auth.js';
44
import { EventSourceParserStream } from 'eventsource-parser/stream';
@@ -129,6 +129,7 @@ export class StreamableHTTPClientTransport implements Transport {
129129
private _requestInit?: RequestInit;
130130
private _authProvider?: OAuthClientProvider;
131131
private _fetch?: FetchLike;
132+
private _fetchWithInit: FetchLike;
132133
private _sessionId?: string;
133134
private _reconnectionOptions: StreamableHTTPReconnectionOptions;
134135
private _protocolVersion?: string;
@@ -145,6 +146,7 @@ export class StreamableHTTPClientTransport implements Transport {
145146
this._requestInit = opts?.requestInit;
146147
this._authProvider = opts?.authProvider;
147148
this._fetch = opts?.fetch;
149+
this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit);
148150
this._sessionId = opts?.sessionId;
149151
this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS;
150152
}
@@ -160,7 +162,7 @@ export class StreamableHTTPClientTransport implements Transport {
160162
serverUrl: this._url,
161163
resourceMetadataUrl: this._resourceMetadataUrl,
162164
scope: this._scope,
163-
fetchFn: this._fetch
165+
fetchFn: this._fetchWithInit
164166
});
165167
} catch (error) {
166168
this.onerror?.(error as Error);
@@ -190,7 +192,7 @@ export class StreamableHTTPClientTransport implements Transport {
190192
headers['mcp-protocol-version'] = this._protocolVersion;
191193
}
192194

193-
const extraHeaders = this._normalizeHeaders(this._requestInit?.headers);
195+
const extraHeaders = normalizeHeaders(this._requestInit?.headers);
194196

195197
return new Headers({
196198
...headers,
@@ -255,20 +257,6 @@ export class StreamableHTTPClientTransport implements Transport {
255257
return Math.min(initialDelay * Math.pow(growFactor, attempt), maxDelay);
256258
}
257259

258-
private _normalizeHeaders(headers: HeadersInit | undefined): Record<string, string> {
259-
if (!headers) return {};
260-
261-
if (headers instanceof Headers) {
262-
return Object.fromEntries(headers.entries());
263-
}
264-
265-
if (Array.isArray(headers)) {
266-
return Object.fromEntries(headers);
267-
}
268-
269-
return { ...(headers as Record<string, string>) };
270-
}
271-
272260
/**
273261
* Schedule a reconnection attempt with exponential backoff
274262
*
@@ -388,7 +376,7 @@ export class StreamableHTTPClientTransport implements Transport {
388376
authorizationCode,
389377
resourceMetadataUrl: this._resourceMetadataUrl,
390378
scope: this._scope,
391-
fetchFn: this._fetch
379+
fetchFn: this._fetchWithInit
392380
});
393381
if (result !== 'AUTHORIZED') {
394382
throw new UnauthorizedError('Failed to authorize');
@@ -452,7 +440,7 @@ export class StreamableHTTPClientTransport implements Transport {
452440
serverUrl: this._url,
453441
resourceMetadataUrl: this._resourceMetadataUrl,
454442
scope: this._scope,
455-
fetchFn: this._fetch
443+
fetchFn: this._fetchWithInit
456444
});
457445
if (result !== 'AUTHORIZED') {
458446
throw new UnauthorizedError();

src/shared/transport.ts

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,49 @@ import { JSONRPCMessage, MessageExtraInfo, RequestId } from '../types.js';
22

33
export type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;
44

5+
/**
6+
* Normalizes HeadersInit to a plain Record<string, string> for manipulation.
7+
* Handles Headers objects, arrays of tuples, and plain objects.
8+
*/
9+
export function normalizeHeaders(headers: HeadersInit | undefined): Record<string, string> {
10+
if (!headers) return {};
11+
12+
if (headers instanceof Headers) {
13+
return Object.fromEntries(headers.entries());
14+
}
15+
16+
if (Array.isArray(headers)) {
17+
return Object.fromEntries(headers);
18+
}
19+
20+
return { ...(headers as Record<string, string>) };
21+
}
22+
23+
/**
24+
* Creates a fetch function that includes base RequestInit options.
25+
* This ensures requests inherit settings like credentials, mode, headers, etc. from the base init.
26+
*
27+
* @param baseFetch - The base fetch function to wrap (defaults to global fetch)
28+
* @param baseInit - The base RequestInit to merge with each request
29+
* @returns A wrapped fetch function that merges base options with call-specific options
30+
*/
31+
export function createFetchWithInit(baseFetch: FetchLike = fetch, baseInit?: RequestInit): FetchLike {
32+
if (!baseInit) {
33+
return baseFetch;
34+
}
35+
36+
// Return a wrapped fetch that merges base RequestInit with call-specific init
37+
return async (url: string | URL, init?: RequestInit): Promise<Response> => {
38+
const mergedInit: RequestInit = {
39+
...baseInit,
40+
...init,
41+
// Headers need special handling - merge instead of replace
42+
headers: init?.headers ? { ...normalizeHeaders(baseInit.headers), ...normalizeHeaders(init.headers) } : baseInit.headers
43+
};
44+
return baseFetch(url, mergedInit);
45+
};
46+
}
47+
548
/**
649
* Options for sending a JSON-RPC message.
750
*/

0 commit comments

Comments
 (0)