Skip to content

Commit f7c14fd

Browse files
Introduces the extensibility API to allow users to add custom HTTP headers to token acquisition requests (under extensibility) (#5440)
* init * more tests * new surface * fix --------- Co-authored-by: Gladwin Johnson <[email protected]>
1 parent b9e5aa8 commit f7c14fd

File tree

8 files changed

+312
-0
lines changed

8 files changed

+312
-0
lines changed

src/client/Microsoft.Identity.Client/Extensibility/AcquireTokenParameterBuilderExtensions.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,26 @@ public static T WithExtraHttpHeaders<T>(
2626
}
2727
}
2828
}
29+
30+
// Extensibility (new surface for WithExtraHttpHeaders)
31+
namespace Microsoft.Identity.Client.Extensibility
32+
{
33+
/// <summary>
34+
/// Extensibility helpers for acquire token parameter builders.
35+
/// </summary>
36+
public static class AcquireTokenParameterBuilderExtensions
37+
{
38+
/// <summary>Adds additional HTTP headers to the token request.</summary>
39+
/// <param name="builder">Parameter builder for acquiring tokens.</param>
40+
/// <param name="extraHttpHeaders">Additional HTTP headers to add to the token request.</param>
41+
public static T WithExtraHttpHeaders<T>(
42+
this AbstractAcquireTokenParameterBuilder<T> builder,
43+
IDictionary<string, string> extraHttpHeaders)
44+
where T : AbstractAcquireTokenParameterBuilder<T>
45+
{
46+
// Delegate to the Advanced implementation to keep a single source of truth.
47+
return Advanced.AcquireTokenParameterBuilderExtensions
48+
.WithExtraHttpHeaders(builder, extraHttpHeaders);
49+
}
50+
}
51+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions
2+
static Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions.WithExtraHttpHeaders<T>(this Microsoft.Identity.Client.AbstractAcquireTokenParameterBuilder<T> builder, System.Collections.Generic.IDictionary<string, string> extraHttpHeaders) -> T
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions
2+
static Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions.WithExtraHttpHeaders<T>(this Microsoft.Identity.Client.AbstractAcquireTokenParameterBuilder<T> builder, System.Collections.Generic.IDictionary<string, string> extraHttpHeaders) -> T
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions
2+
static Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions.WithExtraHttpHeaders<T>(this Microsoft.Identity.Client.AbstractAcquireTokenParameterBuilder<T> builder, System.Collections.Generic.IDictionary<string, string> extraHttpHeaders) -> T
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions
2+
static Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions.WithExtraHttpHeaders<T>(this Microsoft.Identity.Client.AbstractAcquireTokenParameterBuilder<T> builder, System.Collections.Generic.IDictionary<string, string> extraHttpHeaders) -> T
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions
2+
static Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions.WithExtraHttpHeaders<T>(this Microsoft.Identity.Client.AbstractAcquireTokenParameterBuilder<T> builder, System.Collections.Generic.IDictionary<string, string> extraHttpHeaders) -> T
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions
2+
static Microsoft.Identity.Client.Extensibility.AcquireTokenParameterBuilderExtensions.WithExtraHttpHeaders<T>(this Microsoft.Identity.Client.AbstractAcquireTokenParameterBuilder<T> builder, System.Collections.Generic.IDictionary<string, string> extraHttpHeaders) -> T
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
// Copyright (c) Microsoft.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using System.Net.Http;
8+
using System.Threading.Tasks;
9+
using Microsoft.Identity.Client;
10+
using Microsoft.Identity.Client.Extensibility;
11+
using Microsoft.Identity.Test.Common.Core.Helpers;
12+
using Microsoft.Identity.Test.Common.Core.Mocks;
13+
using Microsoft.VisualStudio.TestTools.UnitTesting;
14+
15+
namespace Microsoft.Identity.Test.Unit.PublicApiTests
16+
{
17+
[TestClass]
18+
public class ExtraHttpHeadersTests : TestBase
19+
{
20+
private readonly string _clientId = "4df2cbbb-8612-49c1-87c8-f334d6d065ad";
21+
private readonly string _scope = "api://msaltokenexchange/.default";
22+
private readonly string _tenantId = "tenantid";
23+
24+
private static bool TryGetHeader(HttpRequestMessage req, string name, out string value)
25+
{
26+
if (req.Headers.TryGetValues(name, out var v) && v != null)
27+
{
28+
value = v.Single();
29+
return true;
30+
}
31+
32+
if (req.Content?.Headers != null &&
33+
req.Content.Headers.TryGetValues(name, out var v2) &&
34+
v2 != null)
35+
{
36+
value = v2.Single();
37+
return true;
38+
}
39+
40+
value = null;
41+
return false;
42+
}
43+
44+
[TestMethod]
45+
public async Task AcquireTokenForClient_WithExtraHttpHeaders_SendsHeaders_Async()
46+
{
47+
using var httpManager = new MockHttpManager();
48+
{
49+
// 1) Instance discovery
50+
httpManager.AddInstanceDiscoveryMockHandler();
51+
52+
// 2) Token endpoint
53+
httpManager.AddMockHandler(new MockHttpMessageHandler
54+
{
55+
ExpectedMethod = HttpMethod.Post,
56+
ResponseMessage = MockHelpers.CreateSuccessfulClientCredentialTokenResponseMessage(),
57+
AdditionalRequestValidation = (HttpRequestMessage req) =>
58+
{
59+
Assert.IsTrue(TryGetHeader(req, "x-ms-test", out var v1), "x-ms-test not present.");
60+
Assert.AreEqual("abc", v1);
61+
62+
Assert.IsTrue(TryGetHeader(req, "x-correlation-id", out var v2), "x-correlation-id not present.");
63+
Assert.AreEqual("123", v2);
64+
}
65+
});
66+
67+
var app = ConfidentialClientApplicationBuilder
68+
.Create(_clientId)
69+
.WithAuthority("https://login.microsoftonline.com/", _tenantId)
70+
.WithClientSecret("ClientSecret")
71+
.WithHttpManager(httpManager)
72+
.BuildConcrete();
73+
74+
var headers = new Dictionary<string, string>
75+
{
76+
["x-ms-test"] = "abc",
77+
["x-correlation-id"] = "123"
78+
};
79+
80+
var result = await app.AcquireTokenForClient(new[] { _scope })
81+
.WithExtraHttpHeaders(headers) // <-- new API under test
82+
.ExecuteAsync()
83+
.ConfigureAwait(false);
84+
85+
Assert.IsNotNull(result);
86+
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
87+
}
88+
}
89+
90+
[TestMethod]
91+
public async Task AcquireTokenForClient_ListAllRequestHeaders_Async()
92+
{
93+
using var httpManager = new MockHttpManager();
94+
{
95+
httpManager.AddInstanceDiscoveryMockHandler();
96+
97+
httpManager.AddMockHandler(new MockHttpMessageHandler
98+
{
99+
ExpectedMethod = HttpMethod.Post,
100+
ResponseMessage = MockHelpers.CreateSuccessfulClientCredentialTokenResponseMessage(),
101+
AdditionalRequestValidation = (HttpRequestMessage req) =>
102+
{
103+
// 1) Dump everything to the test output (no assumptions)
104+
foreach (var kv in EnumerateAllHeaders(req))
105+
{
106+
TestContext.WriteLine($"{kv.Key}: {string.Join(", ", kv.Value)}");
107+
}
108+
109+
// 2) (Optional) Assert a few stable MSAL defaults are present.
110+
// Keep this list small to avoid flakiness across platforms.
111+
AssertHeaderExists(req, "client-request-id");
112+
AssertHeaderExists(req, "return-client-request-id");
113+
AssertHeaderExists(req, "x-client-sku");
114+
AssertHeaderExists(req, "x-client-ver");
115+
AssertHeaderExists(req, "x-client-os");
116+
AssertHeaderExists(req, "Accept");
117+
AssertHeaderExists(req, "Content-Type");
118+
AssertHeaderExists(req, "x-ms-test");
119+
}
120+
});
121+
122+
var app = ConfidentialClientApplicationBuilder
123+
.Create(_clientId)
124+
.WithAuthority("https://login.microsoftonline.com/", _tenantId)
125+
.WithClientSecret("ClientSecret")
126+
.WithHttpManager(httpManager)
127+
.BuildConcrete();
128+
129+
// Include one custom header to prove user-provided + defaults both show up
130+
var custom = new Dictionary<string, string>
131+
{
132+
["x-ms-test"] = "abc"
133+
};
134+
135+
var result = await app.AcquireTokenForClient(new[] { _scope })
136+
.WithExtraHttpHeaders(custom)
137+
.ExecuteAsync()
138+
.ConfigureAwait(false);
139+
140+
Assert.IsNotNull(result);
141+
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
142+
}
143+
}
144+
145+
[TestMethod]
146+
public async Task AcquireTokenForClient_WithExtraHttpHeaders_Null_DoesNotChangeHeaders_Async()
147+
{
148+
using var httpManager = new MockHttpManager();
149+
httpManager.AddInstanceDiscoveryMockHandler();
150+
151+
HashSet<string> baseline = new(StringComparer.OrdinalIgnoreCase);
152+
HashSet<string> afterNull = new(StringComparer.OrdinalIgnoreCase);
153+
154+
httpManager.AddMockHandler(new MockHttpMessageHandler
155+
{
156+
ExpectedMethod = HttpMethod.Post,
157+
ResponseMessage = MockHelpers.CreateSuccessfulClientCredentialTokenResponseMessage(),
158+
AdditionalRequestValidation = req => { foreach (var h in EnumerateAllHeaders(req)) baseline.Add(h.Key); }
159+
});
160+
161+
var app1 = ConfidentialClientApplicationBuilder.Create(_clientId)
162+
.WithAuthority("https://login.microsoftonline.com/", _tenantId)
163+
.WithClientSecret("ClientSecret")
164+
.WithHttpManager(httpManager)
165+
.BuildConcrete();
166+
167+
await app1.AcquireTokenForClient(new[] { _scope }).ExecuteAsync().ConfigureAwait(false);
168+
169+
httpManager.AddMockHandler(new MockHttpMessageHandler
170+
{
171+
ExpectedMethod = HttpMethod.Post,
172+
ResponseMessage = MockHelpers.CreateSuccessfulClientCredentialTokenResponseMessage(),
173+
AdditionalRequestValidation = req => { foreach (var h in EnumerateAllHeaders(req)) afterNull.Add(h.Key); }
174+
});
175+
176+
var app2 = ConfidentialClientApplicationBuilder.Create(_clientId)
177+
.WithAuthority("https://login.microsoftonline.com/", _tenantId)
178+
.WithClientSecret("ClientSecret")
179+
.WithHttpManager(httpManager)
180+
.BuildConcrete();
181+
182+
Dictionary<string, string> headers = null;
183+
await app2.AcquireTokenForClient(new[] { _scope })
184+
.WithExtraHttpHeaders(headers)
185+
.ExecuteAsync().ConfigureAwait(false);
186+
187+
CollectionAssert.AreEquivalent(baseline.ToList(), afterNull.ToList(),
188+
"Null headers should not change the header set.");
189+
}
190+
191+
[TestMethod]
192+
public async Task AcquireTokenForClient_ExtraHeaders_OverridesDefault_Async()
193+
{
194+
using var httpManager = new MockHttpManager();
195+
httpManager.AddInstanceDiscoveryMockHandler();
196+
httpManager.AddMockHandler(new MockHttpMessageHandler
197+
{
198+
ExpectedMethod = HttpMethod.Post,
199+
ResponseMessage = MockHelpers.CreateSuccessfulClientCredentialTokenResponseMessage(),
200+
AdditionalRequestValidation = (HttpRequestMessage req) =>
201+
{
202+
Assert.IsTrue(TryGetHeader(req, "Accept", out var v), "Accept not present");
203+
Assert.AreEqual("text/plain", v); // user value should win
204+
}
205+
});
206+
207+
var app = ConfidentialClientApplicationBuilder.Create(_clientId)
208+
.WithAuthority("https://login.microsoftonline.com/", _tenantId)
209+
.WithClientSecret("ClientSecret")
210+
.WithHttpManager(httpManager)
211+
.BuildConcrete();
212+
213+
var headers = new Dictionary<string, string> { ["Accept"] = "text/plain" };
214+
var result = await app.AcquireTokenForClient(new[] { _scope })
215+
.WithExtraHttpHeaders(headers)
216+
.ExecuteAsync().ConfigureAwait(false);
217+
218+
Assert.IsNotNull(result);
219+
}
220+
221+
[TestMethod]
222+
public async Task AcquireTokenForClient_MultipleWithExtraHttpHeaders_Calls_LastWins_Async()
223+
{
224+
using var httpManager = new MockHttpManager();
225+
httpManager.AddInstanceDiscoveryMockHandler();
226+
227+
httpManager.AddMockHandler(new MockHttpMessageHandler
228+
{
229+
ExpectedMethod = HttpMethod.Post,
230+
ResponseMessage = MockHelpers.CreateSuccessfulClientCredentialTokenResponseMessage(),
231+
AdditionalRequestValidation = (HttpRequestMessage req) =>
232+
{
233+
// Only the last set of headers should be present
234+
Assert.IsTrue(TryGetHeader(req, "x-ms-test", out var v1), "x-ms-test not present.");
235+
Assert.AreEqual("final", v1);
236+
Assert.IsFalse(TryGetHeader(req, "x-ms-old", out _), "x-ms-old should not be present.");
237+
}
238+
});
239+
240+
var app = ConfidentialClientApplicationBuilder
241+
.Create(_clientId)
242+
.WithAuthority("https://login.microsoftonline.com/", _tenantId)
243+
.WithClientSecret("ClientSecret")
244+
.WithHttpManager(httpManager)
245+
.BuildConcrete();
246+
247+
var result = await app.AcquireTokenForClient(new[] { _scope })
248+
.WithExtraHttpHeaders(new Dictionary<string, string> { ["x-ms-test"] = "initial", ["x-ms-old"] = "old" })
249+
.WithExtraHttpHeaders(new Dictionary<string, string> { ["x-ms-test"] = "final" }) // last call should win
250+
.ExecuteAsync()
251+
.ConfigureAwait(false);
252+
253+
Assert.IsNotNull(result);
254+
}
255+
256+
private static IEnumerable<KeyValuePair<string, IEnumerable<string>>> EnumerateAllHeaders(HttpRequestMessage req)
257+
{
258+
foreach (var h in req.Headers)
259+
yield return new KeyValuePair<string, IEnumerable<string>>(h.Key, h.Value);
260+
261+
if (req.Content != null)
262+
{
263+
foreach (var h in req.Content.Headers)
264+
yield return new KeyValuePair<string, IEnumerable<string>>(h.Key, h.Value);
265+
}
266+
}
267+
268+
private static void AssertHeaderExists(HttpRequestMessage req, string name)
269+
{
270+
bool found =
271+
(req.Headers.TryGetValues(name, out var v1) && v1 != null) ||
272+
(req.Content?.Headers?.TryGetValues(name, out var v2) ?? false);
273+
274+
Assert.IsTrue(found, $"Expected header '{name}' not found.");
275+
}
276+
}
277+
}

0 commit comments

Comments
 (0)